setup for partial channel diffusion
This commit is contained in:
parent
47b34f5cb9
commit
34005367fd
|
@ -777,13 +777,13 @@ class GaussianDiffusion:
|
|||
kl = normal_kl(
|
||||
true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
|
||||
)
|
||||
kl = mean_flat(kl) / np.log(2.0)
|
||||
kl = kl / np.log(2.0)
|
||||
|
||||
decoder_nll = -discretized_gaussian_log_likelihood(
|
||||
x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
|
||||
)
|
||||
assert decoder_nll.shape == x_start.shape
|
||||
decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
|
||||
decoder_nll = decoder_nll / np.log(2.0)
|
||||
|
||||
# At the first timestep return the decoder NLL,
|
||||
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
|
||||
|
@ -813,14 +813,14 @@ class GaussianDiffusion:
|
|||
|
||||
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
|
||||
# TODO: support multiple model outputs for this mode.
|
||||
terms["loss"] = self._vb_terms_bpd(
|
||||
terms["loss"] = mean_flat(self._vb_terms_bpd(
|
||||
model=model,
|
||||
x_start=x_start,
|
||||
x_t=x_t,
|
||||
t=t,
|
||||
clip_denoised=False,
|
||||
model_kwargs=model_kwargs,
|
||||
)["output"]
|
||||
)["output"])
|
||||
if self.loss_type == LossType.RESCALED_KL:
|
||||
terms["loss"] *= self.num_timesteps
|
||||
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
|
||||
|
@ -873,79 +873,9 @@ class GaussianDiffusion:
|
|||
terms["mse"] = mean_flat(s_err)
|
||||
terms["x_start_predicted"] = x_start_pred
|
||||
if "vb" in terms:
|
||||
terms["loss"] = terms["mse"] + terms["vb"]
|
||||
else:
|
||||
terms["loss"] = terms["mse"]
|
||||
else:
|
||||
raise NotImplementedError(self.loss_type)
|
||||
|
||||
return terms
|
||||
|
||||
def autoregressive_training_losses(self, model, x_start, t, model_output_keys, gd_out_key, model_kwargs=None, noise=None):
|
||||
"""
|
||||
Compute training losses for a single timestep.
|
||||
|
||||
:param model: the model to evaluate loss on.
|
||||
:param x_start: the [N x C x ...] tensor of inputs.
|
||||
:param t: a batch of timestep indices.
|
||||
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
||||
pass to the model. This can be used for conditioning.
|
||||
:param noise: if specified, the specific Gaussian noise to try to remove.
|
||||
:return: a dict with the key "loss" containing a tensor of shape [N].
|
||||
Some mean or variance settings may also have other keys.
|
||||
"""
|
||||
if model_kwargs is None:
|
||||
model_kwargs = {}
|
||||
if noise is None:
|
||||
noise = th.randn_like(x_start)
|
||||
x_t = self.q_sample(x_start, t, noise=noise)
|
||||
terms = {}
|
||||
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
|
||||
assert False # not currently supported for this type of diffusion.
|
||||
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
|
||||
model_outputs = model(x_t, x_start, self._scale_timesteps(t), **model_kwargs)
|
||||
terms.update({k: o for k, o in zip(model_output_keys, model_outputs)})
|
||||
model_output = terms[gd_out_key]
|
||||
if self.model_var_type in [
|
||||
ModelVarType.LEARNED,
|
||||
ModelVarType.LEARNED_RANGE,
|
||||
]:
|
||||
B, C = x_t.shape[:2]
|
||||
assert model_output.shape == (B, C, 2, *x_t.shape[2:])
|
||||
model_output, model_var_values = model_output[:, :, 0], model_output[:, :, 1]
|
||||
# Learn the variance using the variational bound, but don't let
|
||||
# it affect our mean prediction.
|
||||
frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
|
||||
terms["vb"] = self._vb_terms_bpd(
|
||||
model=lambda *args, r=frozen_out: r,
|
||||
x_start=x_start,
|
||||
x_t=x_t,
|
||||
t=t,
|
||||
clip_denoised=False,
|
||||
)["output"]
|
||||
if self.loss_type == LossType.RESCALED_MSE:
|
||||
# Divide by 1000 for equivalence with initial implementation.
|
||||
# Without a factor of 1/1000, the VB term hurts the MSE term.
|
||||
terms["vb"] *= self.num_timesteps / 1000.0
|
||||
|
||||
if self.model_mean_type == ModelMeanType.PREVIOUS_X:
|
||||
target = self.q_posterior_mean_variance(
|
||||
x_start=x_start, x_t=x_t, t=t
|
||||
)[0]
|
||||
x_start_pred = torch.zeros(x_start) # Not supported.
|
||||
elif self.model_mean_type == ModelMeanType.START_X:
|
||||
target = x_start
|
||||
x_start_pred = model_output
|
||||
elif self.model_mean_type == ModelMeanType.EPSILON:
|
||||
target = noise
|
||||
x_start_pred = self._predict_xstart_from_eps(x_t, t, model_output)
|
||||
else:
|
||||
raise NotImplementedError(self.model_mean_type)
|
||||
assert model_output.shape == target.shape == x_start.shape
|
||||
terms["mse"] = mean_flat((target - model_output) ** 2)
|
||||
terms["x_start_predicted"] = x_start_pred
|
||||
if "vb" in terms:
|
||||
terms["loss"] = terms["mse"] + terms["vb"]
|
||||
if channel_balancing_fn is not None:
|
||||
terms["vb"] = channel_balancing_fn(terms["vb"])
|
||||
terms["loss"] = terms["mse"] + mean_flat(terms["vb"])
|
||||
else:
|
||||
terms["loss"] = terms["mse"]
|
||||
else:
|
||||
|
@ -1001,14 +931,14 @@ class GaussianDiffusion:
|
|||
x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
|
||||
# Calculate VLB term at the current timestep
|
||||
with th.no_grad():
|
||||
out = self._vb_terms_bpd(
|
||||
out = mean_flat(self._vb_terms_bpd(
|
||||
model,
|
||||
x_start=x_start,
|
||||
x_t=x_t,
|
||||
t=t_batch,
|
||||
clip_denoised=clip_denoised,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
))
|
||||
vb.append(out["output"])
|
||||
xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
|
||||
eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
|
||||
|
|
|
@ -95,11 +95,6 @@ class SpacedDiffusion(GaussianDiffusion):
|
|||
): # pylint: disable=signature-differs
|
||||
return super().training_losses(self._wrap_model(model), *args, **kwargs)
|
||||
|
||||
def autoregressive_training_losses(
|
||||
self, model, *args, **kwargs
|
||||
): # pylint: disable=signature-differs
|
||||
return super().autoregressive_training_losses(self._wrap_model(model, True), *args, **kwargs)
|
||||
|
||||
def condition_mean(self, cond_fn, *args, **kwargs):
|
||||
return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import functools
|
||||
import os
|
||||
import os.path as osp
|
||||
from glob import glob
|
||||
|
@ -62,6 +63,10 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
self.local_modules['codegen'] = get_music_codegen()
|
||||
elif 'from_codes_quant' == mode:
|
||||
self.diffusion_fn = self.perform_diffusion_from_codes_quant
|
||||
elif 'partial_from_codes_quant' == mode:
|
||||
self.diffusion_fn = functools.partial(self.perform_partial_diffusion_from_codes_quant,
|
||||
partial_low=opt_eval['partial_low'],
|
||||
partial_high=opt_eval['partial_high'])
|
||||
elif 'from_codes_quant_gradual_decode' == mode:
|
||||
self.diffusion_fn = self.perform_diffusion_from_codes_quant_gradual_decode
|
||||
self.spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 11000, 'filter_length': 16000,
|
||||
|
@ -135,6 +140,32 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
|
||||
return gen_wav, real_resampled, gen_mel, mel_norm, sample_rate
|
||||
|
||||
def perform_partial_diffusion_from_codes_quant(self, audio, sample_rate=22050, partial_low=0, partial_high=256):
|
||||
if sample_rate != sample_rate:
|
||||
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
|
||||
else:
|
||||
real_resampled = audio
|
||||
audio = audio.unsqueeze(0)
|
||||
|
||||
mel = self.spec_fn({'in': audio})['out']
|
||||
mel_norm = normalize_mel(mel)
|
||||
mask = torch.ones_like(mel_norm)
|
||||
mask[:, partial_low:partial_high] = 0 # This is the channel region that the model will predict.
|
||||
gen_mel = self.diffuser.p_sample_loop_with_guidance(self.model,
|
||||
guidance_input=mel_norm, mask=mask,
|
||||
model_kwargs={'truth_mel': mel,
|
||||
'conditioning_input': torch.zeros_like(mel_norm[:,:,:390]),
|
||||
'disable_diversity': True})
|
||||
|
||||
gen_mel_denorm = denormalize_mel(gen_mel)
|
||||
output_shape = (1,16,audio.shape[-1]//16)
|
||||
self.spec_decoder = self.spec_decoder.to(audio.device)
|
||||
gen_wav = self.diffuser.p_sample_loop(self.spec_decoder, output_shape,
|
||||
model_kwargs={'aligned_conditioning': gen_mel_denorm})
|
||||
gen_wav = pixel_shuffle_1d(gen_wav, 16)
|
||||
|
||||
return gen_wav, real_resampled, gen_mel, mel_norm, sample_rate
|
||||
|
||||
def perform_diffusion_from_codes_quant_gradual_decode(self, audio, sample_rate=22050):
|
||||
if sample_rate != sample_rate:
|
||||
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
|
||||
|
@ -243,7 +274,8 @@ if __name__ == '__main__':
|
|||
'path': 'E:\\music_eval',
|
||||
'diffusion_steps': 100,
|
||||
'conditioning_free': False, 'conditioning_free_k': 1,
|
||||
'diffusion_schedule': 'linear', 'diffusion_type': 'from_codes_quant'}
|
||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 502, 'device': 'cuda', 'opt': {}}
|
||||
'diffusion_schedule': 'linear', 'diffusion_type': 'partial_from_codes_quant',
|
||||
'partial_low': 128, 'partial_high': 192}
|
||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 504, 'device': 'cuda', 'opt': {}}
|
||||
eval = MusicDiffusionFid(diffusion, opt_eval, env)
|
||||
print(eval.perform_eval())
|
||||
|
|
|
@ -1,18 +1,15 @@
|
|||
import functools
|
||||
import random
|
||||
import time
|
||||
|
||||
import torch
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
from models.diffusion.gaussian_diffusion import GaussianDiffusion, get_named_beta_schedule
|
||||
from models.diffusion.gaussian_diffusion import get_named_beta_schedule
|
||||
from models.diffusion.resample import create_named_schedule_sampler, LossAwareSampler, DeterministicSampler
|
||||
from models.diffusion.respace import space_timesteps, SpacedDiffusion
|
||||
from trainer.inject import Injector
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
|
||||
def masked_channel_balancer(inp, proportion=1):
|
||||
with torch.no_grad():
|
||||
only_channels = inp.mean(dim=(0,2)) # Only currently works for audio tensors. Could be retrofitted for 2d (or 3D!) modalities.
|
||||
|
@ -23,6 +20,12 @@ def masked_channel_balancer(inp, proportion=1):
|
|||
return inp * mask.view(1,inp.shape[1],1)
|
||||
|
||||
|
||||
def channel_restriction(inp, low, high):
|
||||
m = torch.zeros_like(inp)
|
||||
m[:,low:high] = 1
|
||||
return inp * m
|
||||
|
||||
|
||||
# Injects a gaussian diffusion loss as described by OpenAIs "Improved Denoising Diffusion Probabilistic Models" paper.
|
||||
# Largely uses OpenAI's own code to do so (all code from models.diffusion.*)
|
||||
class GaussianDiffusionInjector(Injector):
|
||||
|
@ -40,14 +43,17 @@ class GaussianDiffusionInjector(Injector):
|
|||
self.extra_model_output_keys = opt_get(opt, ['extra_model_output_keys'], [])
|
||||
self.deterministic_timesteps_every = opt_get(opt, ['deterministic_timesteps_every'], 0)
|
||||
self.deterministic_sampler = DeterministicSampler(self.diffusion, opt_get(opt, ['deterministic_sampler_expected_batch_size'], 2048), env)
|
||||
self.channel_balancing_fn = functools.partial(masked_channel_balancer, proportion=opt['channel_balancer_proportion']) \
|
||||
if 'channel_balancer_proportion' in opt.keys() else None
|
||||
self.recent_loss = 0
|
||||
|
||||
def extra_metrics(self):
|
||||
return {
|
||||
'exp_diffusion_loss': torch.exp(self.recent_loss.mean()),
|
||||
}
|
||||
k = 0
|
||||
if 'channel_balancer_proportion' in opt.keys():
|
||||
self.channel_balancing_fn = functools.partial(masked_channel_balancer, proportion=opt['channel_balancer_proportion'])
|
||||
k += 1
|
||||
if 'channel_restriction_low' in opt.keys():
|
||||
self.channel_balancing_fn = functools.partial(channel_restriction, low=opt['channel_restriction_low'], high=opt['channel_restriction_high'])
|
||||
k += 1
|
||||
if not hasattr(self, 'channel_balancing_fn'):
|
||||
self.channel_balancing_fn = None
|
||||
assert k <= 1, 'Only one channel filtering function can be applied.'
|
||||
|
||||
def forward(self, state):
|
||||
gen = self.env['generators'][self.opt['generator']]
|
||||
|
@ -74,8 +80,6 @@ class GaussianDiffusionInjector(Injector):
|
|||
self.output_variational_bounds_key: diffusion_outputs['vb'],
|
||||
self.output_x_start_key: diffusion_outputs['x_start_predicted']})
|
||||
|
||||
self.recent_loss = diffusion_outputs['mse']
|
||||
|
||||
return out
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user