gpt_tts_hf inference fixes
This commit is contained in:
parent
48e3ee9a5b
commit
a42b94ab72
|
@ -125,7 +125,7 @@ class GptTtsHf(nn.Module):
|
|||
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
|
||||
return loss_text.mean(), loss_mel.mean(), mel_logits
|
||||
|
||||
def inference(self, text_inputs, cond_inputs, do_sample=False, temperature=1.0, num_beams=8, repetition_penalty=1):
|
||||
def inference(self, text_inputs, cond_input, **hf_generate_kwargs):
|
||||
if not hasattr(self, 'inference_model'):
|
||||
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, None, self.final_norm, self.mel_head)
|
||||
|
||||
|
@ -133,28 +133,21 @@ class GptTtsHf(nn.Module):
|
|||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.START_TEXT_TOKEN, self.STOP_TEXT_TOKEN)
|
||||
text_emb = self.text_embedding(text_inputs)
|
||||
|
||||
# Format conditioning inputs properly.
|
||||
if len(cond_inputs.shape) == 3:
|
||||
cond_inputs = cond_inputs.unsqueeze(1) # Format a single conditioning input as a set of {1}
|
||||
if cond_inputs.shape[-1] > self.max_conditioning_length:
|
||||
cond_inputs = cond_inputs[:,:,:,:self.max_conditioning_length]
|
||||
# Randomly permute the conditioning spectrogram, to destroy any structure present.
|
||||
cond_input = cond_input[:,:,torch.randperm(cond_input.shape[-1])]
|
||||
if cond_input.shape[-1] > self.max_conditioning_length:
|
||||
cond_input = cond_input[:,:,:self.max_conditioning_length]
|
||||
cond = self.conditioning_encoder(cond_input).unsqueeze(1)
|
||||
|
||||
conds = []
|
||||
for k in range(cond_inputs.shape[1]):
|
||||
conds.append(self.conditioning_encoder(cond_inputs[:, k]))
|
||||
while len(conds) < self.max_conditioning_inputs:
|
||||
conds.append(conds[-1])
|
||||
conds = torch.stack(conds, dim=1)
|
||||
|
||||
emb = torch.cat([text_emb, conds], dim=1)
|
||||
emb = torch.cat([text_emb, cond], dim=1)
|
||||
self.inference_model.store_mel_emb(emb)
|
||||
|
||||
fake_inputs = torch.full((emb.shape[0],emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=text_inputs.device)
|
||||
fake_inputs[:,-1] = self.START_MEL_TOKEN
|
||||
|
||||
gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.START_MEL_TOKEN, pad_token_id=self.STOP_MEL_TOKEN, eos_token_id=self.STOP_MEL_TOKEN,
|
||||
max_length=emb.shape[1]+self.max_mel_tokens, temperature=temperature, num_beams=num_beams, use_cache=True, repetition_penalty=repetition_penalty)
|
||||
return gen[:, fake_inputs.shape[1]:-1]
|
||||
gen = self.inference_model.generate(fake_inputs, bos_token_id=self.START_MEL_TOKEN, pad_token_id=self.STOP_MEL_TOKEN, eos_token_id=self.STOP_MEL_TOKEN,
|
||||
max_length=emb.shape[1]+self.max_mel_tokens, **hf_generate_kwargs)
|
||||
return gen[:, fake_inputs.shape[1]:]
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
|
@ -43,8 +43,7 @@ if __name__ == '__main__':
|
|||
cond = inp if args.cond is None else load_audio(args.cond, 22050)
|
||||
if cond.shape[-1] > 44100+10000:
|
||||
cond = cond[:,10000:54100]
|
||||
cond = torchaudio.transforms.Resample(22050, 10025)(cond.cpu()).cuda()
|
||||
|
||||
print("Performing inference..")
|
||||
roundtripped = roundtrip_vocoding(dvae, diffusion, diffuser, inp, cond).cpu()
|
||||
torchaudio.save('roundtrip_vocoded_output.wav', roundtripped.squeeze(0), 10025)
|
||||
torchaudio.save('roundtrip_vocoded_output.wav', roundtripped.squeeze(0), 11025)
|
|
@ -16,10 +16,6 @@ from utils.options import Loader
|
|||
from utils.util import load_model_from_config
|
||||
|
||||
|
||||
def do_vocoding(dvae, vocoder, diffuser, codes, cond=None, plot_spec=False):
|
||||
return
|
||||
|
||||
|
||||
# Loads multiple conditioning files at random from a folder.
|
||||
def load_conditioning_candidates(path, num_conds, sample_rate=22050, cond_length=44100):
|
||||
candidates = find_files_of_type('img', path, qualifier=is_audio_file)[0]
|
||||
|
@ -50,7 +46,38 @@ def load_conditioning(path, sample_rate=22050, cond_length=44100):
|
|||
return mel_clip.unsqueeze(0).cuda(), rel_clip.unsqueeze(0).cuda()
|
||||
|
||||
|
||||
def fix_autoregressive_output(codes, stop_token):
|
||||
"""
|
||||
This function performs some padding on coded audio that fixes a mismatch issue between what the diffusion model was
|
||||
trained on and what the autoregressive code generator creates (which has no padding or end).
|
||||
This is highly specific to the DVAE being used, so this particular coding will not necessarily work if used with
|
||||
a different DVAE. This can be inferred by feeding a audio clip padded with lots of zeros on the end through the DVAE
|
||||
and copying out the last few codes.
|
||||
|
||||
Failing to do this padding will produce speech with a harsh end that sounds like "BLAH" or similar.
|
||||
"""
|
||||
# Strip off the autoregressive stop token and add padding.
|
||||
stop_token_indices = (codes == stop_token).nonzero()
|
||||
if len(stop_token_indices) == 0:
|
||||
print("No stop tokens found, enjoy that output of yours!")
|
||||
else:
|
||||
codes = codes[:stop_token_indices[0]]
|
||||
|
||||
padding = torch.tensor([83, 83, 83, 83, 83, 83, 83, 83, 83, 45, 45, 248],
|
||||
dtype=torch.long, device=codes.device)
|
||||
return torch.cat([codes, padding])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
preselected_cond_voices = {
|
||||
'trump': 'D:\\data\\audio\\sample_voices\\trump.wav',
|
||||
'ryan_reynolds': 'D:\\data\\audio\\sample_voices\\ryan_reynolds.wav',
|
||||
'ed_sheeran': 'D:\\data\\audio\\sample_voices\\ed_sheeran.wav',
|
||||
'simmons': 'Y:\\clips\\books1\\754_Dan Simmons - The Rise Of Endymion 356 of 450\\00026.wav',
|
||||
'news_girl': 'Y:\\clips\\podcasts-0\\8288_20210113-Is More Violence Coming_\\00022.wav',
|
||||
'dan_carlin': 'Y:\\clips\\books1\5_dchha06 Shield of the West\\00476.wav',
|
||||
}
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt_diffuse', type=str, help='Path to options YAML file used to train the diffusion model', default='X:\\dlas\\experiments\\train_diffusion_vocoder_with_cond_new_dvae.yml')
|
||||
parser.add_argument('-diffusion_model_name', type=str, help='Name of the diffusion model in opt.', default='generator')
|
||||
|
@ -58,9 +85,11 @@ if __name__ == '__main__':
|
|||
parser.add_argument('-dvae_model_name', type=str, help='Name of the DVAE model in opt.', default='dvae')
|
||||
parser.add_argument('-opt_gpt_tts', type=str, help='Path to options YAML file used to train the GPT-TTS model', default='X:\\dlas\\experiments\\train_gpt_tts.yml')
|
||||
parser.add_argument('-gpt_tts_model_name', type=str, help='Name of the GPT TTS model in opt.', default='gpt')
|
||||
parser.add_argument('-gpt_tts_model_path', type=str, help='GPT TTS model checkpoint to load.', default='X:\\dlas\\experiments\\train_gpt_tts_no_pos\\models\\28500_gpt_ema.pth')
|
||||
parser.add_argument('-text', type=str, help='Text to speak.', default="Please set this in the courier drone when we dock.")
|
||||
parser.add_argument('-cond_path', type=str, help='Path to condioning sample.', default='Y:\\clips\\books1\\754_Dan Simmons - The Rise Of Endymion 356 of 450\\00026.wav')
|
||||
parser.add_argument('-gpt_tts_model_path', type=str, help='GPT TTS model checkpoint to load.', default='X:\\dlas\\experiments\\train_gpt_tts_no_pos\\models\\50000_gpt.pth')
|
||||
parser.add_argument('-text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.")
|
||||
parser.add_argument('-cond_path', type=str, help='Path to condioning sample.', default='')
|
||||
parser.add_argument('-cond_preset', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='simmons')
|
||||
parser.add_argument('-num_samples', type=int, help='How many outputs to produce.', default=1)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("Loading GPT TTS..")
|
||||
|
@ -71,12 +100,15 @@ if __name__ == '__main__':
|
|||
|
||||
print("Loading data..")
|
||||
text = torch.IntTensor(text_to_sequence(args.text, ['english_cleaners'])).unsqueeze(0).cuda()
|
||||
conds, cond_wav = load_conditioning(args.cond_path)
|
||||
cond_path = args.cond_path if args.cond_preset is None else preselected_cond_voices[args.cond_preset]
|
||||
conds, cond_wav = load_conditioning(cond_path)
|
||||
|
||||
print("Performing GPT inference..")
|
||||
codes = gpt.inference(text, conds, num_beams=32, repetition_penalty=10.0)
|
||||
codes = gpt.inference(text, conds, num_beams=1, repetition_penalty=1.0, do_sample=True, top_k=20, top_p=.95,
|
||||
num_return_sequences=args.num_samples, length_penalty=.1, early_stopping=True)
|
||||
|
||||
# Delete the GPT TTS model to free up GPU memory
|
||||
stop_token = gpt.STOP_MEL_TOKEN
|
||||
del gpt
|
||||
|
||||
print("Loading DVAE..")
|
||||
|
@ -86,5 +118,9 @@ if __name__ == '__main__':
|
|||
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=50)
|
||||
|
||||
print("Performing vocoding..")
|
||||
wav = do_spectrogram_diffusion(diffusion, dvae, diffuser, codes, cond_wav, spectrogram_compression_factor=128, plt_spec=False)
|
||||
torchaudio.save('gpt_tts_output.wav', wav.squeeze(0).cpu(), 10025)
|
||||
# Perform vocoding on each batch element separately: Vocoding is very memory intensive.
|
||||
for b in range(codes.shape[0]):
|
||||
code = fix_autoregressive_output(codes[b], stop_token).unsqueeze(0)
|
||||
wav = do_spectrogram_diffusion(diffusion, dvae, diffuser, code, cond_wav,
|
||||
spectrogram_compression_factor=128, plt_spec=False)
|
||||
torchaudio.save(f'gpt_tts_output_{b}.wav', wav.squeeze(0).cpu(), 11025)
|
Loading…
Reference in New Issue
Block a user