diff --git a/torchscale/architecture/config.py b/torchscale/architecture/config.py index 14e7bfc..74fa3f9 100644 --- a/torchscale/architecture/config.py +++ b/torchscale/architecture/config.py @@ -77,7 +77,6 @@ class EncoderConfig(Config): self.deepnorm = False - class DecoderConfig(Config): def __init__(self, **kwargs): super().__init__(**kwargs) @@ -100,8 +99,8 @@ class DecoderConfig(Config): self.deepnorm = False - class EncoderDecoderConfig(EncoderConfig, DecoderConfig): def __init__(self, **kwargs): super().__init__(**kwargs) self.share_all_embeddings = kwargs.pop("share_all_embeddings", False) +