fix fairseq example

This commit is contained in:
sunyt32 2023-09-29 03:50:24 +00:00
parent 05a9628309
commit 50174a3078

View File

@ -345,7 +345,7 @@ def retnet_xl(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2048) args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2048)
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 3456) args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 3456)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 3456) args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 3456)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 8)
args.decoder_layers = getattr(args, "decoder_layers", 24) args.decoder_layers = getattr(args, "decoder_layers", 24)
retnet_base_architecture(args) retnet_base_architecture(args)
@ -354,7 +354,7 @@ def retnet_3b(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2560) args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2560)
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 4280) args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 4280)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4280) args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4280)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 10) args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 10)
args.decoder_layers = getattr(args, "decoder_layers", 32) args.decoder_layers = getattr(args, "decoder_layers", 32)
retnet_base_architecture(args) retnet_base_architecture(args)
@ -363,7 +363,7 @@ def retnet_7b(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 4096) args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 4096)
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 6912) args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 6912)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 6912) args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 6912)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 16)
args.decoder_layers = getattr(args, "decoder_layers", 32) args.decoder_layers = getattr(args, "decoder_layers", 32)
retnet_base_architecture(args) retnet_base_architecture(args)
@ -372,7 +372,7 @@ def retnet_13b(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 5120) args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 5120)
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 8560) args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 8560)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 8560) args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 8560)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 20) args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 20)
args.decoder_layers = getattr(args, "decoder_layers", 40) args.decoder_layers = getattr(args, "decoder_layers", 40)
retnet_base_architecture(args) retnet_base_architecture(args)
@ -381,7 +381,7 @@ def retnet_65b(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 8192) args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 8192)
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 13824) args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 13824)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 13824) args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 13824)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32) args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 32)
args.decoder_layers = getattr(args, "decoder_layers", 64) args.decoder_layers = getattr(args, "decoder_layers", 64)
retnet_base_architecture(args) retnet_base_architecture(args)