From 34005367fda863a52a9decd85fdd52a108f053e8 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 9 Jun 2022 21:41:20 -0600 Subject: [PATCH] setup for partial channel diffusion --- codes/models/diffusion/gaussian_diffusion.py | 88 ++----------------- codes/models/diffusion/respace.py | 5 -- codes/trainer/eval/music_diffusion_fid.py | 36 +++++++- .../injectors/gaussian_diffusion_injector.py | 30 ++++--- 4 files changed, 60 insertions(+), 99 deletions(-) diff --git a/codes/models/diffusion/gaussian_diffusion.py b/codes/models/diffusion/gaussian_diffusion.py index 9b8e38b8..274c0520 100644 --- a/codes/models/diffusion/gaussian_diffusion.py +++ b/codes/models/diffusion/gaussian_diffusion.py @@ -777,13 +777,13 @@ class GaussianDiffusion: kl = normal_kl( true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] ) - kl = mean_flat(kl) / np.log(2.0) + kl = kl / np.log(2.0) decoder_nll = -discretized_gaussian_log_likelihood( x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] ) assert decoder_nll.shape == x_start.shape - decoder_nll = mean_flat(decoder_nll) / np.log(2.0) + decoder_nll = decoder_nll / np.log(2.0) # At the first timestep return the decoder NLL, # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) @@ -813,14 +813,14 @@ class GaussianDiffusion: if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: # TODO: support multiple model outputs for this mode. - terms["loss"] = self._vb_terms_bpd( + terms["loss"] = mean_flat(self._vb_terms_bpd( model=model, x_start=x_start, x_t=x_t, t=t, clip_denoised=False, model_kwargs=model_kwargs, - )["output"] + )["output"]) if self.loss_type == LossType.RESCALED_KL: terms["loss"] *= self.num_timesteps elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: @@ -873,79 +873,9 @@ class GaussianDiffusion: terms["mse"] = mean_flat(s_err) terms["x_start_predicted"] = x_start_pred if "vb" in terms: - terms["loss"] = terms["mse"] + terms["vb"] - else: - terms["loss"] = terms["mse"] - else: - raise NotImplementedError(self.loss_type) - - return terms - - def autoregressive_training_losses(self, model, x_start, t, model_output_keys, gd_out_key, model_kwargs=None, noise=None): - """ - Compute training losses for a single timestep. - - :param model: the model to evaluate loss on. - :param x_start: the [N x C x ...] tensor of inputs. - :param t: a batch of timestep indices. - :param model_kwargs: if not None, a dict of extra keyword arguments to - pass to the model. This can be used for conditioning. - :param noise: if specified, the specific Gaussian noise to try to remove. - :return: a dict with the key "loss" containing a tensor of shape [N]. - Some mean or variance settings may also have other keys. - """ - if model_kwargs is None: - model_kwargs = {} - if noise is None: - noise = th.randn_like(x_start) - x_t = self.q_sample(x_start, t, noise=noise) - terms = {} - if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: - assert False # not currently supported for this type of diffusion. - elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: - model_outputs = model(x_t, x_start, self._scale_timesteps(t), **model_kwargs) - terms.update({k: o for k, o in zip(model_output_keys, model_outputs)}) - model_output = terms[gd_out_key] - if self.model_var_type in [ - ModelVarType.LEARNED, - ModelVarType.LEARNED_RANGE, - ]: - B, C = x_t.shape[:2] - assert model_output.shape == (B, C, 2, *x_t.shape[2:]) - model_output, model_var_values = model_output[:, :, 0], model_output[:, :, 1] - # Learn the variance using the variational bound, but don't let - # it affect our mean prediction. - frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) - terms["vb"] = self._vb_terms_bpd( - model=lambda *args, r=frozen_out: r, - x_start=x_start, - x_t=x_t, - t=t, - clip_denoised=False, - )["output"] - if self.loss_type == LossType.RESCALED_MSE: - # Divide by 1000 for equivalence with initial implementation. - # Without a factor of 1/1000, the VB term hurts the MSE term. - terms["vb"] *= self.num_timesteps / 1000.0 - - if self.model_mean_type == ModelMeanType.PREVIOUS_X: - target = self.q_posterior_mean_variance( - x_start=x_start, x_t=x_t, t=t - )[0] - x_start_pred = torch.zeros(x_start) # Not supported. - elif self.model_mean_type == ModelMeanType.START_X: - target = x_start - x_start_pred = model_output - elif self.model_mean_type == ModelMeanType.EPSILON: - target = noise - x_start_pred = self._predict_xstart_from_eps(x_t, t, model_output) - else: - raise NotImplementedError(self.model_mean_type) - assert model_output.shape == target.shape == x_start.shape - terms["mse"] = mean_flat((target - model_output) ** 2) - terms["x_start_predicted"] = x_start_pred - if "vb" in terms: - terms["loss"] = terms["mse"] + terms["vb"] + if channel_balancing_fn is not None: + terms["vb"] = channel_balancing_fn(terms["vb"]) + terms["loss"] = terms["mse"] + mean_flat(terms["vb"]) else: terms["loss"] = terms["mse"] else: @@ -1001,14 +931,14 @@ class GaussianDiffusion: x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) # Calculate VLB term at the current timestep with th.no_grad(): - out = self._vb_terms_bpd( + out = mean_flat(self._vb_terms_bpd( model, x_start=x_start, x_t=x_t, t=t_batch, clip_denoised=clip_denoised, model_kwargs=model_kwargs, - ) + )) vb.append(out["output"]) xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) diff --git a/codes/models/diffusion/respace.py b/codes/models/diffusion/respace.py index 4fad2f8b..78403ebb 100644 --- a/codes/models/diffusion/respace.py +++ b/codes/models/diffusion/respace.py @@ -95,11 +95,6 @@ class SpacedDiffusion(GaussianDiffusion): ): # pylint: disable=signature-differs return super().training_losses(self._wrap_model(model), *args, **kwargs) - def autoregressive_training_losses( - self, model, *args, **kwargs - ): # pylint: disable=signature-differs - return super().autoregressive_training_losses(self._wrap_model(model, True), *args, **kwargs) - def condition_mean(self, cond_fn, *args, **kwargs): return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) diff --git a/codes/trainer/eval/music_diffusion_fid.py b/codes/trainer/eval/music_diffusion_fid.py index 35abc7b6..8d7a1100 100644 --- a/codes/trainer/eval/music_diffusion_fid.py +++ b/codes/trainer/eval/music_diffusion_fid.py @@ -1,3 +1,4 @@ +import functools import os import os.path as osp from glob import glob @@ -62,6 +63,10 @@ class MusicDiffusionFid(evaluator.Evaluator): self.local_modules['codegen'] = get_music_codegen() elif 'from_codes_quant' == mode: self.diffusion_fn = self.perform_diffusion_from_codes_quant + elif 'partial_from_codes_quant' == mode: + self.diffusion_fn = functools.partial(self.perform_partial_diffusion_from_codes_quant, + partial_low=opt_eval['partial_low'], + partial_high=opt_eval['partial_high']) elif 'from_codes_quant_gradual_decode' == mode: self.diffusion_fn = self.perform_diffusion_from_codes_quant_gradual_decode self.spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 11000, 'filter_length': 16000, @@ -135,6 +140,32 @@ class MusicDiffusionFid(evaluator.Evaluator): return gen_wav, real_resampled, gen_mel, mel_norm, sample_rate + def perform_partial_diffusion_from_codes_quant(self, audio, sample_rate=22050, partial_low=0, partial_high=256): + if sample_rate != sample_rate: + real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0) + else: + real_resampled = audio + audio = audio.unsqueeze(0) + + mel = self.spec_fn({'in': audio})['out'] + mel_norm = normalize_mel(mel) + mask = torch.ones_like(mel_norm) + mask[:, partial_low:partial_high] = 0 # This is the channel region that the model will predict. + gen_mel = self.diffuser.p_sample_loop_with_guidance(self.model, + guidance_input=mel_norm, mask=mask, + model_kwargs={'truth_mel': mel, + 'conditioning_input': torch.zeros_like(mel_norm[:,:,:390]), + 'disable_diversity': True}) + + gen_mel_denorm = denormalize_mel(gen_mel) + output_shape = (1,16,audio.shape[-1]//16) + self.spec_decoder = self.spec_decoder.to(audio.device) + gen_wav = self.diffuser.p_sample_loop(self.spec_decoder, output_shape, + model_kwargs={'aligned_conditioning': gen_mel_denorm}) + gen_wav = pixel_shuffle_1d(gen_wav, 16) + + return gen_wav, real_resampled, gen_mel, mel_norm, sample_rate + def perform_diffusion_from_codes_quant_gradual_decode(self, audio, sample_rate=22050): if sample_rate != sample_rate: real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0) @@ -243,7 +274,8 @@ if __name__ == '__main__': 'path': 'E:\\music_eval', 'diffusion_steps': 100, 'conditioning_free': False, 'conditioning_free_k': 1, - 'diffusion_schedule': 'linear', 'diffusion_type': 'from_codes_quant'} - env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 502, 'device': 'cuda', 'opt': {}} + 'diffusion_schedule': 'linear', 'diffusion_type': 'partial_from_codes_quant', + 'partial_low': 128, 'partial_high': 192} + env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 504, 'device': 'cuda', 'opt': {}} eval = MusicDiffusionFid(diffusion, opt_eval, env) print(eval.perform_eval()) diff --git a/codes/trainer/injectors/gaussian_diffusion_injector.py b/codes/trainer/injectors/gaussian_diffusion_injector.py index b6650abd..3d71e2b9 100644 --- a/codes/trainer/injectors/gaussian_diffusion_injector.py +++ b/codes/trainer/injectors/gaussian_diffusion_injector.py @@ -1,18 +1,15 @@ import functools -import random -import time import torch from torch.cuda.amp import autocast -from models.diffusion.gaussian_diffusion import GaussianDiffusion, get_named_beta_schedule +from models.diffusion.gaussian_diffusion import get_named_beta_schedule from models.diffusion.resample import create_named_schedule_sampler, LossAwareSampler, DeterministicSampler from models.diffusion.respace import space_timesteps, SpacedDiffusion from trainer.inject import Injector from utils.util import opt_get - def masked_channel_balancer(inp, proportion=1): with torch.no_grad(): only_channels = inp.mean(dim=(0,2)) # Only currently works for audio tensors. Could be retrofitted for 2d (or 3D!) modalities. @@ -23,6 +20,12 @@ def masked_channel_balancer(inp, proportion=1): return inp * mask.view(1,inp.shape[1],1) +def channel_restriction(inp, low, high): + m = torch.zeros_like(inp) + m[:,low:high] = 1 + return inp * m + + # Injects a gaussian diffusion loss as described by OpenAIs "Improved Denoising Diffusion Probabilistic Models" paper. # Largely uses OpenAI's own code to do so (all code from models.diffusion.*) class GaussianDiffusionInjector(Injector): @@ -40,14 +43,17 @@ class GaussianDiffusionInjector(Injector): self.extra_model_output_keys = opt_get(opt, ['extra_model_output_keys'], []) self.deterministic_timesteps_every = opt_get(opt, ['deterministic_timesteps_every'], 0) self.deterministic_sampler = DeterministicSampler(self.diffusion, opt_get(opt, ['deterministic_sampler_expected_batch_size'], 2048), env) - self.channel_balancing_fn = functools.partial(masked_channel_balancer, proportion=opt['channel_balancer_proportion']) \ - if 'channel_balancer_proportion' in opt.keys() else None - self.recent_loss = 0 - def extra_metrics(self): - return { - 'exp_diffusion_loss': torch.exp(self.recent_loss.mean()), - } + k = 0 + if 'channel_balancer_proportion' in opt.keys(): + self.channel_balancing_fn = functools.partial(masked_channel_balancer, proportion=opt['channel_balancer_proportion']) + k += 1 + if 'channel_restriction_low' in opt.keys(): + self.channel_balancing_fn = functools.partial(channel_restriction, low=opt['channel_restriction_low'], high=opt['channel_restriction_high']) + k += 1 + if not hasattr(self, 'channel_balancing_fn'): + self.channel_balancing_fn = None + assert k <= 1, 'Only one channel filtering function can be applied.' def forward(self, state): gen = self.env['generators'][self.opt['generator']] @@ -74,8 +80,6 @@ class GaussianDiffusionInjector(Injector): self.output_variational_bounds_key: diffusion_outputs['vb'], self.output_x_start_key: diffusion_outputs['x_start_predicted']}) - self.recent_loss = diffusion_outputs['mse'] - return out