forked from mrq/DL-Art-School
add channel loss balancing
This commit is contained in:
parent
40ba802104
commit
a9387179db
|
@ -790,7 +790,7 @@ class GaussianDiffusion:
|
|||
output = th.where((t == 0), decoder_nll, kl)
|
||||
return {"output": output, "pred_xstart": out["pred_xstart"]}
|
||||
|
||||
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
|
||||
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, channel_balancing_fn=None):
|
||||
"""
|
||||
Compute training losses for a single timestep.
|
||||
|
||||
|
@ -867,7 +867,10 @@ class GaussianDiffusion:
|
|||
else:
|
||||
raise NotImplementedError(self.model_mean_type)
|
||||
assert model_output.shape == target.shape == x_start.shape
|
||||
terms["mse"] = mean_flat((target - model_output) ** 2)
|
||||
s_err = (target - model_output) ** 2
|
||||
if channel_balancing_fn is not None:
|
||||
s_err = channel_balancing_fn(s_err)
|
||||
terms["mse"] = mean_flat(s_err)
|
||||
terms["x_start_predicted"] = x_start_pred
|
||||
if "vb" in terms:
|
||||
terms["loss"] = terms["mse"] + terms["vb"]
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import functools
|
||||
import random
|
||||
import time
|
||||
|
||||
|
@ -11,6 +12,17 @@ from trainer.inject import Injector
|
|||
from utils.util import opt_get
|
||||
|
||||
|
||||
|
||||
def masked_channel_balancer(inp, proportion=1):
|
||||
with torch.no_grad():
|
||||
only_channels = inp.mean(dim=(0,2)) # Only currently works for audio tensors. Could be retrofitted for 2d (or 3D!) modalities.
|
||||
dist = only_channels / only_channels.sum()
|
||||
dist_mult = only_channels.shape[0] * proportion
|
||||
dist = (dist * dist_mult).clamp(0, 1)
|
||||
mask = torch.bernoulli(dist)
|
||||
return inp * mask.view(1,inp.shape[1],1)
|
||||
|
||||
|
||||
# Injects a gaussian diffusion loss as described by OpenAIs "Improved Denoising Diffusion Probabilistic Models" paper.
|
||||
# Largely uses OpenAI's own code to do so (all code from models.diffusion.*)
|
||||
class GaussianDiffusionInjector(Injector):
|
||||
|
@ -28,6 +40,8 @@ class GaussianDiffusionInjector(Injector):
|
|||
self.extra_model_output_keys = opt_get(opt, ['extra_model_output_keys'], [])
|
||||
self.deterministic_timesteps_every = opt_get(opt, ['deterministic_timesteps_every'], 0)
|
||||
self.deterministic_sampler = DeterministicSampler(self.diffusion, opt_get(opt, ['deterministic_sampler_expected_batch_size'], 2048), env)
|
||||
self.channel_balancing_fn = functools.partial(masked_channel_balancer, proportion=opt['channel_balancer_proportion']) \
|
||||
if 'channel_balancer_proportion' in opt.keys() else None
|
||||
self.recent_loss = 0
|
||||
|
||||
def extra_metrics(self):
|
||||
|
@ -48,7 +62,7 @@ class GaussianDiffusionInjector(Injector):
|
|||
self.deterministic_sampler.reset() # Keep this reset whenever it is not being used, so it is ready to use automatically.
|
||||
model_inputs = {k: state[v] if isinstance(v, str) else v for k, v in self.model_input_keys.items()}
|
||||
t, weights = sampler.sample(hq.shape[0], hq.device)
|
||||
diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs)
|
||||
diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs, channel_balancing_fn=self.channel_balancing_fn)
|
||||
if isinstance(sampler, LossAwareSampler):
|
||||
sampler.update_with_local_losses(t, diffusion_outputs['losses'])
|
||||
if len(self.extra_model_output_keys) > 0:
|
||||
|
|
Loading…
Reference in New Issue
Block a user