From 82bd62019f603b8411e17bd8effba9bbbd93775e Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 20 Jul 2022 11:54:09 -0600 Subject: [PATCH] diffuse the cascaded prior for continuous sr model --- .../audio/music/transformer_diffusion13.py | 73 +++++++++++++------ .../injectors/gaussian_diffusion_injector.py | 6 +- 2 files changed, 55 insertions(+), 24 deletions(-) diff --git a/codes/models/audio/music/transformer_diffusion13.py b/codes/models/audio/music/transformer_diffusion13.py index 05ff8fe4..fe47cef9 100644 --- a/codes/models/audio/music/transformer_diffusion13.py +++ b/codes/models/audio/music/transformer_diffusion13.py @@ -1,10 +1,10 @@ import itertools +import random from random import randrange import torch import torch.nn as nn import torch.nn.functional as F -import torchvision.utils from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock, build_local_attention_mask from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear @@ -88,7 +88,7 @@ class ConditioningEncoder(nn.Module): def forward(self, x, resolution): h = self.init(x) + self.resolution_embedding(resolution).unsqueeze(-1) h = self.attn(h) - return h[:, :, :6] + return h[:, :, :5] class TransformerDiffusion(nn.Module): @@ -130,10 +130,14 @@ class TransformerDiffusion(nn.Module): nn.SiLU(), linear(time_embed_dim, model_channels), ) + self.prior_time_embed = nn.Sequential( + linear(time_embed_dim, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, model_channels), + ) self.resolution_embed = nn.Embedding(resolution_steps, model_channels) self.conditioning_encoder = ConditioningEncoder(in_channels, model_channels, resolution_steps, num_attn_heads=model_channels//64) - self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,6)) - self.unconditioned_prior = nn.Parameter(torch.zeros(1,in_channels,1)) + self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,5)) self.inp_block = conv_nd(1, in_channels+input_vec_dim, model_channels, 3, 1, 1) self.layers = TimestepEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, num_heads, dropout) for _ in range(num_layers)]) @@ -169,7 +173,7 @@ class TransformerDiffusion(nn.Module): } return groups - def input_to_random_resolution_and_window(self, x): + def input_to_random_resolution_and_window(self, x, ts, diffuser): """ This function MUST be applied to the target *before* noising. It returns the reduced, re-scoped target as well as caches an internal prior for the rescoped target which will be used in training. @@ -185,26 +189,47 @@ class TransformerDiffusion(nn.Module): s = s[:,:,start:start+self.max_window] s_prior = F.interpolate(s, scale_factor=.25, mode='nearest') s_prior = F.interpolate(s_prior, size=(s.shape[-1],), mode='linear', align_corners=True) - self.preprocessed = (s_prior, torch.tensor([resolution] * x.shape[0], dtype=torch.long, device=x.device)) + + # Now diffuse the prior randomly between the x timestep and 0. + adv = torch.rand_like(ts.float()) + t_prior = (adv * ts).long() + s_prior_diffused = diffuser.q_sample(s_prior, t_prior, torch.randn_like(s_prior)) + + self.preprocessed = (s_prior_diffused, t_prior, torch.tensor([resolution] * x.shape[0], dtype=torch.long, device=x.device)) return s - def forward(self, x, timesteps, x_prior=None, resolution=None, conditioning_input=None, conditioning_free=False): + def forward(self, x, timesteps, prior_timesteps=None, x_prior=None, resolution=None, conditioning_input=None, conditioning_free=False): + """ + Predicts the previous diffusion timestep of x, given a partially diffused low-resolution prior and a conditioning + input. + + All parameters are optional because during training, input_to_random_resolution_and_window is used by a training + harness to preformat the inputs and fill in the parameters as state variables. + + Args: + x: Prediction prior. + timesteps: Number of timesteps x has been diffused for. + prior_timesteps: Number of timesteps x_prior has been diffused for. Must be <= timesteps for each batch element. + x_prior: A low-resolution prior that guides the model. + resolution: Integer indicating the operating resolution level. '0' is the highest resolution. + conditioning_input: A semi-related (un-aligned) conditioning input which is used to guide diffusion. Similar to a class input, but hooked to a learned conditioning encoder. + conditioning_free: Whether or not to ignore the conditioning input. + """ conditioning_input = x_prior if conditioning_input is None else conditioning_input - h = x if resolution is None: # This is assumed to be training. assert self.preprocessed is not None, 'Preprocessing function not called.' assert x_prior is None, 'Provided prior will not be used, instead preprocessing output will be used.' - h_sub, resolution = self.preprocessed + x_prior, prior_timesteps, resolution = self.preprocessed self.preprocessed = None else: - assert h.shape[-1] > x_prior.shape[-1] * 3.9, f'{h.shape} {x_prior.shape}' - h_sub = F.interpolate(x_prior, size=(x.shape[-1],), mode='linear', align_corners=True) + assert x.shape[-1] > x_prior.shape[-1] * 3.9, f'{x.shape} {x_prior.shape}' + x_prior = F.interpolate(x_prior, size=(x.shape[-1],), mode='linear', align_corners=True) + assert torch.all(timesteps - prior_timesteps > 0), f'Prior timesteps should always be lower (more resolved) than input timesteps. {timesteps}, {prior_timesteps}' if conditioning_free: - h_sub = self.unconditioned_prior.repeat(x.shape[0], 1, x.shape[-1]) - code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1]) + code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, 1) else: MIN_COND_LEN = 200 MAX_COND_LEN = 1200 @@ -217,17 +242,17 @@ class TransformerDiffusion(nn.Module): # Mask out the conditioning input and x_prior inputs for whole batch elements, implementing something similar to classifier-free guidance. if self.training and self.unconditioned_percentage > 0: - unconditioned_batches = torch.rand((h.shape[0], 1, 1), - device=h.device) < self.unconditioned_percentage - h_sub = torch.where(unconditioned_batches, self.unconditioned_prior.repeat(h_sub.shape[0], 1, h_sub.shape[-1]), h_sub) + unconditioned_batches = torch.rand((x.shape[0], 1, 1), + device=x.device) < self.unconditioned_percentage code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(code_emb.shape[0], 1, 1), code_emb) with torch.autocast(x.device.type, enabled=self.enable_fp16): time_emb = self.time_embed(timestep_embedding(timesteps, self.time_embed_dim)) + prior_time_emb = self.prior_time_embed(timestep_embedding(prior_timesteps, self.time_embed_dim)) res_emb = self.resolution_embed(resolution) - blk_emb = torch.cat([time_emb.unsqueeze(-1), res_emb.unsqueeze(-1), code_emb], dim=-1) + blk_emb = torch.cat([time_emb.unsqueeze(-1), prior_time_emb.unsqueeze(-1), res_emb.unsqueeze(-1), code_emb], dim=-1) - h = torch.cat([h, h_sub], dim=1) + h = torch.cat([x, x_prior], dim=1) h = self.inp_block(h) for layer in self.layers: h = checkpoint(layer, h, blk_emb) @@ -236,7 +261,7 @@ class TransformerDiffusion(nn.Module): out = self.out(h) # Defensively involve probabilistic or possibly unused parameters in loss so we don't get DDP errors. - unused_params = [self.unconditioned_prior, self.unconditioned_embedding] + unused_params = [self.unconditioned_embedding] extraneous_addition = 0 for p in unused_params: extraneous_addition = extraneous_addition + p.mean() @@ -251,6 +276,12 @@ def register_transformer_diffusion13(opt_net, opt): def test_tfd(): + from models.diffusion.respace import SpacedDiffusion + from models.diffusion.respace import space_timesteps + from models.diffusion.gaussian_diffusion import get_named_beta_schedule + diffuser = SpacedDiffusion(use_timesteps=space_timesteps(4000, [4000]), model_mean_type='epsilon', + model_var_type='learned_range', loss_type='mse', + betas=get_named_beta_schedule('linear', 4000)) clip = torch.randn(2,256,10336) cond = torch.randn(2,256,10336) ts = torch.LongTensor([600, 600]) @@ -258,8 +289,8 @@ def test_tfd(): num_heads=512//64, input_vec_dim=256, num_layers=12, dropout=.1, unconditioned_percentage=.6) for k in range(100): - x = model.input_to_random_resolution_and_window(clip, x_prior=clip) - model(x, ts, clip) + x = model.input_to_random_resolution_and_window(clip, ts, diffuser) + model(x, ts, conditioning_input=cond) def remove_conditioning(sd_path): diff --git a/codes/trainer/injectors/gaussian_diffusion_injector.py b/codes/trainer/injectors/gaussian_diffusion_injector.py index ca364d90..e2a8865f 100644 --- a/codes/trainer/injectors/gaussian_diffusion_injector.py +++ b/codes/trainer/injectors/gaussian_diffusion_injector.py @@ -89,10 +89,10 @@ class GaussianDiffusionInjector(Injector): sampler = self.schedule_sampler self.deterministic_sampler.reset() # Keep this reset whenever it is not being used, so it is ready to use automatically. model_inputs = {k: state[v] if isinstance(v, str) else v for k, v in self.model_input_keys.items()} - if self.preprocess_fn is not None: - hq = getattr(gen.module, self.preprocess_fn)(hq) - t, weights = sampler.sample(hq.shape[0], hq.device) + + if self.preprocess_fn is not None: + hq = getattr(gen.module, self.preprocess_fn)(hq, t, self.diffusion) if self.causal_mode: cs, ce = self.causal_slope_range slope = random.random() * (ce-cs) + cs