From 368dca18b14fec8a33882b5992da4c787888313e Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 19 Jun 2022 15:07:24 -0600 Subject: [PATCH] mdf fixes + support for tfd-based waveform gen --- codes/trainer/eval/music_diffusion_fid.py | 43 ++++++++--------------- 1 file changed, 14 insertions(+), 29 deletions(-) diff --git a/codes/trainer/eval/music_diffusion_fid.py b/codes/trainer/eval/music_diffusion_fid.py index b41f7572..066b2e28 100644 --- a/codes/trainer/eval/music_diffusion_fid.py +++ b/codes/trainer/eval/music_diffusion_fid.py @@ -79,24 +79,18 @@ class MusicDiffusionFid(evaluator.Evaluator): return list(glob(f'{path}/*.wav')) def perform_diffusion_spec_decode(self, audio, sample_rate=22050): - if sample_rate != sample_rate: - real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0) - else: - real_resampled = audio + real_resampled = audio audio = audio.unsqueeze(0) - output_shape = (1, 16, audio.shape[-1] // 16) + output_shape = (1, 256, audio.shape[-1] // 256) mel = self.spec_fn({'in': audio})['out'] gen = self.diffuser.p_sample_loop(self.model, output_shape, - model_kwargs={'aligned_conditioning': mel}) - gen = pixel_shuffle_1d(gen, 16) + model_kwargs={'codes': mel}) + gen = pixel_shuffle_1d(gen, 256) return gen, real_resampled, normalize_mel(self.spec_fn({'in': gen})['out']), normalize_mel(mel), sample_rate def perform_diffusion_from_codes(self, audio, sample_rate=22050): - if sample_rate != sample_rate: - real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0) - else: - real_resampled = audio + real_resampled = audio audio = audio.unsqueeze(0) mel = self.spec_fn({'in': audio})['out'] @@ -116,10 +110,7 @@ class MusicDiffusionFid(evaluator.Evaluator): return gen_wav, real_resampled, gen_mel, mel_norm, sample_rate def perform_diffusion_from_codes_quant(self, audio, sample_rate=22050): - if sample_rate != sample_rate: - real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0) - else: - real_resampled = audio + real_resampled = audio audio = audio.unsqueeze(0) mel = self.spec_fn({'in': audio})['out'] @@ -148,10 +139,7 @@ class MusicDiffusionFid(evaluator.Evaluator): return gen_wav, real_wav.squeeze(0), gen_mel, mel_norm, sample_rate def perform_partial_diffusion_from_codes_quant(self, audio, sample_rate=22050, partial_low=0, partial_high=256): - if sample_rate != sample_rate: - real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0) - else: - real_resampled = audio + real_resampled = audio audio = audio.unsqueeze(0) mel = self.spec_fn({'in': audio})['out'] @@ -174,10 +162,7 @@ class MusicDiffusionFid(evaluator.Evaluator): return gen_wav, real_resampled, gen_mel, mel_norm, sample_rate def perform_diffusion_from_codes_quant_gradual_decode(self, audio, sample_rate=22050): - if sample_rate != sample_rate: - real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0) - else: - real_resampled = audio + real_resampled = audio audio = audio.unsqueeze(0) mel = self.spec_fn({'in': audio})['out'] @@ -273,17 +258,17 @@ class MusicDiffusionFid(evaluator.Evaluator): if __name__ == '__main__': - diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_tfd_quant.yml', 'generator', + diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_waveform_gen.yml', 'generator', also_load_savepoint=False, - load_path='X:\\dlas\\experiments\\train_music_diffusion_tfd12\\models\\41500_generator_ema.pth' + load_path='X:\\dlas\\experiments\\train_music_waveform_gen_retry\\models\\22000_generator_ema.pth' ).cuda() opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :) #'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety. - 'diffusion_steps': 200, - 'conditioning_free': True, 'conditioning_free_k': 2, - 'diffusion_schedule': 'linear', 'diffusion_type': 'from_codes_quant', + 'diffusion_steps': 100, + 'conditioning_free': False, 'conditioning_free_k': 1, + 'diffusion_schedule': 'linear', 'diffusion_type': 'spec_decode', #'partial_low': 128, 'partial_high': 192 } - env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 605, 'device': 'cuda', 'opt': {}} + env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 100, 'device': 'cuda', 'opt': {}} eval = MusicDiffusionFid(diffusion, opt_eval, env) print(eval.perform_eval())