Python Class for RNN Encoder

moye Lv6

class RNNEncoder(FairseqEncoder):
def init(self, args, dictionary, embed_tokens):
super().init(dictionary)
self.embed_tokens = embed_tokens

self.embed_dim = args.encoder_embed_dim
self.hidden_dim = args.encoder_ffn_embed_dim
self.num_layers = args.encoder_layers

self.dropout_in_module = nn.Dropout(args.dropout)
self.rnn = nn.GRU(
    self.embed_dim, 
    self.hidden_dim, 
    self.num_layers, 
    dropout=args.dropout, 
    batch_first=False, 
    bidirectional=True
)
self.dropout_out_module = nn.Dropout(args.dropout)

self.padding_idx = dictionary.pad()

def combine_bidir(self, outs, bsz: int):
out = outs.view(self.num_layers, 2, bsz, -1).transpose(1, 2).contiguous()
return out.view(self.num_layers, bsz, -1)

def forward(self, src_tokens, **unused):
bsz, seqlen = src_tokens.size()

# get embeddings
x = self.embed_tokens(src_tokens)
x = self.dropout_in_module(x)

# B x T x C -> T x B x C
x = x.transpose(0, 1)

# pass thru bidirectional RNN
h0 = x.new_zeros(2 * self.num_layers, bsz, self.hidden_dim)
x, final_hiddens = self.rnn(x, h0)
outputs = self.dropout_out_module(x)
# outputs = [sequence len, batch size, hid dim * directions]
# hidden =  [num_layers * directions, batch size  , hid dim]

# Since Encoder is bidirectional, we need to concatenate the hidden states of two directions
final_hiddens = self.combine_bidir(final_hiddens, bsz)
# hidden =  [num_layers x batch x num_directions*hidden]

encoder_padding_mask = src_tokens.eq(self.padding_idx).t()
return tuple(
    (
        outputs,  # seq_len x batch x hidden
        final_hiddens,  # num_layers x batch x num_directions*hidden
        encoder_padding_mask,  # seq_len x batch
    )
)

def reorder_encoder_out(self, encoder_out, new_order):
# This is used by fairseq’s beam search. How and why is not particularly important here.
return tuple(
(
encoder_out[0].index_select(1, new_order),
encoder_out[1].index_select(1, new_order),
encoder_out[2].index_select(1, new_order),
)
)

  • 标题: Python Class for RNN Encoder
  • 作者: moye
  • 创建于 : 2022-08-16 15:01:08
  • 更新于 : 2025-12-11 14:39:48
  • 链接: https://www.kanes.top/2022/08/16/Python Class for RNN Encoder/
  • 版权声明: 本文章采用 CC BY-NC-SA 4.0 进行许可。
评论
目录
Python Class for RNN Encoder