diff --git a/codes/models/audio/music/transformer_diffusion13.py b/codes/models/audio/music/transformer_diffusion13.py index 27eb0781..7d8fd92a 100644 --- a/codes/models/audio/music/transformer_diffusion13.py +++ b/codes/models/audio/music/transformer_diffusion13.py @@ -232,7 +232,7 @@ class TransformerDiffusion(nn.Module): assert torch.all(timesteps - prior_timesteps >= 0), f'Prior timesteps should always be lower (more resolved) than input timesteps. {timesteps}, {prior_timesteps}' if conditioning_free: - code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, 1) + code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1) else: MIN_COND_LEN = 200 MAX_COND_LEN = 1200 diff --git a/codes/trainer/eval/music_diffusion_fid.py b/codes/trainer/eval/music_diffusion_fid.py index 6150f893..0ee8fe61 100644 --- a/codes/trainer/eval/music_diffusion_fid.py +++ b/codes/trainer/eval/music_diffusion_fid.py @@ -19,7 +19,7 @@ from models.diffusion.gaussian_diffusion import get_named_beta_schedule from models.diffusion.respace import space_timesteps, SpacedDiffusion from trainer.injectors.audio_injectors import denormalize_torch_mel, TorchMelSpectrogramInjector, pixel_shuffle_1d, \ KmeansQuantizerInjector, normalize_torch_mel -from utils.music_utils import get_music_codegen, get_mel2wav_model, get_cheater_decoder, get_cheater_encoder, \ +from utils.music_utils import get_music_codegen, get_cheater_decoder, get_cheater_encoder, \ get_mel2wav_v3_model, get_ar_prior from utils.util import opt_get, load_model_from_config @@ -58,8 +58,9 @@ class MusicDiffusionFid(evaluator.Evaluator): self.projector = ContrastiveAudio(model_dim=512, transformer_heads=8, dropout=0, encoder_depth=8, mel_channels=256) self.projector.load_state_dict(torch.load('../experiments/music_eval_projector.pth', map_location=torch.device('cpu'))) + self.spec_decoder = get_mel2wav_v3_model() - self.local_modules = {'projector': self.projector} + self.local_modules = {'projector': self.projector, 'spec_decoder': self.spec_decoder} if mode == 'spec_decode': self.diffusion_fn = self.perform_diffusion_spec_decode self.squeeze_ratio = opt_eval['squeeze_ratio'] @@ -78,7 +79,6 @@ class MusicDiffusionFid(evaluator.Evaluator): self.spectral_diffuser = SpacedDiffusion(use_timesteps=space_timesteps(4000, [16]), model_mean_type='epsilon', model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', 4000), conditioning_free=False, conditioning_free_k=1) - self.local_modules['spec_decoder'] = self.spec_decoder elif 'from_ar_prior' == mode: self.diffusion_fn = self.perform_diffusion_from_codes_ar_prior self.local_modules['cheater_encoder'] = get_cheater_encoder() @@ -88,12 +88,8 @@ class MusicDiffusionFid(evaluator.Evaluator): conditioning_free=True, conditioning_free_k=1) self.kmeans_inj = KmeansQuantizerInjector({'centroids': '../experiments/music_k_means_centroids.pth', 'in': 'in', 'out': 'out'}, {}) self.local_modules['ar_prior'] = get_ar_prior() - self.local_modules['spec_decoder'] = self.spec_decoder elif 'chained_sr' == mode: self.diffusion_fn = self.perform_chained_sr - self.local_modules['spec_decoder'] = self.spec_decoder - self.spec_decoder = get_mel2wav_v3_model() - self.local_modules['spec_decoder'] = self.spec_decoder self.spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 11000, 'filter_length': 16000, 'normalize': True, 'in': 'in', 'out': 'out'}, {}) @@ -149,11 +145,11 @@ class MusicDiffusionFid(evaluator.Evaluator): self.spec_decoder = self.spec_decoder.to(audio.device) sampler = self.spectral_diffuser.ddim_sample_loop if self.ddim else self.spectral_diffuser.p_sample_loop gen_wav = sampler(self.spec_decoder, output_shape, - model_kwargs={'aligned_conditioning': gen_mel_denorm}) + model_kwargs={'codes': gen_mel_denorm}) gen_wav = pixel_shuffle_1d(gen_wav, 16) real_wav = sampler(self.spec_decoder, output_shape, - model_kwargs={'aligned_conditioning': mel}) + model_kwargs={'codes': mel}) real_wav = pixel_shuffle_1d(real_wav, 16) return gen_wav, real_wav.squeeze(0), gen_mel, mel_norm, sample_rate @@ -227,15 +223,15 @@ class MusicDiffusionFid(evaluator.Evaluator): mel = self.spec_fn({'in': audio})['out'] mel_norm = normalize_torch_mel(mel) conditioning = mel_norm[:,:,:1200] - downsampled = F.interpolate(mel_norm, scale_factor=1/16, mode='nearest') + downsampled = F.interpolate(mel_norm, scale_factor=1/4, mode='nearest') stage1_shape = (1, 256, downsampled.shape[-1]*4) sampler = self.diffuser.ddim_sample_loop if self.ddim else self.diffuser.p_sample_loop - # Chain super-sampling using 2 stages. - stage1 = sampler(self.model, stage1_shape, model_kwargs={'resolution': torch.tensor([1], device=audio.device), - 'x_prior': downsampled, - 'conditioning_input': conditioning}) + # (Eventually) Chain super-sampling using 2 stages. + #stage1 = sampler(self.model, stage1_shape, model_kwargs={'resolution': torch.tensor([1], device=audio.device), + # 'x_prior': downsampled, + # 'conditioning_input': conditioning}) stage2 = sampler(self.model, mel.shape, model_kwargs={'resolution': torch.tensor([0], device=audio.device), - 'x_prior': stage1, + 'x_prior': downsampled, 'conditioning_input': conditioning}) # Decode into waveform. output_shape = (1,16,audio.shape[-1]//16) @@ -318,24 +314,24 @@ class MusicDiffusionFid(evaluator.Evaluator): if __name__ == '__main__': - """ # For multilevel SR: diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr.yml', 'generator', also_load_savepoint=False, strict_load=False, - load_path='X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr\\models\\4000_generator.pth' + load_path='X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr\\models\\6000_generator.pth' ).cuda() opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :) #'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety. 'diffusion_steps': 128, # basis: 192 - 'conditioning_free': False, 'conditioning_free_k': 1, 'use_ddim': False, 'clip_audio': False, + 'conditioning_free': True, 'conditioning_free_k': 1, 'use_ddim': False, 'clip_audio': True, 'diffusion_schedule': 'cosine', 'diffusion_type': 'chained_sr', } + """ # For TFD+cheater trainer diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_tfd_and_cheater.yml', 'generator', also_load_savepoint=False, strict_load=False, - load_path='X:\\dlas\\experiments\\train_music_diffusion_tfd14_and_cheater_g2\\models\\20000_generator.pth' + load_path='X:\\dlas\\experiments\\train_music_diffusion_tfd14_and_cheater_g2\\models\\1000_generator.pth' ).cuda() opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :) #'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety. @@ -343,9 +339,9 @@ if __name__ == '__main__': 'conditioning_free': True, 'conditioning_free_k': 1, 'use_ddim': True, 'clip_audio': True, 'diffusion_schedule': 'linear', 'diffusion_type': 'from_codes_quant', } + """ - - env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 6, 'device': 'cuda', 'opt': {}} + env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 7, 'device': 'cuda', 'opt': {}} eval = MusicDiffusionFid(diffusion, opt_eval, env) fds = [] for i in range(2):