From a9387179db0adda76babe2c58962fbbc336f2e3a Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 3 Jun 2022 15:19:23 -0600 Subject: [PATCH] add channel loss balancing --- codes/models/diffusion/gaussian_diffusion.py | 7 +++++-- .../injectors/gaussian_diffusion_injector.py | 16 +++++++++++++++- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/codes/models/diffusion/gaussian_diffusion.py b/codes/models/diffusion/gaussian_diffusion.py index 4a63a899..9b8e38b8 100644 --- a/codes/models/diffusion/gaussian_diffusion.py +++ b/codes/models/diffusion/gaussian_diffusion.py @@ -790,7 +790,7 @@ class GaussianDiffusion: output = th.where((t == 0), decoder_nll, kl) return {"output": output, "pred_xstart": out["pred_xstart"]} - def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): + def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, channel_balancing_fn=None): """ Compute training losses for a single timestep. @@ -867,7 +867,10 @@ class GaussianDiffusion: else: raise NotImplementedError(self.model_mean_type) assert model_output.shape == target.shape == x_start.shape - terms["mse"] = mean_flat((target - model_output) ** 2) + s_err = (target - model_output) ** 2 + if channel_balancing_fn is not None: + s_err = channel_balancing_fn(s_err) + terms["mse"] = mean_flat(s_err) terms["x_start_predicted"] = x_start_pred if "vb" in terms: terms["loss"] = terms["mse"] + terms["vb"] diff --git a/codes/trainer/injectors/gaussian_diffusion_injector.py b/codes/trainer/injectors/gaussian_diffusion_injector.py index 5e5b8808..b6650abd 100644 --- a/codes/trainer/injectors/gaussian_diffusion_injector.py +++ b/codes/trainer/injectors/gaussian_diffusion_injector.py @@ -1,3 +1,4 @@ +import functools import random import time @@ -11,6 +12,17 @@ from trainer.inject import Injector from utils.util import opt_get + +def masked_channel_balancer(inp, proportion=1): + with torch.no_grad(): + only_channels = inp.mean(dim=(0,2)) # Only currently works for audio tensors. Could be retrofitted for 2d (or 3D!) modalities. + dist = only_channels / only_channels.sum() + dist_mult = only_channels.shape[0] * proportion + dist = (dist * dist_mult).clamp(0, 1) + mask = torch.bernoulli(dist) + return inp * mask.view(1,inp.shape[1],1) + + # Injects a gaussian diffusion loss as described by OpenAIs "Improved Denoising Diffusion Probabilistic Models" paper. # Largely uses OpenAI's own code to do so (all code from models.diffusion.*) class GaussianDiffusionInjector(Injector): @@ -28,6 +40,8 @@ class GaussianDiffusionInjector(Injector): self.extra_model_output_keys = opt_get(opt, ['extra_model_output_keys'], []) self.deterministic_timesteps_every = opt_get(opt, ['deterministic_timesteps_every'], 0) self.deterministic_sampler = DeterministicSampler(self.diffusion, opt_get(opt, ['deterministic_sampler_expected_batch_size'], 2048), env) + self.channel_balancing_fn = functools.partial(masked_channel_balancer, proportion=opt['channel_balancer_proportion']) \ + if 'channel_balancer_proportion' in opt.keys() else None self.recent_loss = 0 def extra_metrics(self): @@ -48,7 +62,7 @@ class GaussianDiffusionInjector(Injector): self.deterministic_sampler.reset() # Keep this reset whenever it is not being used, so it is ready to use automatically. model_inputs = {k: state[v] if isinstance(v, str) else v for k, v in self.model_input_keys.items()} t, weights = sampler.sample(hq.shape[0], hq.device) - diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs) + diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs, channel_balancing_fn=self.channel_balancing_fn) if isinstance(sampler, LossAwareSampler): sampler.update_with_local_losses(t, diffusion_outputs['losses']) if len(self.extra_model_output_keys) > 0: