Update MT config
This commit is contained in:
parent
ea07735c7b
commit
5356b252c4
|
@ -22,7 +22,7 @@ from fairseq.models.transformer import Embedding
|
|||
from fairseq.modules import PositionalEmbedding
|
||||
from torch import Tensor
|
||||
|
||||
from torchscale.architecture.config import DecoderConfig, EncoderConfig
|
||||
from torchscale.architecture.config import DecoderConfig, EncoderConfig, EncoderDecoderConfig
|
||||
from torchscale.architecture.encoder import Encoder
|
||||
|
||||
from .language_modeling import LMDecoder as MTDecoder
|
||||
|
@ -308,7 +308,7 @@ class TranslationModel(FairseqEncoderDecoderModel):
|
|||
|
||||
@classmethod
|
||||
def build_encoder(cls, args, embed_tokens, embed_positions, dictionary):
|
||||
config = EncoderConfig()
|
||||
config = EncoderDecoderConfig()
|
||||
config.override(args)
|
||||
|
||||
return MTEncoder(
|
||||
|
@ -323,7 +323,7 @@ class TranslationModel(FairseqEncoderDecoderModel):
|
|||
def build_decoder(
|
||||
cls, args, embed_tokens, embed_positions, output_projection, dictionary
|
||||
):
|
||||
config = DecoderConfig()
|
||||
config = EncoderDecoderConfig()
|
||||
config.override(args)
|
||||
|
||||
return MTDecoder(
|
||||
|
|
Loading…
Reference in New Issue
Block a user