gpt_tts_hf inference fixes

This commit is contained in:
James Betker 2021-12-22 13:22:15 -07:00
parent 48e3ee9a5b
commit a42b94ab72
3 changed files with 58 additions and 30 deletions

View File

@ -125,7 +125,7 @@ class GptTtsHf(nn.Module):
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
return loss_text.mean(), loss_mel.mean(), mel_logits
def inference(self, text_inputs, cond_inputs, do_sample=False, temperature=1.0, num_beams=8, repetition_penalty=1):
def inference(self, text_inputs, cond_input, **hf_generate_kwargs):
if not hasattr(self, 'inference_model'):
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, None, self.final_norm, self.mel_head)
@ -133,28 +133,21 @@ class GptTtsHf(nn.Module):
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.START_TEXT_TOKEN, self.STOP_TEXT_TOKEN)
text_emb = self.text_embedding(text_inputs)
# Format conditioning inputs properly.
if len(cond_inputs.shape) == 3:
cond_inputs = cond_inputs.unsqueeze(1) # Format a single conditioning input as a set of {1}
if cond_inputs.shape[-1] > self.max_conditioning_length:
cond_inputs = cond_inputs[:,:,:,:self.max_conditioning_length]
# Randomly permute the conditioning spectrogram, to destroy any structure present.
cond_input = cond_input[:,:,torch.randperm(cond_input.shape[-1])]
if cond_input.shape[-1] > self.max_conditioning_length:
cond_input = cond_input[:,:,:self.max_conditioning_length]
cond = self.conditioning_encoder(cond_input).unsqueeze(1)
conds = []
for k in range(cond_inputs.shape[1]):
conds.append(self.conditioning_encoder(cond_inputs[:, k]))
while len(conds) < self.max_conditioning_inputs:
conds.append(conds[-1])
conds = torch.stack(conds, dim=1)
emb = torch.cat([text_emb, conds], dim=1)
emb = torch.cat([text_emb, cond], dim=1)
self.inference_model.store_mel_emb(emb)
fake_inputs = torch.full((emb.shape[0],emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=text_inputs.device)
fake_inputs[:,-1] = self.START_MEL_TOKEN
gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.START_MEL_TOKEN, pad_token_id=self.STOP_MEL_TOKEN, eos_token_id=self.STOP_MEL_TOKEN,
max_length=emb.shape[1]+self.max_mel_tokens, temperature=temperature, num_beams=num_beams, use_cache=True, repetition_penalty=repetition_penalty)
return gen[:, fake_inputs.shape[1]:-1]
gen = self.inference_model.generate(fake_inputs, bos_token_id=self.START_MEL_TOKEN, pad_token_id=self.STOP_MEL_TOKEN, eos_token_id=self.STOP_MEL_TOKEN,
max_length=emb.shape[1]+self.max_mel_tokens, **hf_generate_kwargs)
return gen[:, fake_inputs.shape[1]:]
@register_model

View File

@ -43,8 +43,7 @@ if __name__ == '__main__':
cond = inp if args.cond is None else load_audio(args.cond, 22050)
if cond.shape[-1] > 44100+10000:
cond = cond[:,10000:54100]
cond = torchaudio.transforms.Resample(22050, 10025)(cond.cpu()).cuda()
print("Performing inference..")
roundtripped = roundtrip_vocoding(dvae, diffusion, diffuser, inp, cond).cpu()
torchaudio.save('roundtrip_vocoded_output.wav', roundtripped.squeeze(0), 10025)
torchaudio.save('roundtrip_vocoded_output.wav', roundtripped.squeeze(0), 11025)

View File

@ -16,10 +16,6 @@ from utils.options import Loader
from utils.util import load_model_from_config
def do_vocoding(dvae, vocoder, diffuser, codes, cond=None, plot_spec=False):
return
# Loads multiple conditioning files at random from a folder.
def load_conditioning_candidates(path, num_conds, sample_rate=22050, cond_length=44100):
candidates = find_files_of_type('img', path, qualifier=is_audio_file)[0]
@ -50,7 +46,38 @@ def load_conditioning(path, sample_rate=22050, cond_length=44100):
return mel_clip.unsqueeze(0).cuda(), rel_clip.unsqueeze(0).cuda()
def fix_autoregressive_output(codes, stop_token):
"""
This function performs some padding on coded audio that fixes a mismatch issue between what the diffusion model was
trained on and what the autoregressive code generator creates (which has no padding or end).
This is highly specific to the DVAE being used, so this particular coding will not necessarily work if used with
a different DVAE. This can be inferred by feeding a audio clip padded with lots of zeros on the end through the DVAE
and copying out the last few codes.
Failing to do this padding will produce speech with a harsh end that sounds like "BLAH" or similar.
"""
# Strip off the autoregressive stop token and add padding.
stop_token_indices = (codes == stop_token).nonzero()
if len(stop_token_indices) == 0:
print("No stop tokens found, enjoy that output of yours!")
else:
codes = codes[:stop_token_indices[0]]
padding = torch.tensor([83, 83, 83, 83, 83, 83, 83, 83, 83, 45, 45, 248],
dtype=torch.long, device=codes.device)
return torch.cat([codes, padding])
if __name__ == '__main__':
preselected_cond_voices = {
'trump': 'D:\\data\\audio\\sample_voices\\trump.wav',
'ryan_reynolds': 'D:\\data\\audio\\sample_voices\\ryan_reynolds.wav',
'ed_sheeran': 'D:\\data\\audio\\sample_voices\\ed_sheeran.wav',
'simmons': 'Y:\\clips\\books1\\754_Dan Simmons - The Rise Of Endymion 356 of 450\\00026.wav',
'news_girl': 'Y:\\clips\\podcasts-0\\8288_20210113-Is More Violence Coming_\\00022.wav',
'dan_carlin': 'Y:\\clips\\books1\5_dchha06 Shield of the West\\00476.wav',
}
parser = argparse.ArgumentParser()
parser.add_argument('-opt_diffuse', type=str, help='Path to options YAML file used to train the diffusion model', default='X:\\dlas\\experiments\\train_diffusion_vocoder_with_cond_new_dvae.yml')
parser.add_argument('-diffusion_model_name', type=str, help='Name of the diffusion model in opt.', default='generator')
@ -58,9 +85,11 @@ if __name__ == '__main__':
parser.add_argument('-dvae_model_name', type=str, help='Name of the DVAE model in opt.', default='dvae')
parser.add_argument('-opt_gpt_tts', type=str, help='Path to options YAML file used to train the GPT-TTS model', default='X:\\dlas\\experiments\\train_gpt_tts.yml')
parser.add_argument('-gpt_tts_model_name', type=str, help='Name of the GPT TTS model in opt.', default='gpt')
parser.add_argument('-gpt_tts_model_path', type=str, help='GPT TTS model checkpoint to load.', default='X:\\dlas\\experiments\\train_gpt_tts_no_pos\\models\\28500_gpt_ema.pth')
parser.add_argument('-text', type=str, help='Text to speak.', default="Please set this in the courier drone when we dock.")
parser.add_argument('-cond_path', type=str, help='Path to condioning sample.', default='Y:\\clips\\books1\\754_Dan Simmons - The Rise Of Endymion 356 of 450\\00026.wav')
parser.add_argument('-gpt_tts_model_path', type=str, help='GPT TTS model checkpoint to load.', default='X:\\dlas\\experiments\\train_gpt_tts_no_pos\\models\\50000_gpt.pth')
parser.add_argument('-text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.")
parser.add_argument('-cond_path', type=str, help='Path to condioning sample.', default='')
parser.add_argument('-cond_preset', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='simmons')
parser.add_argument('-num_samples', type=int, help='How many outputs to produce.', default=1)
args = parser.parse_args()
print("Loading GPT TTS..")
@ -71,12 +100,15 @@ if __name__ == '__main__':
print("Loading data..")
text = torch.IntTensor(text_to_sequence(args.text, ['english_cleaners'])).unsqueeze(0).cuda()
conds, cond_wav = load_conditioning(args.cond_path)
cond_path = args.cond_path if args.cond_preset is None else preselected_cond_voices[args.cond_preset]
conds, cond_wav = load_conditioning(cond_path)
print("Performing GPT inference..")
codes = gpt.inference(text, conds, num_beams=32, repetition_penalty=10.0)
codes = gpt.inference(text, conds, num_beams=1, repetition_penalty=1.0, do_sample=True, top_k=20, top_p=.95,
num_return_sequences=args.num_samples, length_penalty=.1, early_stopping=True)
# Delete the GPT TTS model to free up GPU memory
stop_token = gpt.STOP_MEL_TOKEN
del gpt
print("Loading DVAE..")
@ -86,5 +118,9 @@ if __name__ == '__main__':
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=50)
print("Performing vocoding..")
wav = do_spectrogram_diffusion(diffusion, dvae, diffuser, codes, cond_wav, spectrogram_compression_factor=128, plt_spec=False)
torchaudio.save('gpt_tts_output.wav', wav.squeeze(0).cpu(), 10025)
# Perform vocoding on each batch element separately: Vocoding is very memory intensive.
for b in range(codes.shape[0]):
code = fix_autoregressive_output(codes[b], stop_token).unsqueeze(0)
wav = do_spectrogram_diffusion(diffusion, dvae, diffuser, code, cond_wav,
spectrogram_compression_factor=128, plt_spec=False)
torchaudio.save(f'gpt_tts_output_{b}.wav', wav.squeeze(0).cpu(), 11025)