DL-Art-School/codes/trainer/injectors/gaussian_diffusion_injector.py

125 lines
6.5 KiB
Python
Raw Normal View History

import random
import time
import torch
2021-12-13 02:52:21 +00:00
from torch.cuda.amp import autocast
from models.diffusion.gaussian_diffusion import GaussianDiffusion, get_named_beta_schedule
from models.diffusion.resample import create_named_schedule_sampler, LossAwareSampler
from models.diffusion.respace import space_timesteps, SpacedDiffusion
from trainer.inject import Injector
from utils.util import opt_get
# Injects a gaussian diffusion loss as described by OpenAIs "Improved Denoising Diffusion Probabilistic Models" paper.
# Largely uses OpenAI's own code to do so (all code from models.diffusion.*)
class GaussianDiffusionInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
self.generator = opt['generator']
self.output_variational_bounds_key = opt['out_key_vb_loss']
self.output_x_start_key = opt['out_key_x_start']
opt['diffusion_args']['betas'] = get_named_beta_schedule(**opt['beta_schedule'])
2021-06-06 22:35:37 +00:00
opt['diffusion_args']['use_timesteps'] = space_timesteps(opt['beta_schedule']['num_diffusion_timesteps'],
[opt['beta_schedule']['num_diffusion_timesteps']])
self.diffusion = SpacedDiffusion(**opt['diffusion_args'])
self.schedule_sampler = create_named_schedule_sampler(opt['sampler_type'], self.diffusion)
self.model_input_keys = opt_get(opt, ['model_input_keys'], [])
self.extra_model_output_keys = opt_get(opt, ['extra_model_output_keys'], [])
def forward(self, state):
gen = self.env['generators'][self.opt['generator']]
hq = state[self.input]
# In eval mode, seed torch with a deterministic seed for reproducibility.
2021-12-07 16:55:39 +00:00
if not gen.training:
torch.manual_seed(0)
random.seed(0)
2021-12-13 02:52:21 +00:00
with autocast(enabled=self.env['opt']['fp16']):
model_inputs = {k: state[v] for k, v in self.model_input_keys.items()}
t, weights = self.schedule_sampler.sample(hq.shape[0], hq.device)
diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs)
if isinstance(self.schedule_sampler, LossAwareSampler):
self.schedule_sampler.update_with_local_losses(t, diffusion_outputs['losses'])
2021-12-13 02:52:21 +00:00
if len(self.extra_model_output_keys) > 0:
assert(len(self.extra_model_output_keys) == len(diffusion_outputs['extra_outputs']))
out = {k: v for k, v in zip(self.extra_model_output_keys, diffusion_outputs['extra_outputs'])}
else:
out = {}
out.update({self.output: diffusion_outputs['mse'],
self.output_variational_bounds_key: diffusion_outputs['vb'],
self.output_x_start_key: diffusion_outputs['x_start_predicted']})
# Absolutely critical to undo the above seed.
2021-12-07 16:55:39 +00:00
if not gen.training:
torch.manual_seed(int(time.time()))
random.seed(int(time.time()))
return out
2021-12-13 02:52:21 +00:00
def closest_multiple(inp, multiple):
2021-12-17 03:47:37 +00:00
div = inp // multiple
2021-12-13 02:52:21 +00:00
mod = inp % multiple
if mod == 0:
return inp
else:
2021-12-17 03:47:37 +00:00
return int((div+1)*multiple)
# Performs inference using a network trained to predict a reverse diffusion process, which nets a image.
class GaussianDiffusionInferenceInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
use_ddim = opt_get(opt, ['use_ddim'], False)
self.generator = opt['generator']
2021-06-11 21:31:10 +00:00
self.output_batch_size = opt['output_batch_size']
self.output_scale_factor = opt['output_scale_factor']
self.undo_n1_to_1 = opt_get(opt, ['undo_n1_to_1'], False) # Explanation: when specified, will shift the output of this injector from [-1,1] to [0,1]
opt['diffusion_args']['betas'] = get_named_beta_schedule(**opt['beta_schedule'])
if use_ddim:
spacing = "ddim" + str(opt['respaced_timestep_spacing'])
else:
spacing = [opt_get(opt, ['respaced_timestep_spacing'], opt['beta_schedule']['num_diffusion_timesteps'])]
opt['diffusion_args']['use_timesteps'] = space_timesteps(opt['beta_schedule']['num_diffusion_timesteps'], spacing)
self.diffusion = SpacedDiffusion(**opt['diffusion_args'])
self.sampling_fn = self.diffusion.ddim_sample_loop if use_ddim else self.diffusion.p_sample_loop
self.model_input_keys = opt_get(opt, ['model_input_keys'], [])
2021-06-14 15:14:30 +00:00
self.use_ema_model = opt_get(opt, ['use_ema'], False)
2021-06-21 16:38:07 +00:00
self.noise_style = opt_get(opt, ['noise_type'], 'random') # 'zero', 'fixed' or 'random'
2021-12-13 02:52:21 +00:00
self.multiple_requirement = opt_get(opt, ['multiple_requirement'], 4096)
def forward(self, state):
2021-06-14 15:14:30 +00:00
if self.use_ema_model:
gen = self.env['emas'][self.opt['generator']]
else:
gen = self.env['generators'][self.opt['generator']]
2021-06-11 21:31:10 +00:00
model_inputs = {k: state[v][:self.output_batch_size] for k, v in self.model_input_keys.items()}
gen.eval()
with torch.no_grad():
if 'low_res' in model_inputs.keys():
output_shape = (self.output_batch_size, 3, model_inputs['low_res'].shape[-2] * self.output_scale_factor,
model_inputs['low_res'].shape[-1] * self.output_scale_factor)
dev = model_inputs['low_res'].device
elif 'spectrogram' in model_inputs.keys():
2021-12-13 02:52:21 +00:00
output_shape = (self.output_batch_size, 1, closest_multiple(model_inputs['spectrogram'].shape[-1] * self.output_scale_factor, self.multiple_requirement))
dev = model_inputs['spectrogram'].device
2021-10-17 23:32:46 +00:00
elif 'discrete_spectrogram' in model_inputs.keys():
2021-12-13 02:52:21 +00:00
output_shape = (self.output_batch_size, 1, closest_multiple(model_inputs['discrete_spectrogram'].shape[-1]*1024, self.multiple_requirement))
2021-10-17 23:32:46 +00:00
dev = model_inputs['discrete_spectrogram'].device
else:
raise NotImplementedError
2021-06-21 16:38:07 +00:00
noise = None
if self.noise_style == 'zero':
noise = torch.zeros(output_shape, device=dev)
2021-06-21 16:38:07 +00:00
elif self.noise_style == 'fixed':
if not hasattr(self, 'fixed_noise') or self.fixed_noise.shape != output_shape:
self.fixed_noise = torch.randn(output_shape, device=dev)
2021-06-21 16:38:07 +00:00
noise = self.fixed_noise
2021-10-17 23:32:46 +00:00
gen = self.sampling_fn(gen, output_shape, noise=noise, model_kwargs=model_inputs, progress=True, device=dev)
2021-06-11 21:31:10 +00:00
if self.undo_n1_to_1:
gen = (gen + 1) / 2
return {self.output: gen}