From ba1699cee2ac4bb543bcc37d78ed369b6ae7f614 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 8 Jul 2022 12:30:22 -0600 Subject: [PATCH] Improve mdf --- codes/trainer/eval/music_diffusion_fid.py | 140 ++++++++++++++++++---- 1 file changed, 120 insertions(+), 20 deletions(-) diff --git a/codes/trainer/eval/music_diffusion_fid.py b/codes/trainer/eval/music_diffusion_fid.py index 184550ef..9d52d177 100644 --- a/codes/trainer/eval/music_diffusion_fid.py +++ b/codes/trainer/eval/music_diffusion_fid.py @@ -84,6 +84,15 @@ class MusicDiffusionFid(evaluator.Evaluator): conditioning_free=True, conditioning_free_k=1) self.spec_decoder = get_mel2wav_v3_model() # The only reason the other functions don't use v3 is because earlier models were trained with v1 and I want to keep metrics consistent. self.local_modules['spec_decoder'] = self.spec_decoder + elif 'cheater_gen_fake_ar' == mode: + self.diffusion_fn = self.perform_fake_ar_reconstruction_from_cheater_gen + self.local_modules['cheater_encoder'] = get_cheater_encoder() + self.local_modules['cheater_decoder'] = get_cheater_decoder() + self.cheater_decoder_diffuser = SpacedDiffusion(use_timesteps=space_timesteps(4000, [32]), model_mean_type='epsilon', + model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', 4000), + conditioning_free=True, conditioning_free_k=1) + self.spec_decoder = get_mel2wav_v3_model() # The only reason the other functions don't use v3 is because earlier models were trained with v1 and I want to keep metrics consistent. + self.local_modules['spec_decoder'] = self.spec_decoder elif 'from_ar_prior' == mode: self.diffusion_fn = self.perform_diffusion_from_codes_ar_prior self.local_modules['cheater_encoder'] = get_cheater_encoder() @@ -223,21 +232,29 @@ class MusicDiffusionFid(evaluator.Evaluator): # 1. Generate the cheater latent using the input as a reference. gen_cheater = self.diffuser.ddim_sample_loop(self.model, cheater.shape, progress=True, model_kwargs={'conditioning_input': cheater}) - # 2. Decode the cheater into a MEL - gen_mel = self.cheater_decoder_diffuser.ddim_sample_loop(self.local_modules['cheater_decoder'].diff.to(audio.device), (1,256,gen_cheater.shape[-1]*16), progress=True, - model_kwargs={'codes': gen_cheater.permute(0,2,1)}) + # 2. Decode the cheater into a MEL. This operation and the next need to be chunked to make them feasible to perform within GPU memory. + chunks = torch.split(gen_cheater, 64, dim=-1) + gen_wavs = [] + for chunk in tqdm(chunks): + gen_mel = self.cheater_decoder_diffuser.ddim_sample_loop(self.local_modules['cheater_decoder'].diff.to(audio.device), (1,256,chunk.shape[-1]*16), progress=True, + model_kwargs={'codes': chunk.permute(0,2,1)}) - # 3. And then the MEL back into a spectrogram - output_shape = (1,16,audio.shape[-1]//16) - self.spec_decoder = self.spec_decoder.to(audio.device) - gen_mel_denorm = denormalize_mel(gen_mel) - gen_wav = self.spectral_diffuser.p_sample_loop(self.spec_decoder, output_shape, - model_kwargs={'codes': gen_mel_denorm}) - gen_wav = pixel_shuffle_1d(gen_wav, 16) + # 3. And then the MEL back into a spectrogram + output_shape = (1,16,audio.shape[-1]//(16*len(chunks))) + self.spec_decoder = self.spec_decoder.to(audio.device) + gen_mel_denorm = denormalize_mel(gen_mel) + gen_wav = self.spectral_diffuser.p_sample_loop(self.spec_decoder, output_shape, + model_kwargs={'codes': gen_mel_denorm}) + gen_wav = pixel_shuffle_1d(gen_wav, 16) + gen_wavs.append(gen_wav) + gen_wav = torch.cat(gen_wavs, dim=-1) - real_wav = self.spectral_diffuser.p_sample_loop(self.spec_decoder, output_shape, - model_kwargs={'codes': mel}) - real_wav = pixel_shuffle_1d(real_wav, 16) + if audio.shape[-1] < 40 * 22050: + real_wav = self.spectral_diffuser.p_sample_loop(self.spec_decoder, output_shape, + model_kwargs={'codes': mel}) + real_wav = pixel_shuffle_1d(real_wav, 16) + else: + real_wav = audio # TODO: chunk like above. return gen_wav, real_wav.squeeze(0), gen_mel, mel_norm, sample_rate @@ -266,6 +283,87 @@ class MusicDiffusionFid(evaluator.Evaluator): return gen_wav, real_wav.squeeze(0), gen_mel, mel_norm, sample_rate + def perform_fake_ar_reconstruction_from_cheater_gen(self, audio, sample_rate=22050): + assert self.ddim, "DDIM mode expected for reconstructing cheater gen. Do you like to waste resources??" + audio = audio.unsqueeze(0) + + mel = self.spec_fn({'in': audio})['out'] + mel_norm = normalize_mel(mel) + cheater = self.local_modules['cheater_encoder'].to(audio.device)(mel_norm) + + # 1. Generate the cheater latent using the input as a reference. + def diffuse(i, ref): + mask = torch.zeros_like(ref) + mask[:,:,:i] = 1 + return self.diffuser.p_sample_loop_with_guidance(self.model, ref, mask, model_kwargs={'conditioning_input': cheater}) + gen_cheater = torch.randn_like(cheater) + for i in range(cheater.shape[-1]): + gen_cheater = diffuse(i, gen_cheater) + if i > 128: + # abort early. + gen_cheater = gen_cheater[:,:,:128] + break + + # 2. Decode the cheater into a MEL. This operation and the next need to be chunked to make them feasible to perform within GPU memory. + chunks = torch.split(gen_cheater, 64, dim=-1) + gen_wavs = [] + for chunk in tqdm(chunks): + gen_mel = self.cheater_decoder_diffuser.ddim_sample_loop(self.local_modules['cheater_decoder'].diff.to(audio.device), (1,256,chunk.shape[-1]*16), progress=True, + model_kwargs={'codes': chunk.permute(0,2,1)}) + + # 3. And then the MEL back into a spectrogram + output_shape = (1,16,audio.shape[-1]//(16*len(chunks))) + self.spec_decoder = self.spec_decoder.to(audio.device) + gen_mel_denorm = denormalize_mel(gen_mel) + gen_wav = self.spectral_diffuser.p_sample_loop(self.spec_decoder, output_shape, + model_kwargs={'codes': gen_mel_denorm}) + gen_wav = pixel_shuffle_1d(gen_wav, 16) + gen_wavs.append(gen_wav) + gen_wav = torch.cat(gen_wavs, dim=-1) + + """ How to do progressive, causal decoding of the TFD diffuser: + MAX_CONTEXT = 64 + def diffuse(start, len, guidance): + mask = torch.zeros_like(guidance) + mask[:,:,:(len-start)] = 1 + return self.cheater_decoder_diffuser.p_sample_loop_with_guidance(self.local_modules['cheater_decoder'].diff.to(audio.device), + guidance_input=guidance, mask=mask, + model_kwargs={'codes': gen_cheater[:,:,start:start+MAX_CONTEXT].permute(0,2,1)}) + guidance_mel = torch.zeros((1,256,MAX_CONTEXT*16), device=mel.device) + gen_mel = torch.zeros((1,256,0), device=mel.device) + for i in tqdm(list(range(gen_cheater.shape[-1]))): + start = max(0, i-MAX_CONTEXT-1) + l = min(16*(MAX_CONTEXT-1), i*16) + ngm = diffuse(start, l, guidance_mel) + gen_mel = torch.cat([gen_mel, ngm[:,:,l:l+16]], dim=-1) + if gen_mel.shape[-1] < guidance_mel.shape[-1]: + guidance_mel[:,:,:gen_mel.shape[-1]] = gen_mel + else: + guidance_mel = gen_mel[:,:,-guidance_mel.shape[-1]:] + + chunks = torch.split(gen_mel, MAX_CONTEXT*16, dim=-1) + gen_wavs = [] + for chunk_mel in tqdm(chunks): + # 3. And then the MEL back into a spectrogram + output_shape = (1,16,audio.shape[-1]//(16*len(chunks))) + self.spec_decoder = self.spec_decoder.to(audio.device) + gen_mel_denorm = denormalize_mel(chunk_mel) + gen_wav = self.spectral_diffuser.p_sample_loop(self.spec_decoder, output_shape, + model_kwargs={'codes': gen_mel_denorm}) + gen_wav = pixel_shuffle_1d(gen_wav, 16) + gen_wavs.append(gen_wav) + gen_wav = torch.cat(gen_wavs, dim=-1) + """ + + if audio.shape[-1] < 40 * 22050: + real_wav = self.spectral_diffuser.p_sample_loop(self.spec_decoder, output_shape, + model_kwargs={'codes': mel}) + real_wav = pixel_shuffle_1d(real_wav, 16) + else: + real_wav = audio # TODO: chunk like above. + + return gen_wav, real_wav.squeeze(0), gen_mel, mel_norm, sample_rate + def project(self, sample, sample_rate): sample = torchaudio.functional.resample(sample, sample_rate, 22050) mel = self.spec_fn({'in': sample})['out'] @@ -302,7 +400,9 @@ class MusicDiffusionFid(evaluator.Evaluator): real_projections = [] for i in tqdm(list(range(0, len(self.data), self.skip))): path = self.data[(i + self.env['rank']) % len(self.data)] - audio = load_audio(path, 22050).to(self.dev) + #audio = load_audio(path, 22050).to(self.dev) + audio = load_audio('C:\\Users\\James\\Music\\another_longer_sample.wav', 22050).to(self.dev) # <- hack, remove it! + audio = audio[:, :1764000] if self.clip: audio = audio[:, :100000] sample, ref, sample_mel, ref_mel, sample_rate = self.diffusion_fn(audio) @@ -334,17 +434,17 @@ class MusicDiffusionFid(evaluator.Evaluator): if __name__ == '__main__': - diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_tfd12_finetune_ar_outputs.yml', 'generator', + diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_cheater_gen_r8.yml', 'generator', also_load_savepoint=False, - load_path='X:\\dlas\\experiments\\train_music_diffusion_tfd12_finetune_from_cheater_ar\\models\\7500_generator.pth' + load_path='X:\\dlas\\experiments\\train_music_cheater_gen_v5\\models\\203000_generator_ema.pth' ).cuda() opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :) #'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety. - 'diffusion_steps': 32, - 'conditioning_free': True, 'conditioning_free_k': 1, 'use_ddim': True, # 'clip_audio': False, - 'diffusion_schedule': 'linear', 'diffusion_type': 'from_ar_prior', + 'diffusion_steps': 64, + 'conditioning_free': True, 'conditioning_free_k': 1, 'use_ddim': True, 'clip_audio': False, + 'diffusion_schedule': 'linear', 'diffusion_type': 'cheater_gen_fake_ar', #'partial_low': 128, 'partial_high': 192 } - env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 230, 'device': 'cuda', 'opt': {}} + env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 232, 'device': 'cuda', 'opt': {}} eval = MusicDiffusionFid(diffusion, opt_eval, env) print(eval.perform_eval())