From 1e3a8554a131824fc4c0aaa154e47bb0cc300ef6 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 15 Mar 2022 11:35:09 -0600 Subject: [PATCH] updates to audio_diffusion_fid --- .../audio/gen/speech_synthesis_utils.py | 24 +++++++- codes/trainer/eval/audio_diffusion_fid.py | 58 +++++++++++-------- 2 files changed, 55 insertions(+), 27 deletions(-) diff --git a/codes/scripts/audio/gen/speech_synthesis_utils.py b/codes/scripts/audio/gen/speech_synthesis_utils.py index 45fc0079..4a3c6158 100644 --- a/codes/scripts/audio/gen/speech_synthesis_utils.py +++ b/codes/scripts/audio/gen/speech_synthesis_utils.py @@ -5,16 +5,28 @@ import torch from data.audio.unsupervised_audio_dataset import load_audio from data.util import find_files_of_type, is_audio_file +from models.audio.vocoders.univnet.generator import UnivNetGenerator from models.diffusion.gaussian_diffusion import get_named_beta_schedule from models.diffusion.respace import SpacedDiffusion, space_timesteps -from trainer.injectors.audio_injectors import TorchMelSpectrogramInjector +from trainer.injectors.audio_injectors import TorchMelSpectrogramInjector, MelSpectrogramInjector from utils.audio import plot_spectrogram from utils.util import load_model_from_config def load_speech_dvae(): - return load_model_from_config("../experiments/train_diffusion_vocoder_22k_level.yml", + dvae = load_model_from_config("../experiments/train_diffusion_vocoder_22k_level.yml", "dvae").cpu() + dvae.eval() + return dvae + + +def load_univnet_vocoder(): + model = UnivNetGenerator() + sd = torch.load('univnet_c32_pretrained_libri.pt') + model.load_state_dict(sd) + model = model.cpu() + model.eval() + return model def wav_to_mel(wav, mel_norms_file='../experiments/clips_mel_norms.pth'): @@ -24,6 +36,14 @@ def wav_to_mel(wav, mel_norms_file='../experiments/clips_mel_norms.pth'): return TorchMelSpectrogramInjector({'in': 'wav', 'out': 'mel', 'mel_norm_file': mel_norms_file},{})({'wav': wav})['mel'] +def wav_to_univnet_mel(wav): + """ + Converts an audio clip into a MEL tensor that the univnet vocoder knows how to decode. + """ + return MelSpectrogramInjector({'in': 'wav', 'out': 'mel', 'sampling_rate': 24000, + 'n_mel_channels': 100, 'mel_fmax': 12000},{})({'wav': wav})['mel'] + + def convert_mel_to_codes(dvae_model, mel): """ Converts an audio clip into discrete codes. diff --git a/codes/trainer/eval/audio_diffusion_fid.py b/codes/trainer/eval/audio_diffusion_fid.py index 2a663124..294656f8 100644 --- a/codes/trainer/eval/audio_diffusion_fid.py +++ b/codes/trainer/eval/audio_diffusion_fid.py @@ -15,7 +15,7 @@ from data.audio.unsupervised_audio_dataset import load_audio from models.clip.mel_text_clip import MelTextCLIP from models.audio.tts.tacotron2 import text_to_sequence from scripts.audio.gen.speech_synthesis_utils import load_discrete_vocoder_diffuser, wav_to_mel, load_speech_dvae, \ - convert_mel_to_codes + convert_mel_to_codes, load_univnet_vocoder, wav_to_univnet_mel from utils.util import ceil_multiple, opt_get @@ -43,16 +43,19 @@ class AudioDiffusionFid(evaluator.Evaluator): conditioning_free_k=conditioning_free_k) self.dev = self.env['device'] mode = opt_get(opt_eval, ['diffusion_type'], 'tts') + self.local_modules = {} if mode == 'tts': self.diffusion_fn = self.perform_diffusion_tts elif mode == 'original_vocoder': - self.dvae = load_speech_dvae().to(self.env['device']) - self.dvae.eval() + self.local_modules['dvae'] = load_speech_dvae().cpu() self.diffusion_fn = self.perform_original_diffusion_vocoder elif mode == 'vocoder': - self.dvae = load_speech_dvae().to(self.env['device']) - self.dvae.eval() + self.local_modules['dvae'] = load_speech_dvae().cpu() self.diffusion_fn = self.perform_diffusion_vocoder + elif mode == 'tts9_mel': + self.local_modules['dvae'] = load_speech_dvae().cpu() + self.local_modules['vocoder'] = load_univnet_vocoder().cpu() + self.diffusion_fn = self.perform_diffusion_tts9_mel_from_codes def perform_diffusion_tts(self, audio, codes, text, sample_rate=5500): real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0) @@ -71,8 +74,8 @@ class AudioDiffusionFid(evaluator.Evaluator): def perform_original_diffusion_vocoder(self, audio, codes, text, sample_rate=11025): mel = wav_to_mel(audio) - mel_codes = convert_mel_to_codes(self.dvae, mel) - back_to_mel = self.dvae.decode(mel_codes)[0] + mel_codes = convert_mel_to_codes(self.local_modules['dvae'], mel) + back_to_mel = self.local_modules['dvae'].decode(mel_codes)[0] orig_audio = audio real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0) @@ -96,7 +99,7 @@ class AudioDiffusionFid(evaluator.Evaluator): def perform_diffusion_vocoder(self, audio, codes, text, sample_rate=5500): mel = wav_to_mel(audio) - mel_codes = convert_mel_to_codes(self.dvae, mel) + mel_codes = convert_mel_to_codes(self.local_modules['dvae'], mel) text_codes = text_to_sequence(text) real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0) @@ -115,11 +118,12 @@ class AudioDiffusionFid(evaluator.Evaluator): return gen, real_resampled, sample_rate - def perform_diffusion_tts9_from_codes(self, audio, codes, text, sample_rate=5500): + def perform_diffusion_tts9_mel_from_codes(self, audio, codes, text): + SAMPLE_RATE = 24000 mel = wav_to_mel(audio) - mel_codes = convert_mel_to_codes(self.dvae, mel) - text_codes = text_to_sequence(text) - real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0) + mel_codes = convert_mel_to_codes(self.local_modules['dvae'], mel) + real_resampled = torchaudio.functional.resample(audio, 22050, SAMPLE_RATE).unsqueeze(0) + univnet_mel = wav_to_univnet_mel(audio) # to be used for a conditioning input output_size = real_resampled.shape[-1] aligned_codes_compression_factor = output_size // mel_codes.shape[-1] @@ -129,11 +133,12 @@ class AudioDiffusionFid(evaluator.Evaluator): if padding_needed_for_codes > 0: mel_codes = F.pad(mel_codes, (0, padding_needed_for_codes)) output_shape = (1, 1, padded_size) - gen = self.diffuser.p_sample_loop(self.model, output_shape, - model_kwargs={'tokens': mel_codes, - 'conditioning_input': audio.unsqueeze(0), - 'unaligned_input': torch.tensor(text_codes, device=audio.device).unsqueeze(0)}) - return gen, real_resampled, sample_rate + gen_mel = self.diffuser.p_sample_loop(self.model, output_shape, + model_kwargs={'aligned_conditioning': mel_codes, + 'conditioning_input': univnet_mel}) + + gen_wav = self.local_modules['vocoder'](gen_mel) + return gen_wav, real_resampled, SAMPLE_RATE def load_projector(self): """ @@ -187,11 +192,11 @@ class AudioDiffusionFid(evaluator.Evaluator): projector = self.load_projector().to(self.env['device']) projector.eval() - if hasattr(self, 'dvae'): - self.dvae = self.dvae.to(self.env['device']) w2v = self.load_w2v().to(self.env['device']) w2v.eval() + for k, mod in self.local_modules.items(): + self.local_modules[k] = mod.to(self.env['device']) # Attempt to fix the random state as much as possible. RNG state will be restored before returning. rng_state = torch.get_rng_state() @@ -226,10 +231,12 @@ class AudioDiffusionFid(evaluator.Evaluator): intelligibility_loss = intelligibility_loss / distributed.get_world_size() self.model.train() - if hasattr(self, 'dvae'): - self.dvae = self.dvae.to('cpu') torch.set_rng_state(rng_state) + # Put modules used for evaluation back into CPU memory. + for k, mod in self.local_modules.items(): + self.local_modules[k] = mod.cpu() + return {"frechet_distance": frechet_distance, "intelligibility_loss": intelligibility_loss} """ @@ -250,11 +257,12 @@ if __name__ == '__main__': if __name__ == '__main__': from utils.util import load_model_from_config - diffusion = load_model_from_config('X:\\dlas\\experiments\\train_diffusion_vocoder_clips_from_dvae_archived_r3_b256_conditioning\\config.yml', 'generator', - also_load_savepoint=False, load_path='X:\\dlas\\experiments\\train_diffusion_vocoder_clips_from_dvae_archived_r3_b256_conditioning\\models\\80800_generator_ema.pth').cuda() + diffusion = load_model_from_config('X:\\dlas\\experiments\\train_diffusion_tts9.yml', 'generator', + also_load_savepoint=False, + load_path='X:\\dlas\\experiments\\train_diffusion_tts9\\models\\7500_generator_ema.pth').cuda() opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv', 'diffusion_steps': 100, 'conditioning_free': False, 'conditioning_free_k': 1, - 'diffusion_schedule': 'linear', 'diffusion_type': 'original_vocoder'} - env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 4, 'device': 'cuda', 'opt': {}} + 'diffusion_schedule': 'linear', 'diffusion_type': 'tts9_mel'} + env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 555, 'device': 'cuda', 'opt': {}} eval = AudioDiffusionFid(diffusion, opt_eval, env) print(eval.perform_eval())