decoder_embed_dim -> args.decoder_embed_dim

This commit is contained in:
Shaohan Huang 2022-11-24 14:30:39 +08:00
parent 51abba7c8b
commit bdf759f116

View File

@ -271,10 +271,10 @@ class TranslationModel(FairseqEncoderDecoderModel):
output_projection.weight = decoder_embed_tokens.weight
else:
output_projection = torch.nn.Linear(
decoder_embed_dim, len(tgt_dict), bias=False
args.decoder_embed_dim, len(tgt_dict), bias=False
)
torch.nn.init.normal_(
output_projection.weight, mean=0, std=decoder_embed_dim ** -0.5
output_projection.weight, mean=0, std=args.decoder_embed_dim ** -0.5
)
encoder = cls.build_encoder(