add channel loss balancing

This commit is contained in:
James Betker 2022-06-03 15:19:23 -06:00
parent 40ba802104
commit a9387179db
2 changed files with 20 additions and 3 deletions

View File

@ -790,7 +790,7 @@ class GaussianDiffusion:
output = th.where((t == 0), decoder_nll, kl) output = th.where((t == 0), decoder_nll, kl)
return {"output": output, "pred_xstart": out["pred_xstart"]} return {"output": output, "pred_xstart": out["pred_xstart"]}
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, channel_balancing_fn=None):
""" """
Compute training losses for a single timestep. Compute training losses for a single timestep.
@ -867,7 +867,10 @@ class GaussianDiffusion:
else: else:
raise NotImplementedError(self.model_mean_type) raise NotImplementedError(self.model_mean_type)
assert model_output.shape == target.shape == x_start.shape assert model_output.shape == target.shape == x_start.shape
terms["mse"] = mean_flat((target - model_output) ** 2) s_err = (target - model_output) ** 2
if channel_balancing_fn is not None:
s_err = channel_balancing_fn(s_err)
terms["mse"] = mean_flat(s_err)
terms["x_start_predicted"] = x_start_pred terms["x_start_predicted"] = x_start_pred
if "vb" in terms: if "vb" in terms:
terms["loss"] = terms["mse"] + terms["vb"] terms["loss"] = terms["mse"] + terms["vb"]

View File

@ -1,3 +1,4 @@
import functools
import random import random
import time import time
@ -11,6 +12,17 @@ from trainer.inject import Injector
from utils.util import opt_get from utils.util import opt_get
def masked_channel_balancer(inp, proportion=1):
with torch.no_grad():
only_channels = inp.mean(dim=(0,2)) # Only currently works for audio tensors. Could be retrofitted for 2d (or 3D!) modalities.
dist = only_channels / only_channels.sum()
dist_mult = only_channels.shape[0] * proportion
dist = (dist * dist_mult).clamp(0, 1)
mask = torch.bernoulli(dist)
return inp * mask.view(1,inp.shape[1],1)
# Injects a gaussian diffusion loss as described by OpenAIs "Improved Denoising Diffusion Probabilistic Models" paper. # Injects a gaussian diffusion loss as described by OpenAIs "Improved Denoising Diffusion Probabilistic Models" paper.
# Largely uses OpenAI's own code to do so (all code from models.diffusion.*) # Largely uses OpenAI's own code to do so (all code from models.diffusion.*)
class GaussianDiffusionInjector(Injector): class GaussianDiffusionInjector(Injector):
@ -28,6 +40,8 @@ class GaussianDiffusionInjector(Injector):
self.extra_model_output_keys = opt_get(opt, ['extra_model_output_keys'], []) self.extra_model_output_keys = opt_get(opt, ['extra_model_output_keys'], [])
self.deterministic_timesteps_every = opt_get(opt, ['deterministic_timesteps_every'], 0) self.deterministic_timesteps_every = opt_get(opt, ['deterministic_timesteps_every'], 0)
self.deterministic_sampler = DeterministicSampler(self.diffusion, opt_get(opt, ['deterministic_sampler_expected_batch_size'], 2048), env) self.deterministic_sampler = DeterministicSampler(self.diffusion, opt_get(opt, ['deterministic_sampler_expected_batch_size'], 2048), env)
self.channel_balancing_fn = functools.partial(masked_channel_balancer, proportion=opt['channel_balancer_proportion']) \
if 'channel_balancer_proportion' in opt.keys() else None
self.recent_loss = 0 self.recent_loss = 0
def extra_metrics(self): def extra_metrics(self):
@ -48,7 +62,7 @@ class GaussianDiffusionInjector(Injector):
self.deterministic_sampler.reset() # Keep this reset whenever it is not being used, so it is ready to use automatically. self.deterministic_sampler.reset() # Keep this reset whenever it is not being used, so it is ready to use automatically.
model_inputs = {k: state[v] if isinstance(v, str) else v for k, v in self.model_input_keys.items()} model_inputs = {k: state[v] if isinstance(v, str) else v for k, v in self.model_input_keys.items()}
t, weights = sampler.sample(hq.shape[0], hq.device) t, weights = sampler.sample(hq.shape[0], hq.device)
diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs) diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs, channel_balancing_fn=self.channel_balancing_fn)
if isinstance(sampler, LossAwareSampler): if isinstance(sampler, LossAwareSampler):
sampler.update_with_local_losses(t, diffusion_outputs['losses']) sampler.update_with_local_losses(t, diffusion_outputs['losses'])
if len(self.extra_model_output_keys) > 0: if len(self.extra_model_output_keys) > 0: