fix fairseq example

main
sunyt32 2023-09-29 03:50:24 +07:00
parent 05a9628309
commit 50174a3078
1 changed files with 5 additions and 5 deletions

@ -345,7 +345,7 @@ def retnet_xl(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2048)
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 3456)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 3456)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 8)
args.decoder_layers = getattr(args, "decoder_layers", 24)
retnet_base_architecture(args)
@ -354,7 +354,7 @@ def retnet_3b(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2560)
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 4280)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4280)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 10)
args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 10)
args.decoder_layers = getattr(args, "decoder_layers", 32)
retnet_base_architecture(args)
@ -363,7 +363,7 @@ def retnet_7b(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 4096)
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 6912)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 6912)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 16)
args.decoder_layers = getattr(args, "decoder_layers", 32)
retnet_base_architecture(args)
@ -372,7 +372,7 @@ def retnet_13b(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 5120)
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 8560)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 8560)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 20)
args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 20)
args.decoder_layers = getattr(args, "decoder_layers", 40)
retnet_base_architecture(args)
@ -381,7 +381,7 @@ def retnet_65b(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 8192)
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 13824)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 13824)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32)
args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 32)
args.decoder_layers = getattr(args, "decoder_layers", 64)
retnet_base_architecture(args)