From 50174a3078b47c04f95a8a7ca64ff5397142ded7 Mon Sep 17 00:00:00 2001 From: sunyt32 Date: Fri, 29 Sep 2023 03:50:24 +0000 Subject: [PATCH] fix fairseq example --- examples/fairseq/models/retnet.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/fairseq/models/retnet.py b/examples/fairseq/models/retnet.py index 1a5b329..b6364c4 100644 --- a/examples/fairseq/models/retnet.py +++ b/examples/fairseq/models/retnet.py @@ -345,7 +345,7 @@ def retnet_xl(args): args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2048) args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 3456) args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 3456) - args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) + args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 8) args.decoder_layers = getattr(args, "decoder_layers", 24) retnet_base_architecture(args) @@ -354,7 +354,7 @@ def retnet_3b(args): args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2560) args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 4280) args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4280) - args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 10) + args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 10) args.decoder_layers = getattr(args, "decoder_layers", 32) retnet_base_architecture(args) @@ -363,7 +363,7 @@ def retnet_7b(args): args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 4096) args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 6912) args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 6912) - args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) + args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 16) args.decoder_layers = getattr(args, "decoder_layers", 32) retnet_base_architecture(args) @@ -372,7 +372,7 @@ def retnet_13b(args): args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 5120) args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 8560) args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 8560) - args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 20) + args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 20) args.decoder_layers = getattr(args, "decoder_layers", 40) retnet_base_architecture(args) @@ -381,7 +381,7 @@ def retnet_65b(args): args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 8192) args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 13824) args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 13824) - args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32) + args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 32) args.decoder_layers = getattr(args, "decoder_layers", 64) retnet_base_architecture(args)