From 8917c02a4da3feb77e6bc2292a630fe39c4ad77e Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 12 Dec 2021 19:51:44 -0700 Subject: [PATCH] gpt_tts_hf inference first pass --- codes/models/gpt_voice/gpt_tts_hf.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/codes/models/gpt_voice/gpt_tts_hf.py b/codes/models/gpt_voice/gpt_tts_hf.py index 315f1e27..c57d492e 100644 --- a/codes/models/gpt_voice/gpt_tts_hf.py +++ b/codes/models/gpt_voice/gpt_tts_hf.py @@ -7,6 +7,7 @@ from transformers import GPT2Model, GPT2Config, GPT2LMHeadModel, GPT2PreTrainedM from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions from transformers.utils.model_parallel_utils import get_device_map, assert_device_map +from models.gpt_voice.gpt_asr_hf import GPT2InferenceModel from models.gpt_voice.mini_encoder import AudioMiniEncoder from models.tacotron2.text import symbols from trainer.networks import register_model @@ -103,6 +104,33 @@ class GptTtsHf(nn.Module): loss_mel = F.cross_entropy(mel_logits, mel_targets.long()) return loss_text.mean(), loss_mel.mean(), mel_logits + def inference(self, text_inputs, cond_inputs, do_sample=False, temperature=1.0, num_beams=8): + if not hasattr(self, 'inference_model'): + self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.text_pos_embedding, self.final_norm, self.text_head) + + text_targets = F.pad(text_inputs, (1,0), value=self.START_TEXT_TOKEN) + text_targets = F.pad(text_inputs, (0, self.max_symbols_per_phrase - text_targets.shape[1]), value=self.STOP_TEXT_TOKEN) + text_emb = self.gpt.get_input_embeddings()(text_targets) + text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_targets.device)) + + conds = [] + for k in range(cond_inputs.shape[1]): + conds.append(self.conditioning_encoder(cond_inputs[:, k])) + while len(conds) < self.max_conditioning_inputs: + conds.append(conds[-1]) + conds = torch.stack(conds, dim=1) + conds = conds + self.conditioning_embedding(torch.arange(conds.shape[1], device=conds.device)) + + emb = torch.cat([text_emb, conds], dim=1) + self.inference_model.store_mel_emb(emb) + + fake_inputs = torch.full((text_inputs.shape[0],self.max_symbols_per_phrase+self.max_conditioning_inputs+1,), fill_value=1, dtype=torch.long, device=text_inputs.device) + fake_inputs[:,-1] = self.START_MEL_TOKEN + + gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.NUMBER_SYMBOLS, pad_token_id=0, eos_token_id=0, + max_length=self.max_symbols_per_phrase+self.max_mel_frames, temperature=temperature, num_beams=num_beams, use_cache=True) + return gen[:, self.max_mel_frames:] + @register_model def register_gpt_tts_hf(opt_net, opt):