Merge pull request #1 from buaahsh/main

Fix decoder_embed_dim in Fairseq example
This commit is contained in:
Li Dong 2022-11-24 15:54:50 +08:00 committed by GitHub
commit afd9094fb5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -271,10 +271,10 @@ class TranslationModel(FairseqEncoderDecoderModel):
output_projection.weight = decoder_embed_tokens.weight
else:
output_projection = torch.nn.Linear(
decoder_embed_dim, len(tgt_dict), bias=False
args.decoder_embed_dim, len(tgt_dict), bias=False
)
torch.nn.init.normal_(
output_projection.weight, mean=0, std=decoder_embed_dim ** -0.5
output_projection.weight, mean=0, std=args.decoder_embed_dim ** -0.5
)
encoder = cls.build_encoder(