From 13c263e9fb57ac1a94e7c16d0e3ef05e7d121585 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 21 Jul 2022 00:51:27 -0600 Subject: [PATCH] go all in on m2wv3 --- codes/trainer/eval/music_diffusion_fid.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/codes/trainer/eval/music_diffusion_fid.py b/codes/trainer/eval/music_diffusion_fid.py index 3fda8f72..6150f893 100644 --- a/codes/trainer/eval/music_diffusion_fid.py +++ b/codes/trainer/eval/music_diffusion_fid.py @@ -78,7 +78,6 @@ class MusicDiffusionFid(evaluator.Evaluator): self.spectral_diffuser = SpacedDiffusion(use_timesteps=space_timesteps(4000, [16]), model_mean_type='epsilon', model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', 4000), conditioning_free=False, conditioning_free_k=1) - self.spec_decoder = get_mel2wav_v3_model() # The only reason the other functions don't use v3 is because earlier models were trained with v1 and I want to keep metrics consistent. self.local_modules['spec_decoder'] = self.spec_decoder elif 'from_ar_prior' == mode: self.diffusion_fn = self.perform_diffusion_from_codes_ar_prior @@ -89,15 +88,12 @@ class MusicDiffusionFid(evaluator.Evaluator): conditioning_free=True, conditioning_free_k=1) self.kmeans_inj = KmeansQuantizerInjector({'centroids': '../experiments/music_k_means_centroids.pth', 'in': 'in', 'out': 'out'}, {}) self.local_modules['ar_prior'] = get_ar_prior() - self.spec_decoder = get_mel2wav_v3_model() self.local_modules['spec_decoder'] = self.spec_decoder elif 'chained_sr' == mode: self.diffusion_fn = self.perform_chained_sr - self.spec_decoder = get_mel2wav_v3_model() - self.local_modules['spec_decoder'] = self.spec_decoder - if not hasattr(self, 'spec_decoder'): - self.spec_decoder = get_mel2wav_model() self.local_modules['spec_decoder'] = self.spec_decoder + self.spec_decoder = get_mel2wav_v3_model() + self.local_modules['spec_decoder'] = self.spec_decoder self.spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 11000, 'filter_length': 16000, 'normalize': True, 'in': 'in', 'out': 'out'}, {})