go all in on m2wv3

pull/2/head
James Betker 2022-07-21 00:51:27 +07:00
parent 24a78bd7d1
commit 13c263e9fb
1 changed files with 2 additions and 6 deletions

@ -78,7 +78,6 @@ class MusicDiffusionFid(evaluator.Evaluator):
self.spectral_diffuser = SpacedDiffusion(use_timesteps=space_timesteps(4000, [16]), model_mean_type='epsilon',
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', 4000),
conditioning_free=False, conditioning_free_k=1)
self.spec_decoder = get_mel2wav_v3_model() # The only reason the other functions don't use v3 is because earlier models were trained with v1 and I want to keep metrics consistent.
self.local_modules['spec_decoder'] = self.spec_decoder
elif 'from_ar_prior' == mode:
self.diffusion_fn = self.perform_diffusion_from_codes_ar_prior
@ -89,15 +88,12 @@ class MusicDiffusionFid(evaluator.Evaluator):
conditioning_free=True, conditioning_free_k=1)
self.kmeans_inj = KmeansQuantizerInjector({'centroids': '../experiments/music_k_means_centroids.pth', 'in': 'in', 'out': 'out'}, {})
self.local_modules['ar_prior'] = get_ar_prior()
self.spec_decoder = get_mel2wav_v3_model()
self.local_modules['spec_decoder'] = self.spec_decoder
elif 'chained_sr' == mode:
self.diffusion_fn = self.perform_chained_sr
self.spec_decoder = get_mel2wav_v3_model()
self.local_modules['spec_decoder'] = self.spec_decoder
if not hasattr(self, 'spec_decoder'):
self.spec_decoder = get_mel2wav_model()
self.local_modules['spec_decoder'] = self.spec_decoder
self.spec_decoder = get_mel2wav_v3_model()
self.local_modules['spec_decoder'] = self.spec_decoder
self.spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 11000, 'filter_length': 16000,
'normalize': True, 'in': 'in', 'out': 'out'}, {})