diff --git a/codes/models/audio/tts/autoregressive_codegen.py b/codes/models/audio/tts/autoregressive_codegen.py index b3c32ee4..b4003aae 100644 --- a/codes/models/audio/tts/autoregressive_codegen.py +++ b/codes/models/audio/tts/autoregressive_codegen.py @@ -196,17 +196,15 @@ class CheckpointedXTransformerWrapper(nn.Module): class AutoregressiveCodegen(nn.Module): - def __init__(self, model_dim, depth, num_text_tokens=256, num_mel_tokens=8194, max_text_tokens=4000, - max_mel_tokens=4000, dropout=.1): + def __init__(self, model_dim, depth, num_text_tokens=256, num_mel_tokens=8194, dropout=.1): super().__init__() self.START_TOKEN=8192 self.STOP_TOKEN=8193 - self.max_mel_tokens = max_mel_tokens self.mel_embedding = ConditioningEncoder(80, model_dim, do_checkpointing=False) self.encoder = CheckpointedXTransformerWrapper( num_tokens=num_text_tokens, - max_seq_len=max_text_tokens, + use_pos_emb=False, attn_layers = Encoder( depth=depth//2, heads=model_dim//64, @@ -221,7 +219,7 @@ class AutoregressiveCodegen(nn.Module): )) self.decoder = CheckpointedXTransformerWrapper( num_tokens=num_mel_tokens, - max_seq_len=max_mel_tokens, + use_pos_emb=False, attn_layers=Decoder( depth=depth, heads=model_dim//64, @@ -268,7 +266,7 @@ class AutoregressiveCodegen(nn.Module): loss_mel = F.cross_entropy(dec.permute(0,2,1), mel_codes) return loss_mel - def generate(self, conditioning_signal, text_codes, **hf_generate_kwargs): + def generate(self, conditioning_signal, text_codes, max_tokens=1024, **hf_generate_kwargs): if not hasattr(self, 'inference_model'): self.inference_model = InferenceModel(self) @@ -283,7 +281,7 @@ class AutoregressiveCodegen(nn.Module): self.inference_model.store_context(context) gen = self.inference_model.generate(bos_token_id=self.START_TOKEN, pad_token_id=self.STOP_TOKEN, eos_token_id=self.STOP_TOKEN, - max_length=self.max_mel_tokens, output_attentions=False, return_dict_in_generate=True, + max_length=max_tokens, output_attentions=False, return_dict_in_generate=True, **hf_generate_kwargs) return gen