2022-06-03 21:19:23 +00:00
import functools
2022-07-08 18:30:05 +00:00
import random
2021-11-24 16:38:10 +00:00
2021-06-03 03:47:32 +00:00
import torch
2021-12-13 02:52:21 +00:00
from torch . cuda . amp import autocast
2021-06-03 03:47:32 +00:00
2022-06-10 03:41:20 +00:00
from models . diffusion . gaussian_diffusion import get_named_beta_schedule
2022-06-10 21:37:02 +00:00
from models . diffusion . resample import create_named_schedule_sampler , LossAwareSampler , DeterministicSampler , LossSecondMomentResampler
2021-06-04 23:13:16 +00:00
from models . diffusion . respace import space_timesteps , SpacedDiffusion
2021-06-03 03:47:32 +00:00
from trainer . inject import Injector
from utils . util import opt_get
2022-06-03 21:19:23 +00:00
def masked_channel_balancer ( inp , proportion = 1 ) :
with torch . no_grad ( ) :
only_channels = inp . mean ( dim = ( 0 , 2 ) ) # Only currently works for audio tensors. Could be retrofitted for 2d (or 3D!) modalities.
dist = only_channels / only_channels . sum ( )
dist_mult = only_channels . shape [ 0 ] * proportion
dist = ( dist * dist_mult ) . clamp ( 0 , 1 )
mask = torch . bernoulli ( dist )
return inp * mask . view ( 1 , inp . shape [ 1 ] , 1 )
2022-06-10 03:41:20 +00:00
def channel_restriction ( inp , low , high ) :
2022-06-10 03:46:32 +00:00
assert low > 0 and low < inp . shape [ 1 ] and high < = inp . shape [ 1 ]
2022-06-10 03:41:20 +00:00
m = torch . zeros_like ( inp )
m [ : , low : high ] = 1
return inp * m
2021-06-03 03:47:32 +00:00
# Injects a gaussian diffusion loss as described by OpenAIs "Improved Denoising Diffusion Probabilistic Models" paper.
# Largely uses OpenAI's own code to do so (all code from models.diffusion.*)
class GaussianDiffusionInjector ( Injector ) :
def __init__ ( self , opt , env ) :
super ( ) . __init__ ( opt , env )
self . generator = opt [ ' generator ' ]
2021-06-04 23:13:16 +00:00
self . output_variational_bounds_key = opt [ ' out_key_vb_loss ' ]
self . output_x_start_key = opt [ ' out_key_x_start ' ]
2021-06-03 03:47:32 +00:00
opt [ ' diffusion_args ' ] [ ' betas ' ] = get_named_beta_schedule ( * * opt [ ' beta_schedule ' ] )
2021-06-06 22:35:37 +00:00
opt [ ' diffusion_args ' ] [ ' use_timesteps ' ] = space_timesteps ( opt [ ' beta_schedule ' ] [ ' num_diffusion_timesteps ' ] ,
[ opt [ ' beta_schedule ' ] [ ' num_diffusion_timesteps ' ] ] )
2021-06-04 23:13:16 +00:00
self . diffusion = SpacedDiffusion ( * * opt [ ' diffusion_args ' ] )
2021-06-03 03:47:32 +00:00
self . schedule_sampler = create_named_schedule_sampler ( opt [ ' sampler_type ' ] , self . diffusion )
self . model_input_keys = opt_get ( opt , [ ' model_input_keys ' ] , [ ] )
2021-09-16 16:53:46 +00:00
self . extra_model_output_keys = opt_get ( opt , [ ' extra_model_output_keys ' ] , [ ] )
2022-03-04 17:40:14 +00:00
self . deterministic_timesteps_every = opt_get ( opt , [ ' deterministic_timesteps_every ' ] , 0 )
2022-03-04 18:50:50 +00:00
self . deterministic_sampler = DeterministicSampler ( self . diffusion , opt_get ( opt , [ ' deterministic_sampler_expected_batch_size ' ] , 2048 ) , env )
2022-07-08 18:30:05 +00:00
self . causal_mode = opt_get ( opt , [ ' causal_mode ' ] , False )
self . causal_slope_range = opt_get ( opt , [ ' causal_slope_range ' ] , [ 1 , 8 ] )
2022-07-18 23:01:19 +00:00
self . preprocess_fn = opt_get ( opt , [ ' preprocess_fn ' ] , None )
2022-05-16 03:50:38 +00:00
2022-06-10 03:41:20 +00:00
k = 0
if ' channel_balancer_proportion ' in opt . keys ( ) :
self . channel_balancing_fn = functools . partial ( masked_channel_balancer , proportion = opt [ ' channel_balancer_proportion ' ] )
k + = 1
if ' channel_restriction_low ' in opt . keys ( ) :
self . channel_balancing_fn = functools . partial ( channel_restriction , low = opt [ ' channel_restriction_low ' ] , high = opt [ ' channel_restriction_high ' ] )
k + = 1
if not hasattr ( self , ' channel_balancing_fn ' ) :
self . channel_balancing_fn = None
assert k < = 1 , ' Only one channel filtering function can be applied. '
2021-06-03 03:47:32 +00:00
2022-06-22 02:04:16 +00:00
self . num_timesteps = opt [ ' beta_schedule ' ] [ ' num_diffusion_timesteps ' ]
self . latest_mse_by_batch = torch . tensor ( [ 0 ] )
self . latest_timesteps = torch . tensor ( [ 0 ] )
2022-06-10 21:23:31 +00:00
def extra_metrics ( self ) :
2022-06-22 02:04:16 +00:00
uqt = self . latest_timesteps > self . num_timesteps * 3 / 4
2022-06-22 02:26:19 +00:00
uql = ( self . latest_mse_by_batch * uqt ) . sum ( ) / uqt . sum ( ) if uqt . sum ( ) != 0 else 0
2022-06-22 02:04:16 +00:00
muqt = ( self . latest_timesteps > self . num_timesteps / 2 ) * ( self . latest_timesteps < self . num_timesteps * 3 / 4 )
2022-06-22 02:26:19 +00:00
muql = ( self . latest_mse_by_batch * muqt ) . sum ( ) / muqt . sum ( ) if muqt . sum ( ) != 0 else 0
2022-06-22 02:04:16 +00:00
d = {
' upper_quantile_mse_loss ' : uql ,
' mid_upper_quantile_mse_loss ' : muql ,
}
2022-06-10 21:23:31 +00:00
if hasattr ( self , ' schedule_sampler ' ) and isinstance ( self . schedule_sampler , LossSecondMomentResampler ) :
2022-06-22 02:04:16 +00:00
d [ ' sampler_warmed_up ' ] = torch . tensor ( float ( self . schedule_sampler . _warmed_up ( ) ) )
return d
2022-06-10 21:23:31 +00:00
2021-06-03 03:47:32 +00:00
def forward ( self , state ) :
gen = self . env [ ' generators ' ] [ self . opt [ ' generator ' ] ]
hq = state [ self . input ]
2022-07-20 16:19:02 +00:00
assert hq . max ( ) < 1.000001 or hq . min ( ) > - 1.00001 , f " Attempting to train gaussian diffusion on un-normalized inputs. This won ' t work, silly! { hq . min ( ) } { hq . max ( ) } "
2021-11-24 16:38:10 +00:00
2021-12-13 02:52:21 +00:00
with autocast ( enabled = self . env [ ' opt ' ] [ ' fp16 ' ] ) :
2022-03-04 17:40:14 +00:00
if not gen . training or ( self . deterministic_timesteps_every != 0 and self . env [ ' step ' ] % self . deterministic_timesteps_every == 0 ) :
2022-03-04 18:50:50 +00:00
sampler = self . deterministic_sampler
2022-03-04 17:40:14 +00:00
else :
sampler = self . schedule_sampler
2022-03-04 18:50:50 +00:00
self . deterministic_sampler . reset ( ) # Keep this reset whenever it is not being used, so it is ready to use automatically.
2022-03-26 14:36:19 +00:00
model_inputs = { k : state [ v ] if isinstance ( v , str ) else v for k , v in self . model_input_keys . items ( ) }
2022-07-18 23:01:19 +00:00
if self . preprocess_fn is not None :
2022-07-20 16:19:02 +00:00
hq = getattr ( gen . module , self . preprocess_fn ) ( hq )
2022-07-18 23:01:19 +00:00
2022-03-04 17:40:14 +00:00
t , weights = sampler . sample ( hq . shape [ 0 ] , hq . device )
2022-07-08 18:30:05 +00:00
if self . causal_mode :
cs , ce = self . causal_slope_range
slope = random . random ( ) * ( ce - cs ) + cs
diffusion_outputs = self . diffusion . causal_training_losses ( gen , hq , t , model_kwargs = model_inputs ,
channel_balancing_fn = self . channel_balancing_fn ,
causal_slope = slope )
else :
diffusion_outputs = self . diffusion . training_losses ( gen , hq , t , model_kwargs = model_inputs ,
channel_balancing_fn = self . channel_balancing_fn )
2022-03-04 17:40:14 +00:00
if isinstance ( sampler , LossAwareSampler ) :
2022-06-10 03:56:47 +00:00
sampler . update_with_local_losses ( t , diffusion_outputs [ ' loss ' ] )
2021-12-13 02:52:21 +00:00
if len ( self . extra_model_output_keys ) > 0 :
assert ( len ( self . extra_model_output_keys ) == len ( diffusion_outputs [ ' extra_outputs ' ] ) )
out = { k : v for k , v in zip ( self . extra_model_output_keys , diffusion_outputs [ ' extra_outputs ' ] ) }
else :
out = { }
out . update ( { self . output : diffusion_outputs [ ' mse ' ] ,
self . output_variational_bounds_key : diffusion_outputs [ ' vb ' ] ,
self . output_x_start_key : diffusion_outputs [ ' x_start_predicted ' ] } )
2022-06-22 02:04:16 +00:00
self . latest_mse_by_batch = diffusion_outputs [ ' mse_by_batch ' ] . detach ( ) . clone ( )
self . latest_timesteps = t . clone ( )
2021-11-24 16:38:10 +00:00
2021-09-16 16:53:46 +00:00
return out
2021-06-03 03:47:32 +00:00
2021-12-13 02:52:21 +00:00
def closest_multiple ( inp , multiple ) :
2021-12-17 03:47:37 +00:00
div = inp / / multiple
2021-12-13 02:52:21 +00:00
mod = inp % multiple
if mod == 0 :
return inp
else :
2021-12-17 03:47:37 +00:00
return int ( ( div + 1 ) * multiple )
2021-07-26 22:27:31 +00:00
2021-06-03 03:47:32 +00:00
# Performs inference using a network trained to predict a reverse diffusion process, which nets a image.
class GaussianDiffusionInferenceInjector ( Injector ) :
def __init__ ( self , opt , env ) :
super ( ) . __init__ ( opt , env )
2021-06-16 22:26:36 +00:00
use_ddim = opt_get ( opt , [ ' use_ddim ' ] , False )
2021-06-03 03:47:32 +00:00
self . generator = opt [ ' generator ' ]
2021-06-11 21:31:10 +00:00
self . output_batch_size = opt [ ' output_batch_size ' ]
self . output_scale_factor = opt [ ' output_scale_factor ' ]
self . undo_n1_to_1 = opt_get ( opt , [ ' undo_n1_to_1 ' ] , False ) # Explanation: when specified, will shift the output of this injector from [-1,1] to [0,1]
2021-06-03 03:47:32 +00:00
opt [ ' diffusion_args ' ] [ ' betas ' ] = get_named_beta_schedule ( * * opt [ ' beta_schedule ' ] )
2021-06-16 22:26:36 +00:00
if use_ddim :
spacing = " ddim " + str ( opt [ ' respaced_timestep_spacing ' ] )
else :
spacing = [ opt_get ( opt , [ ' respaced_timestep_spacing ' ] , opt [ ' beta_schedule ' ] [ ' num_diffusion_timesteps ' ] ) ]
opt [ ' diffusion_args ' ] [ ' use_timesteps ' ] = space_timesteps ( opt [ ' beta_schedule ' ] [ ' num_diffusion_timesteps ' ] , spacing )
2021-06-04 23:13:16 +00:00
self . diffusion = SpacedDiffusion ( * * opt [ ' diffusion_args ' ] )
2021-06-16 22:26:36 +00:00
self . sampling_fn = self . diffusion . ddim_sample_loop if use_ddim else self . diffusion . p_sample_loop
2021-06-03 03:47:32 +00:00
self . model_input_keys = opt_get ( opt , [ ' model_input_keys ' ] , [ ] )
2021-06-14 15:14:30 +00:00
self . use_ema_model = opt_get ( opt , [ ' use_ema ' ] , False )
2021-06-21 16:38:07 +00:00
self . noise_style = opt_get ( opt , [ ' noise_type ' ] , ' random ' ) # 'zero', 'fixed' or 'random'
2021-12-13 02:52:21 +00:00
self . multiple_requirement = opt_get ( opt , [ ' multiple_requirement ' ] , 4096 )
2021-06-03 03:47:32 +00:00
def forward ( self , state ) :
2021-06-14 15:14:30 +00:00
if self . use_ema_model :
gen = self . env [ ' emas ' ] [ self . opt [ ' generator ' ] ]
else :
gen = self . env [ ' generators ' ] [ self . opt [ ' generator ' ] ]
2021-06-11 21:31:10 +00:00
model_inputs = { k : state [ v ] [ : self . output_batch_size ] for k , v in self . model_input_keys . items ( ) }
2021-06-03 03:47:32 +00:00
gen . eval ( )
with torch . no_grad ( ) :
2021-09-01 14:34:47 +00:00
if ' low_res ' in model_inputs . keys ( ) :
output_shape = ( self . output_batch_size , 3 , model_inputs [ ' low_res ' ] . shape [ - 2 ] * self . output_scale_factor ,
model_inputs [ ' low_res ' ] . shape [ - 1 ] * self . output_scale_factor )
dev = model_inputs [ ' low_res ' ] . device
elif ' spectrogram ' in model_inputs . keys ( ) :
2021-12-13 02:52:21 +00:00
output_shape = ( self . output_batch_size , 1 , closest_multiple ( model_inputs [ ' spectrogram ' ] . shape [ - 1 ] * self . output_scale_factor , self . multiple_requirement ) )
2021-09-01 14:34:47 +00:00
dev = model_inputs [ ' spectrogram ' ] . device
2021-10-17 23:32:46 +00:00
elif ' discrete_spectrogram ' in model_inputs . keys ( ) :
2021-12-13 02:52:21 +00:00
output_shape = ( self . output_batch_size , 1 , closest_multiple ( model_inputs [ ' discrete_spectrogram ' ] . shape [ - 1 ] * 1024 , self . multiple_requirement ) )
2021-10-17 23:32:46 +00:00
dev = model_inputs [ ' discrete_spectrogram ' ] . device
2021-09-01 14:34:47 +00:00
else :
raise NotImplementedError
2021-06-21 16:38:07 +00:00
noise = None
if self . noise_style == ' zero ' :
2021-09-01 14:34:47 +00:00
noise = torch . zeros ( output_shape , device = dev )
2021-06-21 16:38:07 +00:00
elif self . noise_style == ' fixed ' :
if not hasattr ( self , ' fixed_noise ' ) or self . fixed_noise . shape != output_shape :
2021-09-01 14:34:47 +00:00
self . fixed_noise = torch . randn ( output_shape , device = dev )
2021-06-21 16:38:07 +00:00
noise = self . fixed_noise
2021-10-17 23:32:46 +00:00
gen = self . sampling_fn ( gen , output_shape , noise = noise , model_kwargs = model_inputs , progress = True , device = dev )
2021-06-11 21:31:10 +00:00
if self . undo_n1_to_1 :
gen = ( gen + 1 ) / 2
2021-06-03 03:47:32 +00:00
return { self . output : gen }