2021-06-03 03:47:32 +00:00
|
|
|
import torch
|
|
|
|
|
|
|
|
from models.diffusion.gaussian_diffusion import GaussianDiffusion, get_named_beta_schedule
|
2021-06-05 19:40:32 +00:00
|
|
|
from models.diffusion.resample import create_named_schedule_sampler, LossAwareSampler
|
2021-06-04 23:13:16 +00:00
|
|
|
from models.diffusion.respace import space_timesteps, SpacedDiffusion
|
2021-06-03 03:47:32 +00:00
|
|
|
from trainer.inject import Injector
|
|
|
|
from utils.util import opt_get
|
|
|
|
|
|
|
|
|
|
|
|
# Injects a gaussian diffusion loss as described by OpenAIs "Improved Denoising Diffusion Probabilistic Models" paper.
|
|
|
|
# Largely uses OpenAI's own code to do so (all code from models.diffusion.*)
|
|
|
|
class GaussianDiffusionInjector(Injector):
|
|
|
|
def __init__(self, opt, env):
|
|
|
|
super().__init__(opt, env)
|
|
|
|
self.generator = opt['generator']
|
2021-06-04 23:13:16 +00:00
|
|
|
self.output_variational_bounds_key = opt['out_key_vb_loss']
|
|
|
|
self.output_x_start_key = opt['out_key_x_start']
|
2021-06-03 03:47:32 +00:00
|
|
|
opt['diffusion_args']['betas'] = get_named_beta_schedule(**opt['beta_schedule'])
|
2021-06-06 22:35:37 +00:00
|
|
|
opt['diffusion_args']['use_timesteps'] = space_timesteps(opt['beta_schedule']['num_diffusion_timesteps'],
|
|
|
|
[opt['beta_schedule']['num_diffusion_timesteps']])
|
2021-06-04 23:13:16 +00:00
|
|
|
self.diffusion = SpacedDiffusion(**opt['diffusion_args'])
|
2021-06-03 03:47:32 +00:00
|
|
|
self.schedule_sampler = create_named_schedule_sampler(opt['sampler_type'], self.diffusion)
|
|
|
|
self.model_input_keys = opt_get(opt, ['model_input_keys'], [])
|
|
|
|
|
|
|
|
def forward(self, state):
|
|
|
|
gen = self.env['generators'][self.opt['generator']]
|
|
|
|
hq = state[self.input]
|
|
|
|
model_inputs = {k: state[v] for k, v in self.model_input_keys.items()}
|
|
|
|
t, weights = self.schedule_sampler.sample(hq.shape[0], hq.device)
|
2021-06-04 23:13:16 +00:00
|
|
|
diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs)
|
2021-06-05 19:40:32 +00:00
|
|
|
if isinstance(self.schedule_sampler, LossAwareSampler):
|
|
|
|
self.schedule_sampler.update_with_local_losses(t, diffusion_outputs['losses'])
|
2021-06-04 23:13:16 +00:00
|
|
|
return {self.output: diffusion_outputs['mse'],
|
|
|
|
self.output_variational_bounds_key: diffusion_outputs['vb'],
|
|
|
|
self.output_x_start_key: diffusion_outputs['x_start_predicted']}
|
2021-06-03 03:47:32 +00:00
|
|
|
|
|
|
|
|
|
|
|
# Performs inference using a network trained to predict a reverse diffusion process, which nets a image.
|
|
|
|
class GaussianDiffusionInferenceInjector(Injector):
|
|
|
|
def __init__(self, opt, env):
|
|
|
|
super().__init__(opt, env)
|
|
|
|
self.generator = opt['generator']
|
|
|
|
self.output_shape = opt['output_shape']
|
|
|
|
opt['diffusion_args']['betas'] = get_named_beta_schedule(**opt['beta_schedule'])
|
2021-06-06 22:35:37 +00:00
|
|
|
opt['diffusion_args']['use_timesteps'] = space_timesteps(opt['beta_schedule']['num_diffusion_timesteps'],
|
|
|
|
[opt_get(opt, ['respaced_timestep_spacing'], opt['beta_schedule']['num_diffusion_timesteps'])])
|
2021-06-04 23:13:16 +00:00
|
|
|
self.diffusion = SpacedDiffusion(**opt['diffusion_args'])
|
2021-06-03 03:47:32 +00:00
|
|
|
self.model_input_keys = opt_get(opt, ['model_input_keys'], [])
|
|
|
|
|
|
|
|
def forward(self, state):
|
|
|
|
gen = self.env['generators'][self.opt['generator']]
|
|
|
|
batch_size = self.output_shape[0]
|
|
|
|
model_inputs = {k: state[v][:batch_size] for k, v in self.model_input_keys.items()}
|
|
|
|
gen.eval()
|
|
|
|
with torch.no_grad():
|
|
|
|
gen = self.diffusion.p_sample_loop(gen, self.output_shape, model_kwargs=model_inputs)
|
|
|
|
return {self.output: gen}
|