forked from mrq/DL-Art-School
go all in on m2wv3
This commit is contained in:
parent
24a78bd7d1
commit
13c263e9fb
|
@ -78,7 +78,6 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
self.spectral_diffuser = SpacedDiffusion(use_timesteps=space_timesteps(4000, [16]), model_mean_type='epsilon',
|
||||
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', 4000),
|
||||
conditioning_free=False, conditioning_free_k=1)
|
||||
self.spec_decoder = get_mel2wav_v3_model() # The only reason the other functions don't use v3 is because earlier models were trained with v1 and I want to keep metrics consistent.
|
||||
self.local_modules['spec_decoder'] = self.spec_decoder
|
||||
elif 'from_ar_prior' == mode:
|
||||
self.diffusion_fn = self.perform_diffusion_from_codes_ar_prior
|
||||
|
@ -89,15 +88,12 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
conditioning_free=True, conditioning_free_k=1)
|
||||
self.kmeans_inj = KmeansQuantizerInjector({'centroids': '../experiments/music_k_means_centroids.pth', 'in': 'in', 'out': 'out'}, {})
|
||||
self.local_modules['ar_prior'] = get_ar_prior()
|
||||
self.spec_decoder = get_mel2wav_v3_model()
|
||||
self.local_modules['spec_decoder'] = self.spec_decoder
|
||||
elif 'chained_sr' == mode:
|
||||
self.diffusion_fn = self.perform_chained_sr
|
||||
self.spec_decoder = get_mel2wav_v3_model()
|
||||
self.local_modules['spec_decoder'] = self.spec_decoder
|
||||
if not hasattr(self, 'spec_decoder'):
|
||||
self.spec_decoder = get_mel2wav_model()
|
||||
self.local_modules['spec_decoder'] = self.spec_decoder
|
||||
self.spec_decoder = get_mel2wav_v3_model()
|
||||
self.local_modules['spec_decoder'] = self.spec_decoder
|
||||
self.spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 11000, 'filter_length': 16000,
|
||||
'normalize': True, 'in': 'in', 'out': 'out'}, {})
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user