setup for partial channel diffusion

This commit is contained in:
James Betker 2022-06-09 21:41:20 -06:00
parent 47b34f5cb9
commit 34005367fd
4 changed files with 60 additions and 99 deletions

View File

@ -777,13 +777,13 @@ class GaussianDiffusion:
kl = normal_kl( kl = normal_kl(
true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
) )
kl = mean_flat(kl) / np.log(2.0) kl = kl / np.log(2.0)
decoder_nll = -discretized_gaussian_log_likelihood( decoder_nll = -discretized_gaussian_log_likelihood(
x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
) )
assert decoder_nll.shape == x_start.shape assert decoder_nll.shape == x_start.shape
decoder_nll = mean_flat(decoder_nll) / np.log(2.0) decoder_nll = decoder_nll / np.log(2.0)
# At the first timestep return the decoder NLL, # At the first timestep return the decoder NLL,
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
@ -813,14 +813,14 @@ class GaussianDiffusion:
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
# TODO: support multiple model outputs for this mode. # TODO: support multiple model outputs for this mode.
terms["loss"] = self._vb_terms_bpd( terms["loss"] = mean_flat(self._vb_terms_bpd(
model=model, model=model,
x_start=x_start, x_start=x_start,
x_t=x_t, x_t=x_t,
t=t, t=t,
clip_denoised=False, clip_denoised=False,
model_kwargs=model_kwargs, model_kwargs=model_kwargs,
)["output"] )["output"])
if self.loss_type == LossType.RESCALED_KL: if self.loss_type == LossType.RESCALED_KL:
terms["loss"] *= self.num_timesteps terms["loss"] *= self.num_timesteps
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
@ -873,79 +873,9 @@ class GaussianDiffusion:
terms["mse"] = mean_flat(s_err) terms["mse"] = mean_flat(s_err)
terms["x_start_predicted"] = x_start_pred terms["x_start_predicted"] = x_start_pred
if "vb" in terms: if "vb" in terms:
terms["loss"] = terms["mse"] + terms["vb"] if channel_balancing_fn is not None:
else: terms["vb"] = channel_balancing_fn(terms["vb"])
terms["loss"] = terms["mse"] terms["loss"] = terms["mse"] + mean_flat(terms["vb"])
else:
raise NotImplementedError(self.loss_type)
return terms
def autoregressive_training_losses(self, model, x_start, t, model_output_keys, gd_out_key, model_kwargs=None, noise=None):
"""
Compute training losses for a single timestep.
:param model: the model to evaluate loss on.
:param x_start: the [N x C x ...] tensor of inputs.
:param t: a batch of timestep indices.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:param noise: if specified, the specific Gaussian noise to try to remove.
:return: a dict with the key "loss" containing a tensor of shape [N].
Some mean or variance settings may also have other keys.
"""
if model_kwargs is None:
model_kwargs = {}
if noise is None:
noise = th.randn_like(x_start)
x_t = self.q_sample(x_start, t, noise=noise)
terms = {}
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
assert False # not currently supported for this type of diffusion.
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
model_outputs = model(x_t, x_start, self._scale_timesteps(t), **model_kwargs)
terms.update({k: o for k, o in zip(model_output_keys, model_outputs)})
model_output = terms[gd_out_key]
if self.model_var_type in [
ModelVarType.LEARNED,
ModelVarType.LEARNED_RANGE,
]:
B, C = x_t.shape[:2]
assert model_output.shape == (B, C, 2, *x_t.shape[2:])
model_output, model_var_values = model_output[:, :, 0], model_output[:, :, 1]
# Learn the variance using the variational bound, but don't let
# it affect our mean prediction.
frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
terms["vb"] = self._vb_terms_bpd(
model=lambda *args, r=frozen_out: r,
x_start=x_start,
x_t=x_t,
t=t,
clip_denoised=False,
)["output"]
if self.loss_type == LossType.RESCALED_MSE:
# Divide by 1000 for equivalence with initial implementation.
# Without a factor of 1/1000, the VB term hurts the MSE term.
terms["vb"] *= self.num_timesteps / 1000.0
if self.model_mean_type == ModelMeanType.PREVIOUS_X:
target = self.q_posterior_mean_variance(
x_start=x_start, x_t=x_t, t=t
)[0]
x_start_pred = torch.zeros(x_start) # Not supported.
elif self.model_mean_type == ModelMeanType.START_X:
target = x_start
x_start_pred = model_output
elif self.model_mean_type == ModelMeanType.EPSILON:
target = noise
x_start_pred = self._predict_xstart_from_eps(x_t, t, model_output)
else:
raise NotImplementedError(self.model_mean_type)
assert model_output.shape == target.shape == x_start.shape
terms["mse"] = mean_flat((target - model_output) ** 2)
terms["x_start_predicted"] = x_start_pred
if "vb" in terms:
terms["loss"] = terms["mse"] + terms["vb"]
else: else:
terms["loss"] = terms["mse"] terms["loss"] = terms["mse"]
else: else:
@ -1001,14 +931,14 @@ class GaussianDiffusion:
x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
# Calculate VLB term at the current timestep # Calculate VLB term at the current timestep
with th.no_grad(): with th.no_grad():
out = self._vb_terms_bpd( out = mean_flat(self._vb_terms_bpd(
model, model,
x_start=x_start, x_start=x_start,
x_t=x_t, x_t=x_t,
t=t_batch, t=t_batch,
clip_denoised=clip_denoised, clip_denoised=clip_denoised,
model_kwargs=model_kwargs, model_kwargs=model_kwargs,
) ))
vb.append(out["output"]) vb.append(out["output"])
xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])

View File

@ -95,11 +95,6 @@ class SpacedDiffusion(GaussianDiffusion):
): # pylint: disable=signature-differs ): # pylint: disable=signature-differs
return super().training_losses(self._wrap_model(model), *args, **kwargs) return super().training_losses(self._wrap_model(model), *args, **kwargs)
def autoregressive_training_losses(
self, model, *args, **kwargs
): # pylint: disable=signature-differs
return super().autoregressive_training_losses(self._wrap_model(model, True), *args, **kwargs)
def condition_mean(self, cond_fn, *args, **kwargs): def condition_mean(self, cond_fn, *args, **kwargs):
return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)

View File

@ -1,3 +1,4 @@
import functools
import os import os
import os.path as osp import os.path as osp
from glob import glob from glob import glob
@ -62,6 +63,10 @@ class MusicDiffusionFid(evaluator.Evaluator):
self.local_modules['codegen'] = get_music_codegen() self.local_modules['codegen'] = get_music_codegen()
elif 'from_codes_quant' == mode: elif 'from_codes_quant' == mode:
self.diffusion_fn = self.perform_diffusion_from_codes_quant self.diffusion_fn = self.perform_diffusion_from_codes_quant
elif 'partial_from_codes_quant' == mode:
self.diffusion_fn = functools.partial(self.perform_partial_diffusion_from_codes_quant,
partial_low=opt_eval['partial_low'],
partial_high=opt_eval['partial_high'])
elif 'from_codes_quant_gradual_decode' == mode: elif 'from_codes_quant_gradual_decode' == mode:
self.diffusion_fn = self.perform_diffusion_from_codes_quant_gradual_decode self.diffusion_fn = self.perform_diffusion_from_codes_quant_gradual_decode
self.spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 11000, 'filter_length': 16000, self.spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 11000, 'filter_length': 16000,
@ -135,6 +140,32 @@ class MusicDiffusionFid(evaluator.Evaluator):
return gen_wav, real_resampled, gen_mel, mel_norm, sample_rate return gen_wav, real_resampled, gen_mel, mel_norm, sample_rate
def perform_partial_diffusion_from_codes_quant(self, audio, sample_rate=22050, partial_low=0, partial_high=256):
if sample_rate != sample_rate:
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
else:
real_resampled = audio
audio = audio.unsqueeze(0)
mel = self.spec_fn({'in': audio})['out']
mel_norm = normalize_mel(mel)
mask = torch.ones_like(mel_norm)
mask[:, partial_low:partial_high] = 0 # This is the channel region that the model will predict.
gen_mel = self.diffuser.p_sample_loop_with_guidance(self.model,
guidance_input=mel_norm, mask=mask,
model_kwargs={'truth_mel': mel,
'conditioning_input': torch.zeros_like(mel_norm[:,:,:390]),
'disable_diversity': True})
gen_mel_denorm = denormalize_mel(gen_mel)
output_shape = (1,16,audio.shape[-1]//16)
self.spec_decoder = self.spec_decoder.to(audio.device)
gen_wav = self.diffuser.p_sample_loop(self.spec_decoder, output_shape,
model_kwargs={'aligned_conditioning': gen_mel_denorm})
gen_wav = pixel_shuffle_1d(gen_wav, 16)
return gen_wav, real_resampled, gen_mel, mel_norm, sample_rate
def perform_diffusion_from_codes_quant_gradual_decode(self, audio, sample_rate=22050): def perform_diffusion_from_codes_quant_gradual_decode(self, audio, sample_rate=22050):
if sample_rate != sample_rate: if sample_rate != sample_rate:
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0) real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
@ -243,7 +274,8 @@ if __name__ == '__main__':
'path': 'E:\\music_eval', 'path': 'E:\\music_eval',
'diffusion_steps': 100, 'diffusion_steps': 100,
'conditioning_free': False, 'conditioning_free_k': 1, 'conditioning_free': False, 'conditioning_free_k': 1,
'diffusion_schedule': 'linear', 'diffusion_type': 'from_codes_quant'} 'diffusion_schedule': 'linear', 'diffusion_type': 'partial_from_codes_quant',
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 502, 'device': 'cuda', 'opt': {}} 'partial_low': 128, 'partial_high': 192}
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 504, 'device': 'cuda', 'opt': {}}
eval = MusicDiffusionFid(diffusion, opt_eval, env) eval = MusicDiffusionFid(diffusion, opt_eval, env)
print(eval.perform_eval()) print(eval.perform_eval())

View File

@ -1,18 +1,15 @@
import functools import functools
import random
import time
import torch import torch
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
from models.diffusion.gaussian_diffusion import GaussianDiffusion, get_named_beta_schedule from models.diffusion.gaussian_diffusion import get_named_beta_schedule
from models.diffusion.resample import create_named_schedule_sampler, LossAwareSampler, DeterministicSampler from models.diffusion.resample import create_named_schedule_sampler, LossAwareSampler, DeterministicSampler
from models.diffusion.respace import space_timesteps, SpacedDiffusion from models.diffusion.respace import space_timesteps, SpacedDiffusion
from trainer.inject import Injector from trainer.inject import Injector
from utils.util import opt_get from utils.util import opt_get
def masked_channel_balancer(inp, proportion=1): def masked_channel_balancer(inp, proportion=1):
with torch.no_grad(): with torch.no_grad():
only_channels = inp.mean(dim=(0,2)) # Only currently works for audio tensors. Could be retrofitted for 2d (or 3D!) modalities. only_channels = inp.mean(dim=(0,2)) # Only currently works for audio tensors. Could be retrofitted for 2d (or 3D!) modalities.
@ -23,6 +20,12 @@ def masked_channel_balancer(inp, proportion=1):
return inp * mask.view(1,inp.shape[1],1) return inp * mask.view(1,inp.shape[1],1)
def channel_restriction(inp, low, high):
m = torch.zeros_like(inp)
m[:,low:high] = 1
return inp * m
# Injects a gaussian diffusion loss as described by OpenAIs "Improved Denoising Diffusion Probabilistic Models" paper. # Injects a gaussian diffusion loss as described by OpenAIs "Improved Denoising Diffusion Probabilistic Models" paper.
# Largely uses OpenAI's own code to do so (all code from models.diffusion.*) # Largely uses OpenAI's own code to do so (all code from models.diffusion.*)
class GaussianDiffusionInjector(Injector): class GaussianDiffusionInjector(Injector):
@ -40,14 +43,17 @@ class GaussianDiffusionInjector(Injector):
self.extra_model_output_keys = opt_get(opt, ['extra_model_output_keys'], []) self.extra_model_output_keys = opt_get(opt, ['extra_model_output_keys'], [])
self.deterministic_timesteps_every = opt_get(opt, ['deterministic_timesteps_every'], 0) self.deterministic_timesteps_every = opt_get(opt, ['deterministic_timesteps_every'], 0)
self.deterministic_sampler = DeterministicSampler(self.diffusion, opt_get(opt, ['deterministic_sampler_expected_batch_size'], 2048), env) self.deterministic_sampler = DeterministicSampler(self.diffusion, opt_get(opt, ['deterministic_sampler_expected_batch_size'], 2048), env)
self.channel_balancing_fn = functools.partial(masked_channel_balancer, proportion=opt['channel_balancer_proportion']) \
if 'channel_balancer_proportion' in opt.keys() else None
self.recent_loss = 0
def extra_metrics(self): k = 0
return { if 'channel_balancer_proportion' in opt.keys():
'exp_diffusion_loss': torch.exp(self.recent_loss.mean()), self.channel_balancing_fn = functools.partial(masked_channel_balancer, proportion=opt['channel_balancer_proportion'])
} k += 1
if 'channel_restriction_low' in opt.keys():
self.channel_balancing_fn = functools.partial(channel_restriction, low=opt['channel_restriction_low'], high=opt['channel_restriction_high'])
k += 1
if not hasattr(self, 'channel_balancing_fn'):
self.channel_balancing_fn = None
assert k <= 1, 'Only one channel filtering function can be applied.'
def forward(self, state): def forward(self, state):
gen = self.env['generators'][self.opt['generator']] gen = self.env['generators'][self.opt['generator']]
@ -74,8 +80,6 @@ class GaussianDiffusionInjector(Injector):
self.output_variational_bounds_key: diffusion_outputs['vb'], self.output_variational_bounds_key: diffusion_outputs['vb'],
self.output_x_start_key: diffusion_outputs['x_start_predicted']}) self.output_x_start_key: diffusion_outputs['x_start_predicted']})
self.recent_loss = diffusion_outputs['mse']
return out return out