2022-06-03 21:19:23 +00:00
|
|
|
import functools
|
2021-11-24 16:38:10 +00:00
|
|
|
|
2021-06-03 03:47:32 +00:00
|
|
|
import torch
|
2021-12-13 02:52:21 +00:00
|
|
|
from torch.cuda.amp import autocast
|
2021-06-03 03:47:32 +00:00
|
|
|
|
2022-06-10 03:41:20 +00:00
|
|
|
from models.diffusion.gaussian_diffusion import get_named_beta_schedule
|
2022-03-04 17:40:14 +00:00
|
|
|
from models.diffusion.resample import create_named_schedule_sampler, LossAwareSampler, DeterministicSampler
|
2021-06-04 23:13:16 +00:00
|
|
|
from models.diffusion.respace import space_timesteps, SpacedDiffusion
|
2021-06-03 03:47:32 +00:00
|
|
|
from trainer.inject import Injector
|
|
|
|
from utils.util import opt_get
|
|
|
|
|
|
|
|
|
2022-06-03 21:19:23 +00:00
|
|
|
def masked_channel_balancer(inp, proportion=1):
|
|
|
|
with torch.no_grad():
|
|
|
|
only_channels = inp.mean(dim=(0,2)) # Only currently works for audio tensors. Could be retrofitted for 2d (or 3D!) modalities.
|
|
|
|
dist = only_channels / only_channels.sum()
|
|
|
|
dist_mult = only_channels.shape[0] * proportion
|
|
|
|
dist = (dist * dist_mult).clamp(0, 1)
|
|
|
|
mask = torch.bernoulli(dist)
|
|
|
|
return inp * mask.view(1,inp.shape[1],1)
|
|
|
|
|
|
|
|
|
2022-06-10 03:41:20 +00:00
|
|
|
def channel_restriction(inp, low, high):
|
2022-06-10 03:46:32 +00:00
|
|
|
assert low > 0 and low < inp.shape[1] and high <= inp.shape[1]
|
2022-06-10 03:41:20 +00:00
|
|
|
m = torch.zeros_like(inp)
|
|
|
|
m[:,low:high] = 1
|
|
|
|
return inp * m
|
|
|
|
|
|
|
|
|
2021-06-03 03:47:32 +00:00
|
|
|
# Injects a gaussian diffusion loss as described by OpenAIs "Improved Denoising Diffusion Probabilistic Models" paper.
|
|
|
|
# Largely uses OpenAI's own code to do so (all code from models.diffusion.*)
|
|
|
|
class GaussianDiffusionInjector(Injector):
|
|
|
|
def __init__(self, opt, env):
|
|
|
|
super().__init__(opt, env)
|
|
|
|
self.generator = opt['generator']
|
2021-06-04 23:13:16 +00:00
|
|
|
self.output_variational_bounds_key = opt['out_key_vb_loss']
|
|
|
|
self.output_x_start_key = opt['out_key_x_start']
|
2021-06-03 03:47:32 +00:00
|
|
|
opt['diffusion_args']['betas'] = get_named_beta_schedule(**opt['beta_schedule'])
|
2021-06-06 22:35:37 +00:00
|
|
|
opt['diffusion_args']['use_timesteps'] = space_timesteps(opt['beta_schedule']['num_diffusion_timesteps'],
|
|
|
|
[opt['beta_schedule']['num_diffusion_timesteps']])
|
2021-06-04 23:13:16 +00:00
|
|
|
self.diffusion = SpacedDiffusion(**opt['diffusion_args'])
|
2021-06-03 03:47:32 +00:00
|
|
|
self.schedule_sampler = create_named_schedule_sampler(opt['sampler_type'], self.diffusion)
|
|
|
|
self.model_input_keys = opt_get(opt, ['model_input_keys'], [])
|
2021-09-16 16:53:46 +00:00
|
|
|
self.extra_model_output_keys = opt_get(opt, ['extra_model_output_keys'], [])
|
2022-03-04 17:40:14 +00:00
|
|
|
self.deterministic_timesteps_every = opt_get(opt, ['deterministic_timesteps_every'], 0)
|
2022-03-04 18:50:50 +00:00
|
|
|
self.deterministic_sampler = DeterministicSampler(self.diffusion, opt_get(opt, ['deterministic_sampler_expected_batch_size'], 2048), env)
|
2022-05-16 03:50:38 +00:00
|
|
|
|
2022-06-10 03:41:20 +00:00
|
|
|
k = 0
|
|
|
|
if 'channel_balancer_proportion' in opt.keys():
|
|
|
|
self.channel_balancing_fn = functools.partial(masked_channel_balancer, proportion=opt['channel_balancer_proportion'])
|
|
|
|
k += 1
|
|
|
|
if 'channel_restriction_low' in opt.keys():
|
|
|
|
self.channel_balancing_fn = functools.partial(channel_restriction, low=opt['channel_restriction_low'], high=opt['channel_restriction_high'])
|
|
|
|
k += 1
|
|
|
|
if not hasattr(self, 'channel_balancing_fn'):
|
|
|
|
self.channel_balancing_fn = None
|
|
|
|
assert k <= 1, 'Only one channel filtering function can be applied.'
|
2021-06-03 03:47:32 +00:00
|
|
|
|
2022-06-10 21:23:31 +00:00
|
|
|
def extra_metrics(self):
|
|
|
|
if hasattr(self, 'schedule_sampler') and isinstance(self.schedule_sampler, LossSecondMomentResampler):
|
|
|
|
return {
|
|
|
|
'sampler_warmed_up': self.schedule_sampler._warmed_up()
|
|
|
|
}
|
|
|
|
return {}
|
|
|
|
|
2021-06-03 03:47:32 +00:00
|
|
|
def forward(self, state):
|
|
|
|
gen = self.env['generators'][self.opt['generator']]
|
|
|
|
hq = state[self.input]
|
2022-05-20 23:18:35 +00:00
|
|
|
assert hq.max() < 1.5 or hq.min() > -1.5, "Attempting to train gaussian diffusion on un-normalized inputs. This won't work, silly!"
|
2021-11-24 16:38:10 +00:00
|
|
|
|
2021-12-13 02:52:21 +00:00
|
|
|
with autocast(enabled=self.env['opt']['fp16']):
|
2022-03-04 17:40:14 +00:00
|
|
|
if not gen.training or (self.deterministic_timesteps_every != 0 and self.env['step'] % self.deterministic_timesteps_every == 0):
|
2022-03-04 18:50:50 +00:00
|
|
|
sampler = self.deterministic_sampler
|
2022-03-04 17:40:14 +00:00
|
|
|
else:
|
|
|
|
sampler = self.schedule_sampler
|
2022-03-04 18:50:50 +00:00
|
|
|
self.deterministic_sampler.reset() # Keep this reset whenever it is not being used, so it is ready to use automatically.
|
2022-03-26 14:36:19 +00:00
|
|
|
model_inputs = {k: state[v] if isinstance(v, str) else v for k, v in self.model_input_keys.items()}
|
2022-03-04 17:40:14 +00:00
|
|
|
t, weights = sampler.sample(hq.shape[0], hq.device)
|
2022-06-03 21:19:23 +00:00
|
|
|
diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs, channel_balancing_fn=self.channel_balancing_fn)
|
2022-03-04 17:40:14 +00:00
|
|
|
if isinstance(sampler, LossAwareSampler):
|
2022-06-10 03:56:47 +00:00
|
|
|
sampler.update_with_local_losses(t, diffusion_outputs['loss'])
|
2021-12-13 02:52:21 +00:00
|
|
|
if len(self.extra_model_output_keys) > 0:
|
|
|
|
assert(len(self.extra_model_output_keys) == len(diffusion_outputs['extra_outputs']))
|
|
|
|
out = {k: v for k, v in zip(self.extra_model_output_keys, diffusion_outputs['extra_outputs'])}
|
|
|
|
else:
|
|
|
|
out = {}
|
|
|
|
out.update({self.output: diffusion_outputs['mse'],
|
|
|
|
self.output_variational_bounds_key: diffusion_outputs['vb'],
|
|
|
|
self.output_x_start_key: diffusion_outputs['x_start_predicted']})
|
2021-11-24 16:38:10 +00:00
|
|
|
|
2021-09-16 16:53:46 +00:00
|
|
|
return out
|
2021-06-03 03:47:32 +00:00
|
|
|
|
|
|
|
|
2021-12-13 02:52:21 +00:00
|
|
|
def closest_multiple(inp, multiple):
|
2021-12-17 03:47:37 +00:00
|
|
|
div = inp // multiple
|
2021-12-13 02:52:21 +00:00
|
|
|
mod = inp % multiple
|
|
|
|
if mod == 0:
|
|
|
|
return inp
|
|
|
|
else:
|
2021-12-17 03:47:37 +00:00
|
|
|
return int((div+1)*multiple)
|
2021-07-26 22:27:31 +00:00
|
|
|
|
|
|
|
|
2021-06-03 03:47:32 +00:00
|
|
|
# Performs inference using a network trained to predict a reverse diffusion process, which nets a image.
|
|
|
|
class GaussianDiffusionInferenceInjector(Injector):
|
|
|
|
def __init__(self, opt, env):
|
|
|
|
super().__init__(opt, env)
|
2021-06-16 22:26:36 +00:00
|
|
|
use_ddim = opt_get(opt, ['use_ddim'], False)
|
2021-06-03 03:47:32 +00:00
|
|
|
self.generator = opt['generator']
|
2021-06-11 21:31:10 +00:00
|
|
|
self.output_batch_size = opt['output_batch_size']
|
|
|
|
self.output_scale_factor = opt['output_scale_factor']
|
|
|
|
self.undo_n1_to_1 = opt_get(opt, ['undo_n1_to_1'], False) # Explanation: when specified, will shift the output of this injector from [-1,1] to [0,1]
|
2021-06-03 03:47:32 +00:00
|
|
|
opt['diffusion_args']['betas'] = get_named_beta_schedule(**opt['beta_schedule'])
|
2021-06-16 22:26:36 +00:00
|
|
|
if use_ddim:
|
|
|
|
spacing = "ddim" + str(opt['respaced_timestep_spacing'])
|
|
|
|
else:
|
|
|
|
spacing = [opt_get(opt, ['respaced_timestep_spacing'], opt['beta_schedule']['num_diffusion_timesteps'])]
|
|
|
|
opt['diffusion_args']['use_timesteps'] = space_timesteps(opt['beta_schedule']['num_diffusion_timesteps'], spacing)
|
2021-06-04 23:13:16 +00:00
|
|
|
self.diffusion = SpacedDiffusion(**opt['diffusion_args'])
|
2021-06-16 22:26:36 +00:00
|
|
|
self.sampling_fn = self.diffusion.ddim_sample_loop if use_ddim else self.diffusion.p_sample_loop
|
2021-06-03 03:47:32 +00:00
|
|
|
self.model_input_keys = opt_get(opt, ['model_input_keys'], [])
|
2021-06-14 15:14:30 +00:00
|
|
|
self.use_ema_model = opt_get(opt, ['use_ema'], False)
|
2021-06-21 16:38:07 +00:00
|
|
|
self.noise_style = opt_get(opt, ['noise_type'], 'random') # 'zero', 'fixed' or 'random'
|
2021-12-13 02:52:21 +00:00
|
|
|
self.multiple_requirement = opt_get(opt, ['multiple_requirement'], 4096)
|
2021-06-03 03:47:32 +00:00
|
|
|
|
|
|
|
def forward(self, state):
|
2021-06-14 15:14:30 +00:00
|
|
|
if self.use_ema_model:
|
|
|
|
gen = self.env['emas'][self.opt['generator']]
|
|
|
|
else:
|
|
|
|
gen = self.env['generators'][self.opt['generator']]
|
2021-06-11 21:31:10 +00:00
|
|
|
model_inputs = {k: state[v][:self.output_batch_size] for k, v in self.model_input_keys.items()}
|
2021-06-03 03:47:32 +00:00
|
|
|
gen.eval()
|
|
|
|
with torch.no_grad():
|
2021-09-01 14:34:47 +00:00
|
|
|
if 'low_res' in model_inputs.keys():
|
|
|
|
output_shape = (self.output_batch_size, 3, model_inputs['low_res'].shape[-2] * self.output_scale_factor,
|
|
|
|
model_inputs['low_res'].shape[-1] * self.output_scale_factor)
|
|
|
|
dev = model_inputs['low_res'].device
|
|
|
|
elif 'spectrogram' in model_inputs.keys():
|
2021-12-13 02:52:21 +00:00
|
|
|
output_shape = (self.output_batch_size, 1, closest_multiple(model_inputs['spectrogram'].shape[-1] * self.output_scale_factor, self.multiple_requirement))
|
2021-09-01 14:34:47 +00:00
|
|
|
dev = model_inputs['spectrogram'].device
|
2021-10-17 23:32:46 +00:00
|
|
|
elif 'discrete_spectrogram' in model_inputs.keys():
|
2021-12-13 02:52:21 +00:00
|
|
|
output_shape = (self.output_batch_size, 1, closest_multiple(model_inputs['discrete_spectrogram'].shape[-1]*1024, self.multiple_requirement))
|
2021-10-17 23:32:46 +00:00
|
|
|
dev = model_inputs['discrete_spectrogram'].device
|
2021-09-01 14:34:47 +00:00
|
|
|
else:
|
|
|
|
raise NotImplementedError
|
2021-06-21 16:38:07 +00:00
|
|
|
noise = None
|
|
|
|
if self.noise_style == 'zero':
|
2021-09-01 14:34:47 +00:00
|
|
|
noise = torch.zeros(output_shape, device=dev)
|
2021-06-21 16:38:07 +00:00
|
|
|
elif self.noise_style == 'fixed':
|
|
|
|
if not hasattr(self, 'fixed_noise') or self.fixed_noise.shape != output_shape:
|
2021-09-01 14:34:47 +00:00
|
|
|
self.fixed_noise = torch.randn(output_shape, device=dev)
|
2021-06-21 16:38:07 +00:00
|
|
|
noise = self.fixed_noise
|
2021-10-17 23:32:46 +00:00
|
|
|
gen = self.sampling_fn(gen, output_shape, noise=noise, model_kwargs=model_inputs, progress=True, device=dev)
|
2021-06-11 21:31:10 +00:00
|
|
|
if self.undo_n1_to_1:
|
|
|
|
gen = (gen + 1) / 2
|
2021-06-03 03:47:32 +00:00
|
|
|
return {self.output: gen}
|