setup for partial channel diffusion
This commit is contained in:
parent
47b34f5cb9
commit
34005367fd
|
@ -777,13 +777,13 @@ class GaussianDiffusion:
|
||||||
kl = normal_kl(
|
kl = normal_kl(
|
||||||
true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
|
true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
|
||||||
)
|
)
|
||||||
kl = mean_flat(kl) / np.log(2.0)
|
kl = kl / np.log(2.0)
|
||||||
|
|
||||||
decoder_nll = -discretized_gaussian_log_likelihood(
|
decoder_nll = -discretized_gaussian_log_likelihood(
|
||||||
x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
|
x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
|
||||||
)
|
)
|
||||||
assert decoder_nll.shape == x_start.shape
|
assert decoder_nll.shape == x_start.shape
|
||||||
decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
|
decoder_nll = decoder_nll / np.log(2.0)
|
||||||
|
|
||||||
# At the first timestep return the decoder NLL,
|
# At the first timestep return the decoder NLL,
|
||||||
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
|
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
|
||||||
|
@ -813,14 +813,14 @@ class GaussianDiffusion:
|
||||||
|
|
||||||
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
|
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
|
||||||
# TODO: support multiple model outputs for this mode.
|
# TODO: support multiple model outputs for this mode.
|
||||||
terms["loss"] = self._vb_terms_bpd(
|
terms["loss"] = mean_flat(self._vb_terms_bpd(
|
||||||
model=model,
|
model=model,
|
||||||
x_start=x_start,
|
x_start=x_start,
|
||||||
x_t=x_t,
|
x_t=x_t,
|
||||||
t=t,
|
t=t,
|
||||||
clip_denoised=False,
|
clip_denoised=False,
|
||||||
model_kwargs=model_kwargs,
|
model_kwargs=model_kwargs,
|
||||||
)["output"]
|
)["output"])
|
||||||
if self.loss_type == LossType.RESCALED_KL:
|
if self.loss_type == LossType.RESCALED_KL:
|
||||||
terms["loss"] *= self.num_timesteps
|
terms["loss"] *= self.num_timesteps
|
||||||
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
|
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
|
||||||
|
@ -873,79 +873,9 @@ class GaussianDiffusion:
|
||||||
terms["mse"] = mean_flat(s_err)
|
terms["mse"] = mean_flat(s_err)
|
||||||
terms["x_start_predicted"] = x_start_pred
|
terms["x_start_predicted"] = x_start_pred
|
||||||
if "vb" in terms:
|
if "vb" in terms:
|
||||||
terms["loss"] = terms["mse"] + terms["vb"]
|
if channel_balancing_fn is not None:
|
||||||
else:
|
terms["vb"] = channel_balancing_fn(terms["vb"])
|
||||||
terms["loss"] = terms["mse"]
|
terms["loss"] = terms["mse"] + mean_flat(terms["vb"])
|
||||||
else:
|
|
||||||
raise NotImplementedError(self.loss_type)
|
|
||||||
|
|
||||||
return terms
|
|
||||||
|
|
||||||
def autoregressive_training_losses(self, model, x_start, t, model_output_keys, gd_out_key, model_kwargs=None, noise=None):
|
|
||||||
"""
|
|
||||||
Compute training losses for a single timestep.
|
|
||||||
|
|
||||||
:param model: the model to evaluate loss on.
|
|
||||||
:param x_start: the [N x C x ...] tensor of inputs.
|
|
||||||
:param t: a batch of timestep indices.
|
|
||||||
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
|
||||||
pass to the model. This can be used for conditioning.
|
|
||||||
:param noise: if specified, the specific Gaussian noise to try to remove.
|
|
||||||
:return: a dict with the key "loss" containing a tensor of shape [N].
|
|
||||||
Some mean or variance settings may also have other keys.
|
|
||||||
"""
|
|
||||||
if model_kwargs is None:
|
|
||||||
model_kwargs = {}
|
|
||||||
if noise is None:
|
|
||||||
noise = th.randn_like(x_start)
|
|
||||||
x_t = self.q_sample(x_start, t, noise=noise)
|
|
||||||
terms = {}
|
|
||||||
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
|
|
||||||
assert False # not currently supported for this type of diffusion.
|
|
||||||
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
|
|
||||||
model_outputs = model(x_t, x_start, self._scale_timesteps(t), **model_kwargs)
|
|
||||||
terms.update({k: o for k, o in zip(model_output_keys, model_outputs)})
|
|
||||||
model_output = terms[gd_out_key]
|
|
||||||
if self.model_var_type in [
|
|
||||||
ModelVarType.LEARNED,
|
|
||||||
ModelVarType.LEARNED_RANGE,
|
|
||||||
]:
|
|
||||||
B, C = x_t.shape[:2]
|
|
||||||
assert model_output.shape == (B, C, 2, *x_t.shape[2:])
|
|
||||||
model_output, model_var_values = model_output[:, :, 0], model_output[:, :, 1]
|
|
||||||
# Learn the variance using the variational bound, but don't let
|
|
||||||
# it affect our mean prediction.
|
|
||||||
frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
|
|
||||||
terms["vb"] = self._vb_terms_bpd(
|
|
||||||
model=lambda *args, r=frozen_out: r,
|
|
||||||
x_start=x_start,
|
|
||||||
x_t=x_t,
|
|
||||||
t=t,
|
|
||||||
clip_denoised=False,
|
|
||||||
)["output"]
|
|
||||||
if self.loss_type == LossType.RESCALED_MSE:
|
|
||||||
# Divide by 1000 for equivalence with initial implementation.
|
|
||||||
# Without a factor of 1/1000, the VB term hurts the MSE term.
|
|
||||||
terms["vb"] *= self.num_timesteps / 1000.0
|
|
||||||
|
|
||||||
if self.model_mean_type == ModelMeanType.PREVIOUS_X:
|
|
||||||
target = self.q_posterior_mean_variance(
|
|
||||||
x_start=x_start, x_t=x_t, t=t
|
|
||||||
)[0]
|
|
||||||
x_start_pred = torch.zeros(x_start) # Not supported.
|
|
||||||
elif self.model_mean_type == ModelMeanType.START_X:
|
|
||||||
target = x_start
|
|
||||||
x_start_pred = model_output
|
|
||||||
elif self.model_mean_type == ModelMeanType.EPSILON:
|
|
||||||
target = noise
|
|
||||||
x_start_pred = self._predict_xstart_from_eps(x_t, t, model_output)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(self.model_mean_type)
|
|
||||||
assert model_output.shape == target.shape == x_start.shape
|
|
||||||
terms["mse"] = mean_flat((target - model_output) ** 2)
|
|
||||||
terms["x_start_predicted"] = x_start_pred
|
|
||||||
if "vb" in terms:
|
|
||||||
terms["loss"] = terms["mse"] + terms["vb"]
|
|
||||||
else:
|
else:
|
||||||
terms["loss"] = terms["mse"]
|
terms["loss"] = terms["mse"]
|
||||||
else:
|
else:
|
||||||
|
@ -1001,14 +931,14 @@ class GaussianDiffusion:
|
||||||
x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
|
x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
|
||||||
# Calculate VLB term at the current timestep
|
# Calculate VLB term at the current timestep
|
||||||
with th.no_grad():
|
with th.no_grad():
|
||||||
out = self._vb_terms_bpd(
|
out = mean_flat(self._vb_terms_bpd(
|
||||||
model,
|
model,
|
||||||
x_start=x_start,
|
x_start=x_start,
|
||||||
x_t=x_t,
|
x_t=x_t,
|
||||||
t=t_batch,
|
t=t_batch,
|
||||||
clip_denoised=clip_denoised,
|
clip_denoised=clip_denoised,
|
||||||
model_kwargs=model_kwargs,
|
model_kwargs=model_kwargs,
|
||||||
)
|
))
|
||||||
vb.append(out["output"])
|
vb.append(out["output"])
|
||||||
xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
|
xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
|
||||||
eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
|
eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
|
||||||
|
|
|
@ -95,11 +95,6 @@ class SpacedDiffusion(GaussianDiffusion):
|
||||||
): # pylint: disable=signature-differs
|
): # pylint: disable=signature-differs
|
||||||
return super().training_losses(self._wrap_model(model), *args, **kwargs)
|
return super().training_losses(self._wrap_model(model), *args, **kwargs)
|
||||||
|
|
||||||
def autoregressive_training_losses(
|
|
||||||
self, model, *args, **kwargs
|
|
||||||
): # pylint: disable=signature-differs
|
|
||||||
return super().autoregressive_training_losses(self._wrap_model(model, True), *args, **kwargs)
|
|
||||||
|
|
||||||
def condition_mean(self, cond_fn, *args, **kwargs):
|
def condition_mean(self, cond_fn, *args, **kwargs):
|
||||||
return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
|
return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import functools
|
||||||
import os
|
import os
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
from glob import glob
|
from glob import glob
|
||||||
|
@ -62,6 +63,10 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
||||||
self.local_modules['codegen'] = get_music_codegen()
|
self.local_modules['codegen'] = get_music_codegen()
|
||||||
elif 'from_codes_quant' == mode:
|
elif 'from_codes_quant' == mode:
|
||||||
self.diffusion_fn = self.perform_diffusion_from_codes_quant
|
self.diffusion_fn = self.perform_diffusion_from_codes_quant
|
||||||
|
elif 'partial_from_codes_quant' == mode:
|
||||||
|
self.diffusion_fn = functools.partial(self.perform_partial_diffusion_from_codes_quant,
|
||||||
|
partial_low=opt_eval['partial_low'],
|
||||||
|
partial_high=opt_eval['partial_high'])
|
||||||
elif 'from_codes_quant_gradual_decode' == mode:
|
elif 'from_codes_quant_gradual_decode' == mode:
|
||||||
self.diffusion_fn = self.perform_diffusion_from_codes_quant_gradual_decode
|
self.diffusion_fn = self.perform_diffusion_from_codes_quant_gradual_decode
|
||||||
self.spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 11000, 'filter_length': 16000,
|
self.spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 11000, 'filter_length': 16000,
|
||||||
|
@ -135,6 +140,32 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
||||||
|
|
||||||
return gen_wav, real_resampled, gen_mel, mel_norm, sample_rate
|
return gen_wav, real_resampled, gen_mel, mel_norm, sample_rate
|
||||||
|
|
||||||
|
def perform_partial_diffusion_from_codes_quant(self, audio, sample_rate=22050, partial_low=0, partial_high=256):
|
||||||
|
if sample_rate != sample_rate:
|
||||||
|
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
|
||||||
|
else:
|
||||||
|
real_resampled = audio
|
||||||
|
audio = audio.unsqueeze(0)
|
||||||
|
|
||||||
|
mel = self.spec_fn({'in': audio})['out']
|
||||||
|
mel_norm = normalize_mel(mel)
|
||||||
|
mask = torch.ones_like(mel_norm)
|
||||||
|
mask[:, partial_low:partial_high] = 0 # This is the channel region that the model will predict.
|
||||||
|
gen_mel = self.diffuser.p_sample_loop_with_guidance(self.model,
|
||||||
|
guidance_input=mel_norm, mask=mask,
|
||||||
|
model_kwargs={'truth_mel': mel,
|
||||||
|
'conditioning_input': torch.zeros_like(mel_norm[:,:,:390]),
|
||||||
|
'disable_diversity': True})
|
||||||
|
|
||||||
|
gen_mel_denorm = denormalize_mel(gen_mel)
|
||||||
|
output_shape = (1,16,audio.shape[-1]//16)
|
||||||
|
self.spec_decoder = self.spec_decoder.to(audio.device)
|
||||||
|
gen_wav = self.diffuser.p_sample_loop(self.spec_decoder, output_shape,
|
||||||
|
model_kwargs={'aligned_conditioning': gen_mel_denorm})
|
||||||
|
gen_wav = pixel_shuffle_1d(gen_wav, 16)
|
||||||
|
|
||||||
|
return gen_wav, real_resampled, gen_mel, mel_norm, sample_rate
|
||||||
|
|
||||||
def perform_diffusion_from_codes_quant_gradual_decode(self, audio, sample_rate=22050):
|
def perform_diffusion_from_codes_quant_gradual_decode(self, audio, sample_rate=22050):
|
||||||
if sample_rate != sample_rate:
|
if sample_rate != sample_rate:
|
||||||
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
|
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
|
||||||
|
@ -243,7 +274,8 @@ if __name__ == '__main__':
|
||||||
'path': 'E:\\music_eval',
|
'path': 'E:\\music_eval',
|
||||||
'diffusion_steps': 100,
|
'diffusion_steps': 100,
|
||||||
'conditioning_free': False, 'conditioning_free_k': 1,
|
'conditioning_free': False, 'conditioning_free_k': 1,
|
||||||
'diffusion_schedule': 'linear', 'diffusion_type': 'from_codes_quant'}
|
'diffusion_schedule': 'linear', 'diffusion_type': 'partial_from_codes_quant',
|
||||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 502, 'device': 'cuda', 'opt': {}}
|
'partial_low': 128, 'partial_high': 192}
|
||||||
|
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 504, 'device': 'cuda', 'opt': {}}
|
||||||
eval = MusicDiffusionFid(diffusion, opt_eval, env)
|
eval = MusicDiffusionFid(diffusion, opt_eval, env)
|
||||||
print(eval.perform_eval())
|
print(eval.perform_eval())
|
||||||
|
|
|
@ -1,18 +1,15 @@
|
||||||
import functools
|
import functools
|
||||||
import random
|
|
||||||
import time
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.cuda.amp import autocast
|
from torch.cuda.amp import autocast
|
||||||
|
|
||||||
from models.diffusion.gaussian_diffusion import GaussianDiffusion, get_named_beta_schedule
|
from models.diffusion.gaussian_diffusion import get_named_beta_schedule
|
||||||
from models.diffusion.resample import create_named_schedule_sampler, LossAwareSampler, DeterministicSampler
|
from models.diffusion.resample import create_named_schedule_sampler, LossAwareSampler, DeterministicSampler
|
||||||
from models.diffusion.respace import space_timesteps, SpacedDiffusion
|
from models.diffusion.respace import space_timesteps, SpacedDiffusion
|
||||||
from trainer.inject import Injector
|
from trainer.inject import Injector
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def masked_channel_balancer(inp, proportion=1):
|
def masked_channel_balancer(inp, proportion=1):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
only_channels = inp.mean(dim=(0,2)) # Only currently works for audio tensors. Could be retrofitted for 2d (or 3D!) modalities.
|
only_channels = inp.mean(dim=(0,2)) # Only currently works for audio tensors. Could be retrofitted for 2d (or 3D!) modalities.
|
||||||
|
@ -23,6 +20,12 @@ def masked_channel_balancer(inp, proportion=1):
|
||||||
return inp * mask.view(1,inp.shape[1],1)
|
return inp * mask.view(1,inp.shape[1],1)
|
||||||
|
|
||||||
|
|
||||||
|
def channel_restriction(inp, low, high):
|
||||||
|
m = torch.zeros_like(inp)
|
||||||
|
m[:,low:high] = 1
|
||||||
|
return inp * m
|
||||||
|
|
||||||
|
|
||||||
# Injects a gaussian diffusion loss as described by OpenAIs "Improved Denoising Diffusion Probabilistic Models" paper.
|
# Injects a gaussian diffusion loss as described by OpenAIs "Improved Denoising Diffusion Probabilistic Models" paper.
|
||||||
# Largely uses OpenAI's own code to do so (all code from models.diffusion.*)
|
# Largely uses OpenAI's own code to do so (all code from models.diffusion.*)
|
||||||
class GaussianDiffusionInjector(Injector):
|
class GaussianDiffusionInjector(Injector):
|
||||||
|
@ -40,14 +43,17 @@ class GaussianDiffusionInjector(Injector):
|
||||||
self.extra_model_output_keys = opt_get(opt, ['extra_model_output_keys'], [])
|
self.extra_model_output_keys = opt_get(opt, ['extra_model_output_keys'], [])
|
||||||
self.deterministic_timesteps_every = opt_get(opt, ['deterministic_timesteps_every'], 0)
|
self.deterministic_timesteps_every = opt_get(opt, ['deterministic_timesteps_every'], 0)
|
||||||
self.deterministic_sampler = DeterministicSampler(self.diffusion, opt_get(opt, ['deterministic_sampler_expected_batch_size'], 2048), env)
|
self.deterministic_sampler = DeterministicSampler(self.diffusion, opt_get(opt, ['deterministic_sampler_expected_batch_size'], 2048), env)
|
||||||
self.channel_balancing_fn = functools.partial(masked_channel_balancer, proportion=opt['channel_balancer_proportion']) \
|
|
||||||
if 'channel_balancer_proportion' in opt.keys() else None
|
|
||||||
self.recent_loss = 0
|
|
||||||
|
|
||||||
def extra_metrics(self):
|
k = 0
|
||||||
return {
|
if 'channel_balancer_proportion' in opt.keys():
|
||||||
'exp_diffusion_loss': torch.exp(self.recent_loss.mean()),
|
self.channel_balancing_fn = functools.partial(masked_channel_balancer, proportion=opt['channel_balancer_proportion'])
|
||||||
}
|
k += 1
|
||||||
|
if 'channel_restriction_low' in opt.keys():
|
||||||
|
self.channel_balancing_fn = functools.partial(channel_restriction, low=opt['channel_restriction_low'], high=opt['channel_restriction_high'])
|
||||||
|
k += 1
|
||||||
|
if not hasattr(self, 'channel_balancing_fn'):
|
||||||
|
self.channel_balancing_fn = None
|
||||||
|
assert k <= 1, 'Only one channel filtering function can be applied.'
|
||||||
|
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
gen = self.env['generators'][self.opt['generator']]
|
gen = self.env['generators'][self.opt['generator']]
|
||||||
|
@ -74,8 +80,6 @@ class GaussianDiffusionInjector(Injector):
|
||||||
self.output_variational_bounds_key: diffusion_outputs['vb'],
|
self.output_variational_bounds_key: diffusion_outputs['vb'],
|
||||||
self.output_x_start_key: diffusion_outputs['x_start_predicted']})
|
self.output_x_start_key: diffusion_outputs['x_start_predicted']})
|
||||||
|
|
||||||
self.recent_loss = diffusion_outputs['mse']
|
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user