gaussian_diffusion: support fp16
This commit is contained in:
parent
aa7cfd1edf
commit
76f86c0e47
|
@ -2,6 +2,7 @@ import random
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch.cuda.amp import autocast
|
||||||
|
|
||||||
from models.diffusion.gaussian_diffusion import GaussianDiffusion, get_named_beta_schedule
|
from models.diffusion.gaussian_diffusion import GaussianDiffusion, get_named_beta_schedule
|
||||||
from models.diffusion.resample import create_named_schedule_sampler, LossAwareSampler
|
from models.diffusion.resample import create_named_schedule_sampler, LossAwareSampler
|
||||||
|
@ -35,20 +36,21 @@ class GaussianDiffusionInjector(Injector):
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
random.seed(0)
|
random.seed(0)
|
||||||
|
|
||||||
model_inputs = {k: state[v] for k, v in self.model_input_keys.items()}
|
with autocast(enabled=self.env['opt']['fp16']):
|
||||||
t, weights = self.schedule_sampler.sample(hq.shape[0], hq.device)
|
model_inputs = {k: state[v] for k, v in self.model_input_keys.items()}
|
||||||
diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs)
|
t, weights = self.schedule_sampler.sample(hq.shape[0], hq.device)
|
||||||
if isinstance(self.schedule_sampler, LossAwareSampler):
|
diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs)
|
||||||
self.schedule_sampler.update_with_local_losses(t, diffusion_outputs['losses'])
|
if isinstance(self.schedule_sampler, LossAwareSampler):
|
||||||
|
self.schedule_sampler.update_with_local_losses(t, diffusion_outputs['losses'])
|
||||||
|
|
||||||
if len(self.extra_model_output_keys) > 0:
|
if len(self.extra_model_output_keys) > 0:
|
||||||
assert(len(self.extra_model_output_keys) == len(diffusion_outputs['extra_outputs']))
|
assert(len(self.extra_model_output_keys) == len(diffusion_outputs['extra_outputs']))
|
||||||
out = {k: v for k, v in zip(self.extra_model_output_keys, diffusion_outputs['extra_outputs'])}
|
out = {k: v for k, v in zip(self.extra_model_output_keys, diffusion_outputs['extra_outputs'])}
|
||||||
else:
|
else:
|
||||||
out = {}
|
out = {}
|
||||||
out.update({self.output: diffusion_outputs['mse'],
|
out.update({self.output: diffusion_outputs['mse'],
|
||||||
self.output_variational_bounds_key: diffusion_outputs['vb'],
|
self.output_variational_bounds_key: diffusion_outputs['vb'],
|
||||||
self.output_x_start_key: diffusion_outputs['x_start_predicted']})
|
self.output_x_start_key: diffusion_outputs['x_start_predicted']})
|
||||||
|
|
||||||
# Absolutely critical to undo the above seed.
|
# Absolutely critical to undo the above seed.
|
||||||
if not gen.training:
|
if not gen.training:
|
||||||
|
@ -58,36 +60,13 @@ class GaussianDiffusionInjector(Injector):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class AutoregressiveGaussianDiffusionInjector(Injector):
|
def closest_multiple(inp, multiple):
|
||||||
def __init__(self, opt, env):
|
div = inp / multiple
|
||||||
super().__init__(opt, env)
|
mod = inp % multiple
|
||||||
self.generator = opt['generator']
|
if mod == 0:
|
||||||
self.output_variational_bounds_key = opt['out_key_vb_loss']
|
return inp
|
||||||
self.output_x_start_key = opt['out_key_x_start']
|
else:
|
||||||
opt['diffusion_args']['betas'] = get_named_beta_schedule(**opt['beta_schedule'])
|
return (div+1)*multiple
|
||||||
opt['diffusion_args']['use_timesteps'] = space_timesteps(opt['beta_schedule']['num_diffusion_timesteps'],
|
|
||||||
[opt['beta_schedule']['num_diffusion_timesteps']])
|
|
||||||
self.diffusion = SpacedDiffusion(**opt['diffusion_args'])
|
|
||||||
self.schedule_sampler = create_named_schedule_sampler(opt['sampler_type'], self.diffusion)
|
|
||||||
self.model_input_keys = opt_get(opt, ['model_input_keys'], [])
|
|
||||||
self.model_output_keys = opt['model_output_keys']
|
|
||||||
self.model_eps_pred_key = opt['prediction_key']
|
|
||||||
|
|
||||||
def forward(self, state):
|
|
||||||
gen = self.env['generators'][self.opt['generator']]
|
|
||||||
hq = state[self.input]
|
|
||||||
model_inputs = {k: state[v] for k, v in self.model_input_keys.items()}
|
|
||||||
t, weights = self.schedule_sampler.sample(hq.shape[0], hq.device)
|
|
||||||
diffusion_outputs = self.diffusion.autoregressive_training_losses(gen, hq, t, self.model_output_keys,
|
|
||||||
self.model_eps_pred_key,
|
|
||||||
model_kwargs=model_inputs)
|
|
||||||
if isinstance(self.schedule_sampler, LossAwareSampler):
|
|
||||||
self.schedule_sampler.update_with_local_losses(t, diffusion_outputs['losses'])
|
|
||||||
outputs = {k: diffusion_outputs[k] for k in self.model_output_keys}
|
|
||||||
outputs.update({self.output: diffusion_outputs['mse'],
|
|
||||||
self.output_variational_bounds_key: diffusion_outputs['vb'],
|
|
||||||
self.output_x_start_key: diffusion_outputs['x_start_predicted']})
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
|
|
||||||
# Performs inference using a network trained to predict a reverse diffusion process, which nets a image.
|
# Performs inference using a network trained to predict a reverse diffusion process, which nets a image.
|
||||||
|
@ -110,6 +89,7 @@ class GaussianDiffusionInferenceInjector(Injector):
|
||||||
self.model_input_keys = opt_get(opt, ['model_input_keys'], [])
|
self.model_input_keys = opt_get(opt, ['model_input_keys'], [])
|
||||||
self.use_ema_model = opt_get(opt, ['use_ema'], False)
|
self.use_ema_model = opt_get(opt, ['use_ema'], False)
|
||||||
self.noise_style = opt_get(opt, ['noise_type'], 'random') # 'zero', 'fixed' or 'random'
|
self.noise_style = opt_get(opt, ['noise_type'], 'random') # 'zero', 'fixed' or 'random'
|
||||||
|
self.multiple_requirement = opt_get(opt, ['multiple_requirement'], 4096)
|
||||||
|
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
if self.use_ema_model:
|
if self.use_ema_model:
|
||||||
|
@ -124,10 +104,10 @@ class GaussianDiffusionInferenceInjector(Injector):
|
||||||
model_inputs['low_res'].shape[-1] * self.output_scale_factor)
|
model_inputs['low_res'].shape[-1] * self.output_scale_factor)
|
||||||
dev = model_inputs['low_res'].device
|
dev = model_inputs['low_res'].device
|
||||||
elif 'spectrogram' in model_inputs.keys():
|
elif 'spectrogram' in model_inputs.keys():
|
||||||
output_shape = (self.output_batch_size, 1, model_inputs['spectrogram'].shape[-1] * self.output_scale_factor)
|
output_shape = (self.output_batch_size, 1, closest_multiple(model_inputs['spectrogram'].shape[-1] * self.output_scale_factor, self.multiple_requirement))
|
||||||
dev = model_inputs['spectrogram'].device
|
dev = model_inputs['spectrogram'].device
|
||||||
elif 'discrete_spectrogram' in model_inputs.keys():
|
elif 'discrete_spectrogram' in model_inputs.keys():
|
||||||
output_shape = (self.output_batch_size, 1, model_inputs['discrete_spectrogram'].shape[-1]*1024)
|
output_shape = (self.output_batch_size, 1, closest_multiple(model_inputs['discrete_spectrogram'].shape[-1]*1024, self.multiple_requirement))
|
||||||
dev = model_inputs['discrete_spectrogram'].device
|
dev = model_inputs['discrete_spectrogram'].device
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
Loading…
Reference in New Issue
Block a user