Update MT config
This commit is contained in:
parent
ea07735c7b
commit
5356b252c4
|
@ -22,7 +22,7 @@ from fairseq.models.transformer import Embedding
|
||||||
from fairseq.modules import PositionalEmbedding
|
from fairseq.modules import PositionalEmbedding
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from torchscale.architecture.config import DecoderConfig, EncoderConfig
|
from torchscale.architecture.config import DecoderConfig, EncoderConfig, EncoderDecoderConfig
|
||||||
from torchscale.architecture.encoder import Encoder
|
from torchscale.architecture.encoder import Encoder
|
||||||
|
|
||||||
from .language_modeling import LMDecoder as MTDecoder
|
from .language_modeling import LMDecoder as MTDecoder
|
||||||
|
@ -308,7 +308,7 @@ class TranslationModel(FairseqEncoderDecoderModel):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build_encoder(cls, args, embed_tokens, embed_positions, dictionary):
|
def build_encoder(cls, args, embed_tokens, embed_positions, dictionary):
|
||||||
config = EncoderConfig()
|
config = EncoderDecoderConfig()
|
||||||
config.override(args)
|
config.override(args)
|
||||||
|
|
||||||
return MTEncoder(
|
return MTEncoder(
|
||||||
|
@ -323,7 +323,7 @@ class TranslationModel(FairseqEncoderDecoderModel):
|
||||||
def build_decoder(
|
def build_decoder(
|
||||||
cls, args, embed_tokens, embed_positions, output_projection, dictionary
|
cls, args, embed_tokens, embed_positions, output_projection, dictionary
|
||||||
):
|
):
|
||||||
config = DecoderConfig()
|
config = EncoderDecoderConfig()
|
||||||
config.override(args)
|
config.override(args)
|
||||||
|
|
||||||
return MTDecoder(
|
return MTDecoder(
|
||||||
|
|
Loading…
Reference in New Issue
Block a user