add channel loss balancing
This commit is contained in:
parent
40ba802104
commit
a9387179db
|
@ -790,7 +790,7 @@ class GaussianDiffusion:
|
||||||
output = th.where((t == 0), decoder_nll, kl)
|
output = th.where((t == 0), decoder_nll, kl)
|
||||||
return {"output": output, "pred_xstart": out["pred_xstart"]}
|
return {"output": output, "pred_xstart": out["pred_xstart"]}
|
||||||
|
|
||||||
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
|
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, channel_balancing_fn=None):
|
||||||
"""
|
"""
|
||||||
Compute training losses for a single timestep.
|
Compute training losses for a single timestep.
|
||||||
|
|
||||||
|
@ -867,7 +867,10 @@ class GaussianDiffusion:
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(self.model_mean_type)
|
raise NotImplementedError(self.model_mean_type)
|
||||||
assert model_output.shape == target.shape == x_start.shape
|
assert model_output.shape == target.shape == x_start.shape
|
||||||
terms["mse"] = mean_flat((target - model_output) ** 2)
|
s_err = (target - model_output) ** 2
|
||||||
|
if channel_balancing_fn is not None:
|
||||||
|
s_err = channel_balancing_fn(s_err)
|
||||||
|
terms["mse"] = mean_flat(s_err)
|
||||||
terms["x_start_predicted"] = x_start_pred
|
terms["x_start_predicted"] = x_start_pred
|
||||||
if "vb" in terms:
|
if "vb" in terms:
|
||||||
terms["loss"] = terms["mse"] + terms["vb"]
|
terms["loss"] = terms["mse"] + terms["vb"]
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import functools
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
@ -11,6 +12,17 @@ from trainer.inject import Injector
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def masked_channel_balancer(inp, proportion=1):
|
||||||
|
with torch.no_grad():
|
||||||
|
only_channels = inp.mean(dim=(0,2)) # Only currently works for audio tensors. Could be retrofitted for 2d (or 3D!) modalities.
|
||||||
|
dist = only_channels / only_channels.sum()
|
||||||
|
dist_mult = only_channels.shape[0] * proportion
|
||||||
|
dist = (dist * dist_mult).clamp(0, 1)
|
||||||
|
mask = torch.bernoulli(dist)
|
||||||
|
return inp * mask.view(1,inp.shape[1],1)
|
||||||
|
|
||||||
|
|
||||||
# Injects a gaussian diffusion loss as described by OpenAIs "Improved Denoising Diffusion Probabilistic Models" paper.
|
# Injects a gaussian diffusion loss as described by OpenAIs "Improved Denoising Diffusion Probabilistic Models" paper.
|
||||||
# Largely uses OpenAI's own code to do so (all code from models.diffusion.*)
|
# Largely uses OpenAI's own code to do so (all code from models.diffusion.*)
|
||||||
class GaussianDiffusionInjector(Injector):
|
class GaussianDiffusionInjector(Injector):
|
||||||
|
@ -28,6 +40,8 @@ class GaussianDiffusionInjector(Injector):
|
||||||
self.extra_model_output_keys = opt_get(opt, ['extra_model_output_keys'], [])
|
self.extra_model_output_keys = opt_get(opt, ['extra_model_output_keys'], [])
|
||||||
self.deterministic_timesteps_every = opt_get(opt, ['deterministic_timesteps_every'], 0)
|
self.deterministic_timesteps_every = opt_get(opt, ['deterministic_timesteps_every'], 0)
|
||||||
self.deterministic_sampler = DeterministicSampler(self.diffusion, opt_get(opt, ['deterministic_sampler_expected_batch_size'], 2048), env)
|
self.deterministic_sampler = DeterministicSampler(self.diffusion, opt_get(opt, ['deterministic_sampler_expected_batch_size'], 2048), env)
|
||||||
|
self.channel_balancing_fn = functools.partial(masked_channel_balancer, proportion=opt['channel_balancer_proportion']) \
|
||||||
|
if 'channel_balancer_proportion' in opt.keys() else None
|
||||||
self.recent_loss = 0
|
self.recent_loss = 0
|
||||||
|
|
||||||
def extra_metrics(self):
|
def extra_metrics(self):
|
||||||
|
@ -48,7 +62,7 @@ class GaussianDiffusionInjector(Injector):
|
||||||
self.deterministic_sampler.reset() # Keep this reset whenever it is not being used, so it is ready to use automatically.
|
self.deterministic_sampler.reset() # Keep this reset whenever it is not being used, so it is ready to use automatically.
|
||||||
model_inputs = {k: state[v] if isinstance(v, str) else v for k, v in self.model_input_keys.items()}
|
model_inputs = {k: state[v] if isinstance(v, str) else v for k, v in self.model_input_keys.items()}
|
||||||
t, weights = sampler.sample(hq.shape[0], hq.device)
|
t, weights = sampler.sample(hq.shape[0], hq.device)
|
||||||
diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs)
|
diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs, channel_balancing_fn=self.channel_balancing_fn)
|
||||||
if isinstance(sampler, LossAwareSampler):
|
if isinstance(sampler, LossAwareSampler):
|
||||||
sampler.update_with_local_losses(t, diffusion_outputs['losses'])
|
sampler.update_with_local_losses(t, diffusion_outputs['losses'])
|
||||||
if len(self.extra_model_output_keys) > 0:
|
if len(self.extra_model_output_keys) > 0:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user