diff --git a/codes/models/audio/tts/autoregressive_codegen.py b/codes/models/audio/tts/autoregressive_codegen.py index 7831f1de..6ed288b0 100644 --- a/codes/models/audio/tts/autoregressive_codegen.py +++ b/codes/models/audio/tts/autoregressive_codegen.py @@ -201,6 +201,8 @@ class AutoregressiveCodegen(nn.Module): self.START_TOKEN=8192 self.STOP_TOKEN=8193 + self.max_text_token_id = num_text_tokens + self.max_mel_token_id = num_mel_tokens self.mel_embedding = ConditioningEncoder(80, model_dim, do_checkpointing=False) self.encoder = CheckpointedXTransformerWrapper( num_tokens=num_text_tokens, @@ -243,6 +245,9 @@ class AutoregressiveCodegen(nn.Module): } def forward(self, text_codes, conditioning_signal, mel_codes, wav_lengths, return_loss=True): + assert text_codes.max() < self.max_text_token_id and text_codes.min() >= 0, f'Invalid text code encountered: {text_codes.max()}, {text_codes.min()}' + assert mel_codes.max() < self.max_mel_token_id and mel_codes.min() >= 0, f'Invalid mel code encountered: {mel_codes.max()}, {mel_codes.min()}' + # Format mel_codes with a stop token on the end. mel_lengths = wav_lengths // 1024 + 1 for b in range(mel_codes.shape[0]):