From 34d9d5f202d8101a7014e35b4073e7f19d743565 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 15 Jun 2022 16:41:08 -0600 Subject: [PATCH] adf for ar-latent tfd --- codes/trainer/eval/audio_diffusion_fid.py | 38 +++++++++++++++++++++-- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/codes/trainer/eval/audio_diffusion_fid.py b/codes/trainer/eval/audio_diffusion_fid.py index 91d6eab7..cb81a49d 100644 --- a/codes/trainer/eval/audio_diffusion_fid.py +++ b/codes/trainer/eval/audio_diffusion_fid.py @@ -91,6 +91,15 @@ class AudioDiffusionFid(evaluator.Evaluator): elif 'tfd' == mode: self.diffusion_fn = self.perform_diffusion_tfd self.local_modules['vocoder'] = load_univnet_vocoder().cpu() + elif 'tfd_ar' == mode: + self.local_modules['dvae'] = load_speech_dvae().cpu() + self.local_modules['autoregressive'] = load_model_from_config("../experiments/train_gpt_tts_unified.yml", + model_name='gpt', + also_load_savepoint=False, + load_path='../experiments/tortoise_ar.pth', + device=torch.device('cpu')).cuda().eval() + self.diffusion_fn = self.perform_diffusion_tfd_ar_prior + self.local_modules['vocoder'] = load_univnet_vocoder().cpu() def perform_diffusion_tts(self, audio, codes, text, sample_rate=5500): real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0) @@ -219,6 +228,29 @@ class AudioDiffusionFid(evaluator.Evaluator): real_dec = self.local_modules['vocoder'].inference(denormalize_mel(umel)) return gen_wav.float(), real_dec, gen_mel, umel, SAMPLE_RATE + def perform_diffusion_tfd_ar_prior(self, audio, codes, text): + SAMPLE_RATE = 24000 + audio_resampled = torchaudio.functional.resample(audio, 22050, SAMPLE_RATE).unsqueeze(0) + vmel = wav_to_mel(audio) + umel = wav_to_univnet_mel(audio_resampled, do_normalization=True) + + mel_codes = convert_mel_to_codes(self.local_modules['dvae'], vmel) + text_codes = torch.LongTensor(self.bpe_tokenizer.encode(text)).unsqueeze(0).to(vmel.device) + cond_inputs = pad_or_truncate(vmel, 132300//256).unsqueeze(1) + mlc = self.local_modules['autoregressive'].mel_length_compression + auto_latents = self.local_modules['autoregressive'](cond_inputs, text_codes, + torch.tensor([text_codes.shape[-1]], device=vmel.device), + mel_codes, + torch.tensor([mel_codes.shape[-1]*mlc], device=vmel.device), + text_first=True, raw_mels=None, return_latent=True) + + gen_mel = self.diffuser.p_sample_loop(self.model, umel.shape, + model_kwargs={'codes': auto_latents}) + + gen_wav = self.local_modules['vocoder'].inference(denormalize_mel(gen_mel)) + real_dec = self.local_modules['vocoder'].inference(denormalize_mel(umel)) + return gen_wav.float(), real_dec, gen_mel, umel, SAMPLE_RATE + def load_projector(self): """ Builds the CLIP model used to project speech into a latent. This model has fixed parameters and a fixed loading @@ -335,12 +367,12 @@ if __name__ == '__main__': if __name__ == '__main__': - diffusion = load_model_from_config('X:\\dlas\\experiments\\train_tts_diffusion_tfd11_quant.yml', 'generator', + diffusion = load_model_from_config('X:\\dlas\\experiments\\train_tts_diffusion_tfd12_ar_inputs.yml', 'generator', also_load_savepoint=False, - load_path='X:\\dlas\\experiments\\train_tts_diffusion_tfd12_linear_dvae\\models\\12000_generator.pth').cuda() + load_path='X:\\dlas\\experiments\\train_tts_diffusion_tfd12_ar_inputs_pretrain\\models\\4500_generator.pth').cuda() opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv', 'diffusion_steps': 50, 'conditioning_free': False, 'conditioning_free_k': 1, - 'diffusion_schedule': 'linear', 'diffusion_type': 'tfd'} + 'diffusion_schedule': 'linear', 'diffusion_type': 'tfd_ar'} env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 101, 'device': 'cuda', 'opt': {}} eval = AudioDiffusionFid(diffusion, opt_eval, env) print(eval.perform_eval())