diff --git a/codes/models/diffusion/gaussian_diffusion.py b/codes/models/diffusion/gaussian_diffusion.py index cd9dc208..35c4a10a 100644 --- a/codes/models/diffusion/gaussian_diffusion.py +++ b/codes/models/diffusion/gaussian_diffusion.py @@ -336,7 +336,7 @@ class GaussianDiffusion: if self.conditioning_free: if self.ramp_conditioning_free: assert t.shape[0] == 1 # This should only be used in inference. - cfk = self.conditioning_free_k * (1 - self._scale_timesteps(t)[0].item() / self.num_timesteps) + cfk = self.conditioning_free_k * (1 - self._scale_timesteps(t).float().mean().item() / self.num_timesteps) else: cfk = self.conditioning_free_k model_output = (1 + cfk) * model_output - cfk * model_output_no_conditioning @@ -660,9 +660,12 @@ class GaussianDiffusion: out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps ) - nonzero_mask = ( - (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) - ) # no noise when t == 0 + if len(t.shape) == 2: + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + else: + nonzero_mask = (t != 0).float() sample = mean_pred + nonzero_mask * sigma * noise return {"sample": sample, "pred_xstart": out["pred_xstart"]} @@ -710,11 +713,13 @@ class GaussianDiffusion: shape, noise=None, clip_denoised=True, + causal=False, + causal_slope=1, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, - progress=False, + progress=True, eta=0.0, ): """ @@ -728,6 +733,8 @@ class GaussianDiffusion: shape, noise=noise, clip_denoised=clip_denoised, + causal=causal, + causal_slope=causal_slope, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, @@ -744,11 +751,13 @@ class GaussianDiffusion: shape, noise=None, clip_denoised=True, + causal=False, + causal_slope=1, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, - progress=False, + progress=True, eta=0.0, ): """ @@ -772,8 +781,15 @@ class GaussianDiffusion: indices = tqdm(indices) + orig_img = img for i in indices: t = th.tensor([i] * shape[0], device=device) + mask = torch.zeros_like(img) + if causal: + t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope, add_jitter=False).unsqueeze(1) + mask = t == self.num_timesteps + t[mask] = self.num_timesteps-1 + mask = mask.repeat(img.shape[0], img.shape[1], 1) with th.no_grad(): out = self.ddim_sample( model, @@ -787,6 +803,8 @@ class GaussianDiffusion: ) yield out img = out["sample"] + if torch.any(mask): + img[mask] = orig_img[mask] # For causal diffusion, keep resetting these predictions until they are unmasked. def _vb_terms_bpd( self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None diff --git a/codes/trainer/eval/music_diffusion_fid.py b/codes/trainer/eval/music_diffusion_fid.py index 9d52d177..4a886983 100644 --- a/codes/trainer/eval/music_diffusion_fid.py +++ b/codes/trainer/eval/music_diffusion_fid.py @@ -37,6 +37,8 @@ class MusicDiffusionFid(evaluator.Evaluator): self.data = self.load_data(self.real_path) self.clip = opt_get(opt_eval, ['clip_audio'], True) # Recommend setting true for more efficient eval passes. self.ddim = opt_get(opt_eval, ['use_ddim'], False) + self.causal = opt_get(opt_eval, ['causal'], True) + self.causal_slope = opt_get(opt_eval, ['causal_slope'], 1) if distributed.is_initialized() and distributed.get_world_size() > 1: self.skip = distributed.get_world_size() # One batch element per GPU. else: @@ -84,15 +86,6 @@ class MusicDiffusionFid(evaluator.Evaluator): conditioning_free=True, conditioning_free_k=1) self.spec_decoder = get_mel2wav_v3_model() # The only reason the other functions don't use v3 is because earlier models were trained with v1 and I want to keep metrics consistent. self.local_modules['spec_decoder'] = self.spec_decoder - elif 'cheater_gen_fake_ar' == mode: - self.diffusion_fn = self.perform_fake_ar_reconstruction_from_cheater_gen - self.local_modules['cheater_encoder'] = get_cheater_encoder() - self.local_modules['cheater_decoder'] = get_cheater_decoder() - self.cheater_decoder_diffuser = SpacedDiffusion(use_timesteps=space_timesteps(4000, [32]), model_mean_type='epsilon', - model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', 4000), - conditioning_free=True, conditioning_free_k=1) - self.spec_decoder = get_mel2wav_v3_model() # The only reason the other functions don't use v3 is because earlier models were trained with v1 and I want to keep metrics consistent. - self.local_modules['spec_decoder'] = self.spec_decoder elif 'from_ar_prior' == mode: self.diffusion_fn = self.perform_diffusion_from_codes_ar_prior self.local_modules['cheater_encoder'] = get_cheater_encoder() @@ -230,7 +223,9 @@ class MusicDiffusionFid(evaluator.Evaluator): cheater = self.local_modules['cheater_encoder'].to(audio.device)(mel_norm) # 1. Generate the cheater latent using the input as a reference. - gen_cheater = self.diffuser.ddim_sample_loop(self.model, cheater.shape, progress=True, model_kwargs={'conditioning_input': cheater}) + gen_cheater = self.diffuser.ddim_sample_loop(self.model, cheater.shape, progress=True, + model_kwargs={'conditioning_input': cheater}, + causal=self.causal, causal_slope=self.causal_slope) # 2. Decode the cheater into a MEL. This operation and the next need to be chunked to make them feasible to perform within GPU memory. chunks = torch.split(gen_cheater, 64, dim=-1) @@ -400,9 +395,9 @@ class MusicDiffusionFid(evaluator.Evaluator): real_projections = [] for i in tqdm(list(range(0, len(self.data), self.skip))): path = self.data[(i + self.env['rank']) % len(self.data)] - #audio = load_audio(path, 22050).to(self.dev) - audio = load_audio('C:\\Users\\James\\Music\\another_longer_sample.wav', 22050).to(self.dev) # <- hack, remove it! - audio = audio[:, :1764000] + audio = load_audio(path, 22050).to(self.dev) + #audio = load_audio('C:\\Users\\James\\Music\\another_longer_sample.wav', 22050).to(self.dev) # <- hack, remove it! + #audio = audio[:, :1764000] if self.clip: audio = audio[:, :100000] sample, ref, sample_mel, ref_mel, sample_rate = self.diffusion_fn(audio) @@ -436,13 +431,14 @@ class MusicDiffusionFid(evaluator.Evaluator): if __name__ == '__main__': diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_cheater_gen_r8.yml', 'generator', also_load_savepoint=False, - load_path='X:\\dlas\\experiments\\train_music_cheater_gen_v5\\models\\203000_generator_ema.pth' + load_path='X:\\dlas\\experiments\\train_music_cheater_gen_v5_causal\\models\\1000_generator.pth' ).cuda() opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :) #'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety. 'diffusion_steps': 64, 'conditioning_free': True, 'conditioning_free_k': 1, 'use_ddim': True, 'clip_audio': False, - 'diffusion_schedule': 'linear', 'diffusion_type': 'cheater_gen_fake_ar', + 'diffusion_schedule': 'linear', 'diffusion_type': 'cheater_gen', + 'causal': True, 'causal_slope': 4, #'partial_low': 128, 'partial_high': 192 } env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 232, 'device': 'cuda', 'opt': {}}