Update MT config

This commit is contained in:
Shuming Ma 2023-07-31 09:17:03 -07:00
parent ea07735c7b
commit 5356b252c4

View File

@ -22,7 +22,7 @@ from fairseq.models.transformer import Embedding
from fairseq.modules import PositionalEmbedding from fairseq.modules import PositionalEmbedding
from torch import Tensor from torch import Tensor
from torchscale.architecture.config import DecoderConfig, EncoderConfig from torchscale.architecture.config import DecoderConfig, EncoderConfig, EncoderDecoderConfig
from torchscale.architecture.encoder import Encoder from torchscale.architecture.encoder import Encoder
from .language_modeling import LMDecoder as MTDecoder from .language_modeling import LMDecoder as MTDecoder
@ -308,7 +308,7 @@ class TranslationModel(FairseqEncoderDecoderModel):
@classmethod @classmethod
def build_encoder(cls, args, embed_tokens, embed_positions, dictionary): def build_encoder(cls, args, embed_tokens, embed_positions, dictionary):
config = EncoderConfig() config = EncoderDecoderConfig()
config.override(args) config.override(args)
return MTEncoder( return MTEncoder(
@ -323,7 +323,7 @@ class TranslationModel(FairseqEncoderDecoderModel):
def build_decoder( def build_decoder(
cls, args, embed_tokens, embed_positions, output_projection, dictionary cls, args, embed_tokens, embed_positions, output_projection, dictionary
): ):
config = DecoderConfig() config = EncoderDecoderConfig()
config.override(args) config.override(args)
return MTDecoder( return MTDecoder(