Fix gpt_tts_hf inference

This commit is contained in:
James Betker 2021-12-20 17:45:26 -07:00
parent 712d746e9b
commit 53858b2055
3 changed files with 28 additions and 9 deletions

View File

@ -145,15 +145,17 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
if input_ids.shape[1] != 1:
text_inputs = input_ids[:, mel_len:]
text_emb = self.transformer.get_input_embeddings()(text_inputs)
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_emb.device))
if self.text_pos_embedding is not None:
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_emb.device))
if self.cached_mel_emb.shape[0] != text_emb.shape[0]:
mel_emb = self.cached_mel_emb.repeat_interleave(text_emb.shape[0]//self.cached_mel_emb.shape[0], 0)
else:
mel_emb = self.cached_mel_emb
emb = torch.cat([mel_emb, text_emb], dim=1)
else:
emb = self.transformer.get_input_embeddings()(input_ids) + \
self.text_pos_embedding(torch.tensor(attention_mask.shape[1]-mel_len, device=attention_mask.device)).unsqueeze(0).unsqueeze(0)
emb = self.transformer.get_input_embeddings()(input_ids)
if self.text_pos_embedding is not None:
emb = emb + self.text_pos_embedding(torch.tensor(attention_mask.shape[1]-mel_len, device=attention_mask.device)).unsqueeze(0).unsqueeze(0)
transformer_outputs = self.transformer(
inputs_embeds=emb,

View File

@ -115,12 +115,18 @@ class GptTtsHf(nn.Module):
def inference(self, text_inputs, cond_inputs, do_sample=False, temperature=1.0, num_beams=8, repetition_penalty=1):
if not hasattr(self, 'inference_model'):
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.mel_pos_embedding, self.final_norm, self.mel_head)
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, None, self.final_norm, self.mel_head)
text_inputs = F.pad(text_inputs, (0, self.max_symbols_per_phrase - text_inputs.shape[1]), value=self.STOP_TEXT_TOKEN)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.START_TEXT_TOKEN, self.STOP_TEXT_TOKEN)
text_emb = self.text_embedding(text_inputs)
# Format conditioning inputs properly.
if len(cond_inputs.shape) == 3:
cond_inputs = cond_inputs.unsqueeze(1) # Format a single conditioning input as a set of {1}
if cond_inputs.shape[-1] > self.max_conditioning_length:
cond_inputs = cond_inputs[:,:,:,:self.max_conditioning_length]
conds = []
for k in range(cond_inputs.shape[1]):
conds.append(self.conditioning_encoder(cond_inputs[:, k]))

View File

@ -20,6 +20,7 @@ def do_vocoding(dvae, vocoder, diffuser, codes, cond=None, plot_spec=False):
return
# Loads multiple conditioning files at random from a folder.
def load_conditioning_candidates(path, num_conds, sample_rate=22050, cond_length=44100):
candidates = find_files_of_type('img', path, qualifier=is_audio_file)[0]
# Sample with replacement. This can get repeats, but more conveniently handles situations where there are not enough candidates.
@ -37,6 +38,17 @@ def load_conditioning_candidates(path, num_conds, sample_rate=22050, cond_length
return torch.stack(related_mels, dim=0).unsqueeze(0).cuda(), rel_clip.unsqueeze(0).cuda()
def load_conditioning(path, sample_rate=22050, cond_length=44100):
rel_clip = load_audio(path, sample_rate)
gap = rel_clip.shape[-1] - cond_length
if gap < 0:
rel_clip = F.pad(rel_clip, pad=(0, abs(gap)))
elif gap > 0:
rand_start = random.randint(0, gap)
rel_clip = rel_clip[:, rand_start:rand_start + cond_length]
mel_clip = wav_to_mel(rel_clip.unsqueeze(0)).squeeze(0)
return mel_clip.unsqueeze(0).cuda(), rel_clip.unsqueeze(0).cuda()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
@ -46,10 +58,9 @@ if __name__ == '__main__':
parser.add_argument('-dvae_model_name', type=str, help='Name of the DVAE model in opt.', default='dvae')
parser.add_argument('-opt_gpt_tts', type=str, help='Path to options YAML file used to train the GPT-TTS model', default='X:\\dlas\\experiments\\train_gpt_tts.yml')
parser.add_argument('-gpt_tts_model_name', type=str, help='Name of the GPT TTS model in opt.', default='gpt')
parser.add_argument('-gpt_tts_model_path', type=str, help='GPT TTS model checkpoint to load.', default='X:\\dlas\\experiments\\train_gpt_tts\\models\\48000_gpt.pth')
parser.add_argument('-text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.")
parser.add_argument('-cond_path', type=str, help='Folder containing conditioning samples.', default='Y:\\clips\\podcasts-0\\8816_20210511-Pay Taxes Less Frequently_ We\'re Interested')
parser.add_argument('-num_cond', type=int, help='Number of conditioning samples to load.', default=3)
parser.add_argument('-gpt_tts_model_path', type=str, help='GPT TTS model checkpoint to load.', default='X:\\dlas\\experiments\\train_gpt_tts_no_pos\\models\\28500_gpt_ema.pth')
parser.add_argument('-text', type=str, help='Text to speak.', default="Please set this in the courier drone when we dock.")
parser.add_argument('-cond_path', type=str, help='Path to condioning sample.', default='Y:\\clips\\books1\\754_Dan Simmons - The Rise Of Endymion 356 of 450\\00026.wav')
args = parser.parse_args()
print("Loading GPT TTS..")
@ -60,7 +71,7 @@ if __name__ == '__main__':
print("Loading data..")
text = torch.IntTensor(text_to_sequence(args.text, ['english_cleaners'])).unsqueeze(0).cuda()
conds, cond_wav = load_conditioning_candidates(args.cond_path, args.num_cond)
conds, cond_wav = load_conditioning(args.cond_path)
print("Performing GPT inference..")
codes = gpt.inference(text, conds, num_beams=32, repetition_penalty=10.0)