diff --git a/codes/models/audio/tts/transformer_builders.py b/codes/models/audio/tts/transformer_builders.py index 8ce96f38..06f4a44d 100644 --- a/codes/models/audio/tts/transformer_builders.py +++ b/codes/models/audio/tts/transformer_builders.py @@ -59,8 +59,8 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text """ from transformers import GPT2Config, GPT2Model gpt_config = GPT2Config(vocab_size=256, # Unused. - n_positions=1, - n_ctx=1, + n_positions=max_mel_seq_len+max_text_seq_len, + n_ctx=max_mel_seq_len+max_text_seq_len, n_embd=model_dim, n_layer=layers, n_head=heads,