forked from mrq/DL-Art-School
gpt_tts_hf inference first pass
This commit is contained in:
parent
63bf135b93
commit
8917c02a4d
|
@ -7,6 +7,7 @@ from transformers import GPT2Model, GPT2Config, GPT2LMHeadModel, GPT2PreTrainedM
|
|||
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
||||
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
|
||||
|
||||
from models.gpt_voice.gpt_asr_hf import GPT2InferenceModel
|
||||
from models.gpt_voice.mini_encoder import AudioMiniEncoder
|
||||
from models.tacotron2.text import symbols
|
||||
from trainer.networks import register_model
|
||||
|
@ -103,6 +104,33 @@ class GptTtsHf(nn.Module):
|
|||
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
|
||||
return loss_text.mean(), loss_mel.mean(), mel_logits
|
||||
|
||||
def inference(self, text_inputs, cond_inputs, do_sample=False, temperature=1.0, num_beams=8):
|
||||
if not hasattr(self, 'inference_model'):
|
||||
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.text_pos_embedding, self.final_norm, self.text_head)
|
||||
|
||||
text_targets = F.pad(text_inputs, (1,0), value=self.START_TEXT_TOKEN)
|
||||
text_targets = F.pad(text_inputs, (0, self.max_symbols_per_phrase - text_targets.shape[1]), value=self.STOP_TEXT_TOKEN)
|
||||
text_emb = self.gpt.get_input_embeddings()(text_targets)
|
||||
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_targets.device))
|
||||
|
||||
conds = []
|
||||
for k in range(cond_inputs.shape[1]):
|
||||
conds.append(self.conditioning_encoder(cond_inputs[:, k]))
|
||||
while len(conds) < self.max_conditioning_inputs:
|
||||
conds.append(conds[-1])
|
||||
conds = torch.stack(conds, dim=1)
|
||||
conds = conds + self.conditioning_embedding(torch.arange(conds.shape[1], device=conds.device))
|
||||
|
||||
emb = torch.cat([text_emb, conds], dim=1)
|
||||
self.inference_model.store_mel_emb(emb)
|
||||
|
||||
fake_inputs = torch.full((text_inputs.shape[0],self.max_symbols_per_phrase+self.max_conditioning_inputs+1,), fill_value=1, dtype=torch.long, device=text_inputs.device)
|
||||
fake_inputs[:,-1] = self.START_MEL_TOKEN
|
||||
|
||||
gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.NUMBER_SYMBOLS, pad_token_id=0, eos_token_id=0,
|
||||
max_length=self.max_symbols_per_phrase+self.max_mel_frames, temperature=temperature, num_beams=num_beams, use_cache=True)
|
||||
return gen[:, self.max_mel_frames:]
|
||||
|
||||
|
||||
@register_model
|
||||
def register_gpt_tts_hf(opt_net, opt):
|
||||
|
|
Loading…
Reference in New Issue
Block a user