diff --git a/torchscale/architecture/encoder_decoder.py b/torchscale/architecture/encoder_decoder.py index d91313f..ed64641 100644 --- a/torchscale/architecture/encoder_decoder.py +++ b/torchscale/architecture/encoder_decoder.py @@ -11,7 +11,7 @@ from torchscale.architecture.encoder import Encoder class EncoderDecoder(nn.Module): def __init__( self, - args, + args: EncoderDecoderConfig, encoder_embed_tokens=None, encoder_embed_positions=None, decoder_embed_tokens=None,