diff --git a/codes/models/audio/tts/autoregressive_codegen.py b/codes/models/audio/tts/autoregressive_codegen.py index 9f2198af..07429a5e 100644 --- a/codes/models/audio/tts/autoregressive_codegen.py +++ b/codes/models/audio/tts/autoregressive_codegen.py @@ -217,6 +217,7 @@ class AutoregressiveCodegen(nn.Module): ff_mult=1, rotary_pos_emb=True, )) + self.encoder.to_logits = nn.Identity() # This is unused. self.decoder = CheckpointedXTransformerWrapper( num_tokens=num_mel_tokens, use_pos_emb=False,