From 76f86c0e47f07123c07794a26112d9914219235f Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 12 Dec 2021 19:52:21 -0700 Subject: [PATCH] gaussian_diffusion: support fp16 --- .../injectors/gaussian_diffusion_injector.py | 70 +++++++------------ 1 file changed, 25 insertions(+), 45 deletions(-) diff --git a/codes/trainer/injectors/gaussian_diffusion_injector.py b/codes/trainer/injectors/gaussian_diffusion_injector.py index af6ff70e..28f2e2df 100644 --- a/codes/trainer/injectors/gaussian_diffusion_injector.py +++ b/codes/trainer/injectors/gaussian_diffusion_injector.py @@ -2,6 +2,7 @@ import random import time import torch +from torch.cuda.amp import autocast from models.diffusion.gaussian_diffusion import GaussianDiffusion, get_named_beta_schedule from models.diffusion.resample import create_named_schedule_sampler, LossAwareSampler @@ -35,20 +36,21 @@ class GaussianDiffusionInjector(Injector): torch.manual_seed(0) random.seed(0) - model_inputs = {k: state[v] for k, v in self.model_input_keys.items()} - t, weights = self.schedule_sampler.sample(hq.shape[0], hq.device) - diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs) - if isinstance(self.schedule_sampler, LossAwareSampler): - self.schedule_sampler.update_with_local_losses(t, diffusion_outputs['losses']) + with autocast(enabled=self.env['opt']['fp16']): + model_inputs = {k: state[v] for k, v in self.model_input_keys.items()} + t, weights = self.schedule_sampler.sample(hq.shape[0], hq.device) + diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs) + if isinstance(self.schedule_sampler, LossAwareSampler): + self.schedule_sampler.update_with_local_losses(t, diffusion_outputs['losses']) - if len(self.extra_model_output_keys) > 0: - assert(len(self.extra_model_output_keys) == len(diffusion_outputs['extra_outputs'])) - out = {k: v for k, v in zip(self.extra_model_output_keys, diffusion_outputs['extra_outputs'])} - else: - out = {} - out.update({self.output: diffusion_outputs['mse'], - self.output_variational_bounds_key: diffusion_outputs['vb'], - self.output_x_start_key: diffusion_outputs['x_start_predicted']}) + if len(self.extra_model_output_keys) > 0: + assert(len(self.extra_model_output_keys) == len(diffusion_outputs['extra_outputs'])) + out = {k: v for k, v in zip(self.extra_model_output_keys, diffusion_outputs['extra_outputs'])} + else: + out = {} + out.update({self.output: diffusion_outputs['mse'], + self.output_variational_bounds_key: diffusion_outputs['vb'], + self.output_x_start_key: diffusion_outputs['x_start_predicted']}) # Absolutely critical to undo the above seed. if not gen.training: @@ -58,36 +60,13 @@ class GaussianDiffusionInjector(Injector): return out -class AutoregressiveGaussianDiffusionInjector(Injector): - def __init__(self, opt, env): - super().__init__(opt, env) - self.generator = opt['generator'] - self.output_variational_bounds_key = opt['out_key_vb_loss'] - self.output_x_start_key = opt['out_key_x_start'] - opt['diffusion_args']['betas'] = get_named_beta_schedule(**opt['beta_schedule']) - opt['diffusion_args']['use_timesteps'] = space_timesteps(opt['beta_schedule']['num_diffusion_timesteps'], - [opt['beta_schedule']['num_diffusion_timesteps']]) - self.diffusion = SpacedDiffusion(**opt['diffusion_args']) - self.schedule_sampler = create_named_schedule_sampler(opt['sampler_type'], self.diffusion) - self.model_input_keys = opt_get(opt, ['model_input_keys'], []) - self.model_output_keys = opt['model_output_keys'] - self.model_eps_pred_key = opt['prediction_key'] - - def forward(self, state): - gen = self.env['generators'][self.opt['generator']] - hq = state[self.input] - model_inputs = {k: state[v] for k, v in self.model_input_keys.items()} - t, weights = self.schedule_sampler.sample(hq.shape[0], hq.device) - diffusion_outputs = self.diffusion.autoregressive_training_losses(gen, hq, t, self.model_output_keys, - self.model_eps_pred_key, - model_kwargs=model_inputs) - if isinstance(self.schedule_sampler, LossAwareSampler): - self.schedule_sampler.update_with_local_losses(t, diffusion_outputs['losses']) - outputs = {k: diffusion_outputs[k] for k in self.model_output_keys} - outputs.update({self.output: diffusion_outputs['mse'], - self.output_variational_bounds_key: diffusion_outputs['vb'], - self.output_x_start_key: diffusion_outputs['x_start_predicted']}) - return outputs +def closest_multiple(inp, multiple): + div = inp / multiple + mod = inp % multiple + if mod == 0: + return inp + else: + return (div+1)*multiple # Performs inference using a network trained to predict a reverse diffusion process, which nets a image. @@ -110,6 +89,7 @@ class GaussianDiffusionInferenceInjector(Injector): self.model_input_keys = opt_get(opt, ['model_input_keys'], []) self.use_ema_model = opt_get(opt, ['use_ema'], False) self.noise_style = opt_get(opt, ['noise_type'], 'random') # 'zero', 'fixed' or 'random' + self.multiple_requirement = opt_get(opt, ['multiple_requirement'], 4096) def forward(self, state): if self.use_ema_model: @@ -124,10 +104,10 @@ class GaussianDiffusionInferenceInjector(Injector): model_inputs['low_res'].shape[-1] * self.output_scale_factor) dev = model_inputs['low_res'].device elif 'spectrogram' in model_inputs.keys(): - output_shape = (self.output_batch_size, 1, model_inputs['spectrogram'].shape[-1] * self.output_scale_factor) + output_shape = (self.output_batch_size, 1, closest_multiple(model_inputs['spectrogram'].shape[-1] * self.output_scale_factor, self.multiple_requirement)) dev = model_inputs['spectrogram'].device elif 'discrete_spectrogram' in model_inputs.keys(): - output_shape = (self.output_batch_size, 1, model_inputs['discrete_spectrogram'].shape[-1]*1024) + output_shape = (self.output_batch_size, 1, closest_multiple(model_inputs['discrete_spectrogram'].shape[-1]*1024, self.multiple_requirement)) dev = model_inputs['discrete_spectrogram'].device else: raise NotImplementedError