gpt_tts_hf inference first pass

This commit is contained in:
James Betker 2021-12-12 19:51:44 -07:00
parent 63bf135b93
commit 8917c02a4d

View File

@ -7,6 +7,7 @@ from transformers import GPT2Model, GPT2Config, GPT2LMHeadModel, GPT2PreTrainedM
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
from models.gpt_voice.gpt_asr_hf import GPT2InferenceModel
from models.gpt_voice.mini_encoder import AudioMiniEncoder
from models.tacotron2.text import symbols
from trainer.networks import register_model
@ -103,6 +104,33 @@ class GptTtsHf(nn.Module):
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
return loss_text.mean(), loss_mel.mean(), mel_logits
def inference(self, text_inputs, cond_inputs, do_sample=False, temperature=1.0, num_beams=8):
if not hasattr(self, 'inference_model'):
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.text_pos_embedding, self.final_norm, self.text_head)
text_targets = F.pad(text_inputs, (1,0), value=self.START_TEXT_TOKEN)
text_targets = F.pad(text_inputs, (0, self.max_symbols_per_phrase - text_targets.shape[1]), value=self.STOP_TEXT_TOKEN)
text_emb = self.gpt.get_input_embeddings()(text_targets)
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_targets.device))
conds = []
for k in range(cond_inputs.shape[1]):
conds.append(self.conditioning_encoder(cond_inputs[:, k]))
while len(conds) < self.max_conditioning_inputs:
conds.append(conds[-1])
conds = torch.stack(conds, dim=1)
conds = conds + self.conditioning_embedding(torch.arange(conds.shape[1], device=conds.device))
emb = torch.cat([text_emb, conds], dim=1)
self.inference_model.store_mel_emb(emb)
fake_inputs = torch.full((text_inputs.shape[0],self.max_symbols_per_phrase+self.max_conditioning_inputs+1,), fill_value=1, dtype=torch.long, device=text_inputs.device)
fake_inputs[:,-1] = self.START_MEL_TOKEN
gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.NUMBER_SYMBOLS, pad_token_id=0, eos_token_id=0,
max_length=self.max_symbols_per_phrase+self.max_mel_frames, temperature=temperature, num_beams=num_beams, use_cache=True)
return gen[:, self.max_mel_frames:]
@register_model
def register_gpt_tts_hf(opt_net, opt):