diff --git a/examples/fairseq/models/machine_translation.py b/examples/fairseq/models/machine_translation.py index 05e3633..f2e1b1a 100644 --- a/examples/fairseq/models/machine_translation.py +++ b/examples/fairseq/models/machine_translation.py @@ -271,10 +271,10 @@ class TranslationModel(FairseqEncoderDecoderModel): output_projection.weight = decoder_embed_tokens.weight else: output_projection = torch.nn.Linear( - decoder_embed_dim, len(tgt_dict), bias=False + args.decoder_embed_dim, len(tgt_dict), bias=False ) torch.nn.init.normal_( - output_projection.weight, mean=0, std=decoder_embed_dim ** -0.5 + output_projection.weight, mean=0, std=args.decoder_embed_dim ** -0.5 ) encoder = cls.build_encoder(