From 7c578eb59bcadba861041cedcaf8c0caea52f722 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 7 Apr 2022 21:22:46 -0600 Subject: [PATCH] Fix inference in new autoregressive_codegen --- .../audio/tts/autoregressive_codegen.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/codes/models/audio/tts/autoregressive_codegen.py b/codes/models/audio/tts/autoregressive_codegen.py index 22403233..acedc4a4 100644 --- a/codes/models/audio/tts/autoregressive_codegen.py +++ b/codes/models/audio/tts/autoregressive_codegen.py @@ -5,7 +5,7 @@ from transformers import GPT2PreTrainedModel, GPT2Config from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions from models.arch_util import AttentionBlock -from models.lucidrains.x_transformers import TransformerWrapper, Encoder, Decoder +from models.lucidrains.x_transformers import TransformerWrapper, Decoder, Encoder from trainer.networks import register_model @@ -86,7 +86,7 @@ class InferenceModel(GPT2PreTrainedModel): assert labels is None # Training not supported by this inference model. return_dict = return_dict if return_dict is not None else self.config.use_return_dict - hidden_states = self.transformer.decoder(input_ids, context=self.context, return_embeddings=True) + hidden_states = self.transformer.decoder(input_ids, full_context=self.context, return_embeddings=True) logits = self.transformer.decoder.to_logits(hidden_states) if not return_dict: @@ -241,21 +241,24 @@ class AutoregressiveCodegen(nn.Module): loss_mel = F.cross_entropy(dec.permute(0,2,1), mel_codes) return loss_mel - def generate(self, conditioning_signal, text_codes, max_tokens=1024, **hf_generate_kwargs): - if not hasattr(self, 'inference_model'): - self.inference_model = InferenceModel(self) - + def generate(self, conditioning_signal, text_codes, max_tokens=256, **hf_generate_kwargs): + inference_model = InferenceModel(self) + # Build the context if len(conditioning_signal.shape) != 4: conditioning_signal = conditioning_signal.unsqueeze(1) cond_embs = [] for i in range(conditioning_signal.shape[1]): cond_embs.append(self.mel_embedding(conditioning_signal[:, i])) cond_emb = torch.stack(cond_embs, dim=1).mean(dim=1, keepdim=True) - enc_text = self.encoder(text_codes, return_embeddings=True) - context = torch.cat([cond_emb, enc_text], dim=1) - self.inference_model.store_context(context) + _, enc_text = self.encoder(text_codes, return_hiddens=True) + # Interleave cond_emb into the first few contexts. + full_context = enc_text + full_context[1] = cond_emb + full_context[3] = cond_emb + full_context[6] = cond_emb + inference_model.store_context(full_context) - gen = self.inference_model.generate(bos_token_id=self.START_TOKEN, pad_token_id=self.STOP_TOKEN, eos_token_id=self.STOP_TOKEN, + gen = inference_model.generate(bos_token_id=self.START_TOKEN, pad_token_id=self.STOP_TOKEN, eos_token_id=self.STOP_TOKEN, max_length=max_tokens, output_attentions=False, return_dict_in_generate=True, **hf_generate_kwargs) return gen.sequences