diff --git a/codes/models/audio/tts/autoregressive_codegen.py b/codes/models/audio/tts/autoregressive_codegen.py index f69f99d1..ce1113b7 100644 --- a/codes/models/audio/tts/autoregressive_codegen.py +++ b/codes/models/audio/tts/autoregressive_codegen.py @@ -86,7 +86,8 @@ class InferenceModel(GPT2PreTrainedModel): assert labels is None # Training not supported by this inference model. return_dict = return_dict if return_dict is not None else self.config.use_return_dict - out = self.transformer.decoder(input_ids, full_context=self.context, return_embeddings=True, past_key_values=past_key_values, use_cache=use_cache) + out = self.transformer.decoder(input_ids, full_context=self.context, return_embeddings=True, past_key_values=past_key_values, + use_cache=use_cache, expected_seq_len=150) if use_cache: hidden_states, present_key_values = out else: @@ -168,6 +169,8 @@ class AutoregressiveCodegen(nn.Module): self.START_TOKEN=8192 self.STOP_TOKEN=8193 + self.START_TEXT_TOKEN = 255 + self.STOP_TEXT_TOKEN = 0 self.max_text_token_id = num_text_tokens self.max_mel_token_id = num_mel_tokens self.mel_embedding = ConditioningEncoder(80, model_dim, do_checkpointing=False) @@ -231,6 +234,9 @@ class AutoregressiveCodegen(nn.Module): for i in range(conditioning_signal.shape[1]): cond_embs.append(self.mel_embedding(conditioning_signal[:, i])) cond_emb = torch.stack(cond_embs, dim=1).mean(dim=1, keepdim=True) + # Since all positional embeddings are relative, it is (probably) important to "fix" the text with some permanent embeddings. + text_codes = F.pad(text_codes, (1,0), value=self.START_TEXT_TOKEN) + text_codes = F.pad(text_codes, (0,1), value=self.STOP_TEXT_TOKEN) _, enc_text = self.encoder(text_codes, return_hiddens=True) # Interleave cond_emb into the first few contexts. full_context = enc_text @@ -255,6 +261,8 @@ class AutoregressiveCodegen(nn.Module): for i in range(conditioning_signal.shape[1]): cond_embs.append(self.mel_embedding(conditioning_signal[:, i])) cond_emb = torch.stack(cond_embs, dim=1).mean(dim=1, keepdim=True) + text_codes = F.pad(text_codes, (1,0), value=self.START_TEXT_TOKEN) + text_codes = F.pad(text_codes, (0,1), value=self.STOP_TEXT_TOKEN) _, enc_text = self.encoder(text_codes, return_hiddens=True) # Interleave cond_emb into the first few contexts. full_context = enc_text