1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
| class Seq2SeqAttentionDecoder(AttentionDecoder): def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs): super(Seq2SeqAttentionDecoder, self).__init__(**kwargs) self.attention = d2l.AdditiveAttention( num_hiddens, num_hiddens, num_hiddens, dropout) self.embedding = nn.Embedding(vocab_size, embed_size) self.rnn = nn.GRU( embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout) self.dense = nn.Linear(num_hiddens, vocab_size)
def init_state(self, enc_outputs, enc_valid_lens, *args): outputs, hidden_state = enc_outputs return (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)
def forward(self, X, state): enc_outputs, hidden_state, enc_valid_lens = state X = self.embedding(X).permute(1, 0, 2) outputs, self._attention_weights = [], [] for x in X: query = torch.unsqueeze(hidden_state[-1], dim=1) context = self.attention( query, enc_outputs, enc_outputs, enc_valid_lens) x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1) out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state) outputs.append(out) self._attention_weights.append(self.attention.attention_weights) outputs = self.dense(torch.cat(outputs, dim=0)) return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens]
@property def attention_weights(self): return self._attention_weights
使用包含7个时间步的4个序列输入的小批量测试Bahdanau注意力解码器 encoder = d2l.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2) encoder.eval() decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2) decoder.eval() X = torch.zeros((4, 7), dtype=torch.long) state = decoder.init_state(encoder(X), None) output, state = decoder(X, state) output.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape
|