Fix gpt_tts_hf inference
This commit is contained in:
parent
712d746e9b
commit
53858b2055
|
@ -145,15 +145,17 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
||||||
if input_ids.shape[1] != 1:
|
if input_ids.shape[1] != 1:
|
||||||
text_inputs = input_ids[:, mel_len:]
|
text_inputs = input_ids[:, mel_len:]
|
||||||
text_emb = self.transformer.get_input_embeddings()(text_inputs)
|
text_emb = self.transformer.get_input_embeddings()(text_inputs)
|
||||||
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_emb.device))
|
if self.text_pos_embedding is not None:
|
||||||
|
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_emb.device))
|
||||||
if self.cached_mel_emb.shape[0] != text_emb.shape[0]:
|
if self.cached_mel_emb.shape[0] != text_emb.shape[0]:
|
||||||
mel_emb = self.cached_mel_emb.repeat_interleave(text_emb.shape[0]//self.cached_mel_emb.shape[0], 0)
|
mel_emb = self.cached_mel_emb.repeat_interleave(text_emb.shape[0]//self.cached_mel_emb.shape[0], 0)
|
||||||
else:
|
else:
|
||||||
mel_emb = self.cached_mel_emb
|
mel_emb = self.cached_mel_emb
|
||||||
emb = torch.cat([mel_emb, text_emb], dim=1)
|
emb = torch.cat([mel_emb, text_emb], dim=1)
|
||||||
else:
|
else:
|
||||||
emb = self.transformer.get_input_embeddings()(input_ids) + \
|
emb = self.transformer.get_input_embeddings()(input_ids)
|
||||||
self.text_pos_embedding(torch.tensor(attention_mask.shape[1]-mel_len, device=attention_mask.device)).unsqueeze(0).unsqueeze(0)
|
if self.text_pos_embedding is not None:
|
||||||
|
emb = emb + self.text_pos_embedding(torch.tensor(attention_mask.shape[1]-mel_len, device=attention_mask.device)).unsqueeze(0).unsqueeze(0)
|
||||||
|
|
||||||
transformer_outputs = self.transformer(
|
transformer_outputs = self.transformer(
|
||||||
inputs_embeds=emb,
|
inputs_embeds=emb,
|
||||||
|
|
|
@ -115,12 +115,18 @@ class GptTtsHf(nn.Module):
|
||||||
|
|
||||||
def inference(self, text_inputs, cond_inputs, do_sample=False, temperature=1.0, num_beams=8, repetition_penalty=1):
|
def inference(self, text_inputs, cond_inputs, do_sample=False, temperature=1.0, num_beams=8, repetition_penalty=1):
|
||||||
if not hasattr(self, 'inference_model'):
|
if not hasattr(self, 'inference_model'):
|
||||||
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.mel_pos_embedding, self.final_norm, self.mel_head)
|
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, None, self.final_norm, self.mel_head)
|
||||||
|
|
||||||
text_inputs = F.pad(text_inputs, (0, self.max_symbols_per_phrase - text_inputs.shape[1]), value=self.STOP_TEXT_TOKEN)
|
text_inputs = F.pad(text_inputs, (0, self.max_symbols_per_phrase - text_inputs.shape[1]), value=self.STOP_TEXT_TOKEN)
|
||||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.START_TEXT_TOKEN, self.STOP_TEXT_TOKEN)
|
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.START_TEXT_TOKEN, self.STOP_TEXT_TOKEN)
|
||||||
text_emb = self.text_embedding(text_inputs)
|
text_emb = self.text_embedding(text_inputs)
|
||||||
|
|
||||||
|
# Format conditioning inputs properly.
|
||||||
|
if len(cond_inputs.shape) == 3:
|
||||||
|
cond_inputs = cond_inputs.unsqueeze(1) # Format a single conditioning input as a set of {1}
|
||||||
|
if cond_inputs.shape[-1] > self.max_conditioning_length:
|
||||||
|
cond_inputs = cond_inputs[:,:,:,:self.max_conditioning_length]
|
||||||
|
|
||||||
conds = []
|
conds = []
|
||||||
for k in range(cond_inputs.shape[1]):
|
for k in range(cond_inputs.shape[1]):
|
||||||
conds.append(self.conditioning_encoder(cond_inputs[:, k]))
|
conds.append(self.conditioning_encoder(cond_inputs[:, k]))
|
||||||
|
|
|
@ -20,6 +20,7 @@ def do_vocoding(dvae, vocoder, diffuser, codes, cond=None, plot_spec=False):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
# Loads multiple conditioning files at random from a folder.
|
||||||
def load_conditioning_candidates(path, num_conds, sample_rate=22050, cond_length=44100):
|
def load_conditioning_candidates(path, num_conds, sample_rate=22050, cond_length=44100):
|
||||||
candidates = find_files_of_type('img', path, qualifier=is_audio_file)[0]
|
candidates = find_files_of_type('img', path, qualifier=is_audio_file)[0]
|
||||||
# Sample with replacement. This can get repeats, but more conveniently handles situations where there are not enough candidates.
|
# Sample with replacement. This can get repeats, but more conveniently handles situations where there are not enough candidates.
|
||||||
|
@ -37,6 +38,17 @@ def load_conditioning_candidates(path, num_conds, sample_rate=22050, cond_length
|
||||||
return torch.stack(related_mels, dim=0).unsqueeze(0).cuda(), rel_clip.unsqueeze(0).cuda()
|
return torch.stack(related_mels, dim=0).unsqueeze(0).cuda(), rel_clip.unsqueeze(0).cuda()
|
||||||
|
|
||||||
|
|
||||||
|
def load_conditioning(path, sample_rate=22050, cond_length=44100):
|
||||||
|
rel_clip = load_audio(path, sample_rate)
|
||||||
|
gap = rel_clip.shape[-1] - cond_length
|
||||||
|
if gap < 0:
|
||||||
|
rel_clip = F.pad(rel_clip, pad=(0, abs(gap)))
|
||||||
|
elif gap > 0:
|
||||||
|
rand_start = random.randint(0, gap)
|
||||||
|
rel_clip = rel_clip[:, rand_start:rand_start + cond_length]
|
||||||
|
mel_clip = wav_to_mel(rel_clip.unsqueeze(0)).squeeze(0)
|
||||||
|
return mel_clip.unsqueeze(0).cuda(), rel_clip.unsqueeze(0).cuda()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
@ -46,10 +58,9 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('-dvae_model_name', type=str, help='Name of the DVAE model in opt.', default='dvae')
|
parser.add_argument('-dvae_model_name', type=str, help='Name of the DVAE model in opt.', default='dvae')
|
||||||
parser.add_argument('-opt_gpt_tts', type=str, help='Path to options YAML file used to train the GPT-TTS model', default='X:\\dlas\\experiments\\train_gpt_tts.yml')
|
parser.add_argument('-opt_gpt_tts', type=str, help='Path to options YAML file used to train the GPT-TTS model', default='X:\\dlas\\experiments\\train_gpt_tts.yml')
|
||||||
parser.add_argument('-gpt_tts_model_name', type=str, help='Name of the GPT TTS model in opt.', default='gpt')
|
parser.add_argument('-gpt_tts_model_name', type=str, help='Name of the GPT TTS model in opt.', default='gpt')
|
||||||
parser.add_argument('-gpt_tts_model_path', type=str, help='GPT TTS model checkpoint to load.', default='X:\\dlas\\experiments\\train_gpt_tts\\models\\48000_gpt.pth')
|
parser.add_argument('-gpt_tts_model_path', type=str, help='GPT TTS model checkpoint to load.', default='X:\\dlas\\experiments\\train_gpt_tts_no_pos\\models\\28500_gpt_ema.pth')
|
||||||
parser.add_argument('-text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.")
|
parser.add_argument('-text', type=str, help='Text to speak.', default="Please set this in the courier drone when we dock.")
|
||||||
parser.add_argument('-cond_path', type=str, help='Folder containing conditioning samples.', default='Y:\\clips\\podcasts-0\\8816_20210511-Pay Taxes Less Frequently_ We\'re Interested')
|
parser.add_argument('-cond_path', type=str, help='Path to condioning sample.', default='Y:\\clips\\books1\\754_Dan Simmons - The Rise Of Endymion 356 of 450\\00026.wav')
|
||||||
parser.add_argument('-num_cond', type=int, help='Number of conditioning samples to load.', default=3)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
print("Loading GPT TTS..")
|
print("Loading GPT TTS..")
|
||||||
|
@ -60,7 +71,7 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
print("Loading data..")
|
print("Loading data..")
|
||||||
text = torch.IntTensor(text_to_sequence(args.text, ['english_cleaners'])).unsqueeze(0).cuda()
|
text = torch.IntTensor(text_to_sequence(args.text, ['english_cleaners'])).unsqueeze(0).cuda()
|
||||||
conds, cond_wav = load_conditioning_candidates(args.cond_path, args.num_cond)
|
conds, cond_wav = load_conditioning(args.cond_path)
|
||||||
|
|
||||||
print("Performing GPT inference..")
|
print("Performing GPT inference..")
|
||||||
codes = gpt.inference(text, conds, num_beams=32, repetition_penalty=10.0)
|
codes = gpt.inference(text, conds, num_beams=32, repetition_penalty=10.0)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user