diff --git a/codes/models/audio/music/tfdpc_v5.py b/codes/models/audio/music/tfdpc_v5.py index 38f2a9a4..07862598 100644 --- a/codes/models/audio/music/tfdpc_v5.py +++ b/codes/models/audio/music/tfdpc_v5.py @@ -81,11 +81,14 @@ class ConditioningEncoder(nn.Module): attn_blocks=6, num_attn_heads=8, dropout=.1, - do_checkpointing=False): + do_checkpointing=False, + time_proj=True): super().__init__() attn = [] self.init = nn.Conv1d(cond_dim, embedding_dim, kernel_size=1) - self.time_proj = nn.Linear(time_embed_dim, embedding_dim) + self.time_proj = time_proj + if time_proj: + self.time_proj = nn.Linear(time_embed_dim, embedding_dim) self.attn = Encoder( dim=embedding_dim, depth=attn_blocks, @@ -103,8 +106,9 @@ class ConditioningEncoder(nn.Module): def forward(self, x, time_emb): h = self.init(x).permute(0,2,1) - time_enc = self.time_proj(time_emb) - h = torch.cat([time_enc.unsqueeze(1), h], dim=1) + if self.time_proj: + time_enc = self.time_proj(time_emb) + h = torch.cat([time_enc.unsqueeze(1), h], dim=1) h = self.attn(h).permute(0,2,1) return h @@ -125,6 +129,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module): input_cond_dim=1024, num_heads=8, dropout=0, + time_proj=False, use_fp16=False, checkpoint_conditioning=True, # This will need to be false for DDP training. :( # Parameters for regularization. @@ -141,7 +146,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module): self.enable_fp16 = use_fp16 self.inp_block = conv_nd(1, in_channels, model_channels, 3, 1, 1) - self.conditioning_encoder = ConditioningEncoder(256, model_channels, time_embed_dim, do_checkpointing=checkpoint_conditioning) + self.conditioning_encoder = ConditioningEncoder(256, model_channels, time_embed_dim, do_checkpointing=checkpoint_conditioning, time_proj=time_proj) self.time_embed = nn.Sequential( linear(time_embed_dim, time_embed_dim), @@ -210,11 +215,12 @@ class TransformerDiffusionWithPointConditioning(nn.Module): cond_left = cond_aligned[:,:,:break_pt] cond_right = cond_aligned[:,:,break_pt:] - # Drop out a random amount of the aligned data. The network will need to figure out how to reconstruct this. - to_remove_left = random.randint(1, cond_left.shape[-1]-MIN_MARGIN) - cond_left = cond_left[:,:,:-to_remove_left] - to_remove_right = random.randint(1, cond_right.shape[-1]-MIN_MARGIN) - cond_right = cond_right[:,:,to_remove_right:] + if self.training: + # Drop out a random amount of the aligned data. The network will need to figure out how to reconstruct this. + to_remove_left = random.randint(1, cond_left.shape[-1]-MIN_MARGIN) + cond_left = cond_left[:,:,:-to_remove_left] + to_remove_right = random.randint(1, cond_right.shape[-1]-MIN_MARGIN) + cond_right = cond_right[:,:,to_remove_right:] # Concatenate the _pre and _post back on. cond_left_full = torch.cat([cond_pre, cond_left], dim=-1) diff --git a/codes/models/diffusion/gaussian_diffusion.py b/codes/models/diffusion/gaussian_diffusion.py index 220a79ef..cd9dc208 100644 --- a/codes/models/diffusion/gaussian_diffusion.py +++ b/codes/models/diffusion/gaussian_diffusion.py @@ -7,14 +7,15 @@ Docstrings have been added, as well as DDIM sampling and a new collection of bet import enum import math +import random import numpy as np import torch import torch as th from tqdm import tqdm -from .nn import mean_flat -from .losses import normal_kl, discretized_gaussian_log_likelihood +from models.diffusion.nn import mean_flat +from models.diffusion.losses import normal_kl, discretized_gaussian_log_likelihood def causal_timestep_adjustment(t, S, num_timesteps, causal_slope=1, add_jitter=True): @@ -37,43 +38,18 @@ def causal_timestep_adjustment(t, S, num_timesteps, causal_slope=1, add_jitter=T # This algorithm for adding causality does so by simply adding S_sloped additional timesteps. To make this # actually work, we map the existing t from the timescale specified to the model to the causal timescale: adj_t = t * (num_timesteps + S_sloped) // num_timesteps + adj_t = adj_t - S_sloped if add_jitter: - jitter = (random.random() - .5) * S_sloped - adj_t = (adj_t+jitter).clamp(0, num_timesteps+S_sloped) + t_gap = (num_timesteps + S_sloped) / num_timesteps + jitter = (2*random.random()-1) * t_gap + adj_t = (adj_t+jitter).clamp(-S_sloped, num_timesteps) # Now use the re-mapped adj_t to create a timestep vector that propagates across the sequence with the specified slope. t = adj_t.unsqueeze(1).repeat(1, S) - t = (t - torch.arange(0, S) * causal_slope).clamp(0, num_timesteps).long() + t = (t + torch.arange(0, S, device=t.device) * causal_slope).clamp(0, num_timesteps).long() return t -def graph_causal_timestep_adjustment(): - S = 400 - slope=4 - num_timesteps=4000 - #for num_timesteps in range(100, 4000, 200): - t_res = [] - for t in range(num_timesteps, -1, -num_timesteps//50): - T = causal_timestep_adjustment(torch.tensor([t]), S, num_timesteps, causal_slope=slope, add_jitter=False)[0] - t_res.append(T) - plt.plot(T.numpy()) - plt.ylim(0,4000) - plt.xlim(0,500) - plt.savefig(f'{t}.png') - plt.clf() - - for i in range(len(t_res)): - for j in range(len(t_res)): - if i == j: - continue - #assert not torch.all(t_res[i] == t_res[j]) - plt.ylim(0,4000) - plt.xlim(0,500) - plt.ylabel('timestep') - plt.savefig(f'{num_timesteps}.png') - plt.clf() - - def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): """ Get a pre-defined beta schedule for the given name. @@ -319,7 +295,7 @@ class GaussianDiffusion: model_kwargs = {} B, C = x.shape[:2] - assert t.shape == (B,) + assert t.shape == (B,) or t.shape == (B,1,x.shape[-1]) model_output = model(x, self._scale_timesteps(t), **model_kwargs) if self.conditioning_free: model_output_no_conditioning = model(x, self._scale_timesteps(t), conditioning_free=True, **model_kwargs) @@ -844,7 +820,10 @@ class GaussianDiffusion: # At the first timestep return the decoder NLL, # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) - output = th.where((t == 0).view(-1, 1, 1), decoder_nll, kl) + if len(t.shape) == 1: + output = th.where((t == 0).view(-1, 1, 1), decoder_nll, kl) + else: + output = th.where((t == 0), decoder_nll, kl) return {"output": output, "pred_xstart": out["pred_xstart"]} def causal_training_losses(self, model, x_start, t, causal_slope=1, model_kwargs=None, noise=None, channel_balancing_fn=None): @@ -852,8 +831,9 @@ class GaussianDiffusion: Compute training losses for a causal diffusion process. """ assert len(x_start.shape) == 3, "causal_training_losses assumes a 1d sequence with the axis being the time axis." - t = causal_timestep_adjustment(t, x_start.shape[-1], self.num_timesteps, causal_slope, add_jitter=True) - return self.training_losses(model, x_start, t, model_kwargs, noise, channel_balancing_fn) + ct = causal_timestep_adjustment(t, x_start.shape[-1], self.num_timesteps, causal_slope, add_jitter=True) + ct = ct.unsqueeze(1) # Necessary to make the output shape compatible with x_start. + return self.training_losses(model, x_start, ct, model_kwargs, noise, channel_balancing_fn) def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, channel_balancing_fn=None): """ @@ -872,6 +852,13 @@ class GaussianDiffusion: model_kwargs = {} if noise is None: noise = th.randn_like(x_start) + + if len(t.shape) == 3: + t_mask = t != self.num_timesteps + t[t_mask.logical_not()] = self.num_timesteps-1 + else: + t_mask = torch.ones_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise) terms = {} @@ -912,7 +899,7 @@ class GaussianDiffusion: x_t=x_t, t=t, clip_denoised=False, - )["output"] + )["output"] * t_mask if self.loss_type == LossType.RESCALED_MSE: # Divide by 1000 for equivalence with initial implementation. # Without a factor of 1/1000, the VB term hurts the MSE term. @@ -932,7 +919,7 @@ class GaussianDiffusion: else: raise NotImplementedError(self.model_mean_type) assert model_output.shape == target.shape == x_start.shape - s_err = (target - model_output) ** 2 + s_err = t_mask * (target - model_output) ** 2 if channel_balancing_fn is not None: s_err = channel_balancing_fn(s_err) terms["mse_by_batch"] = s_err.reshape(s_err.shape[0], -1).mean(dim=1) @@ -1039,3 +1026,47 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape): while len(res.shape) < len(broadcast_shape): res = res[..., None] return res.expand(broadcast_shape) + + +def test_causal_training_losses(): + from models.diffusion.respace import SpacedDiffusion + from models.diffusion.respace import space_timesteps + diff = SpacedDiffusion(use_timesteps=space_timesteps(4000, [4000]), model_mean_type='epsilon', + model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', 4000), + conditioning_free=False, conditioning_free_k=1) + class IdentityTwoArg(torch.nn.Module): + def __init__(self): + super().__init__() + def forward(self, x, *args, **kwargs): + return x.repeat(1,2,1) + + model = IdentityTwoArg() + diff.causal_training_losses(model, torch.randn(4,256,400), torch.tensor([500,1000,3000,3500]), causal_slope=4) + +def graph_causal_timestep_adjustment(): + import matplotlib.pyplot as plt + S = 400 + #slope=4 + num_timesteps=4000 + for slpe in range(10, 400, 10): + slope = slpe / 10 + t_res = [] + for t in range(num_timesteps, -1, -num_timesteps//50): + T = causal_timestep_adjustment(torch.tensor([t]), S, num_timesteps, causal_slope=slope, add_jitter=False)[0] + t_res.append(T) + plt.plot(T.numpy()) + + for i in range(len(t_res)): + for j in range(len(t_res)): + if i == j: + continue + #assert not torch.all(t_res[i] == t_res[j]) + plt.ylim(0,num_timesteps) + plt.xlim(0,4000) + plt.ylabel('timestep') + plt.savefig(f'{slpe}.png') + plt.clf() + +if __name__ == '__main__': + #test_causal_training_losses() + graph_causal_timestep_adjustment() \ No newline at end of file diff --git a/codes/models/diffusion/nn.py b/codes/models/diffusion/nn.py index 169bc0d9..50203d75 100644 --- a/codes/models/diffusion/nn.py +++ b/codes/models/diffusion/nn.py @@ -122,7 +122,10 @@ def timestep_embedding(timesteps, dim, max_period=10000): freqs = th.exp( -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half ).to(device=timesteps.device) - args = timesteps[:, None].float() * freqs[None] + if len(timesteps.shape) == 1: + args = timesteps[:, None].float() * freqs[None] + else: + args = (timesteps.float() * freqs.view(1,half,1)).permute(0,2,1) embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) if dim % 2: embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) diff --git a/codes/models/lucidrains/x_transformers.py b/codes/models/lucidrains/x_transformers.py index 02ef7d59..895f60fe 100644 --- a/codes/models/lucidrains/x_transformers.py +++ b/codes/models/lucidrains/x_transformers.py @@ -365,8 +365,11 @@ class RMSScaleShiftNorm(nn.Module): norm = x / norm.clamp(min=self.eps) * self.g ss_emb = self.scale_shift_process(norm_scale_shift_inp) - scale, shift = torch.chunk(ss_emb, 2, dim=1) - h = norm * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + scale, shift = torch.chunk(ss_emb, 2, dim=-1) + if len(scale.shape) == 2: + scale = scale.unsqueeze(1) + shift = shift.unsqueeze(1) + h = norm * (1 + scale) + shift return h diff --git a/codes/train.py b/codes/train.py index 891cf066..42027511 100644 --- a/codes/train.py +++ b/codes/train.py @@ -339,7 +339,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_tfd12_finetune_ar_outputs.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_cheater_gen.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') args = parser.parse_args() opt = option.parse(args.opt, is_train=True) diff --git a/codes/trainer/injectors/gaussian_diffusion_injector.py b/codes/trainer/injectors/gaussian_diffusion_injector.py index 692bbf46..8217043e 100644 --- a/codes/trainer/injectors/gaussian_diffusion_injector.py +++ b/codes/trainer/injectors/gaussian_diffusion_injector.py @@ -1,4 +1,5 @@ import functools +import random import torch from torch.cuda.amp import autocast @@ -44,6 +45,8 @@ class GaussianDiffusionInjector(Injector): self.extra_model_output_keys = opt_get(opt, ['extra_model_output_keys'], []) self.deterministic_timesteps_every = opt_get(opt, ['deterministic_timesteps_every'], 0) self.deterministic_sampler = DeterministicSampler(self.diffusion, opt_get(opt, ['deterministic_sampler_expected_batch_size'], 2048), env) + self.causal_mode = opt_get(opt, ['causal_mode'], False) + self.causal_slope_range = opt_get(opt, ['causal_slope_range'], [1,8]) k = 0 if 'channel_balancer_proportion' in opt.keys(): @@ -86,7 +89,16 @@ class GaussianDiffusionInjector(Injector): self.deterministic_sampler.reset() # Keep this reset whenever it is not being used, so it is ready to use automatically. model_inputs = {k: state[v] if isinstance(v, str) else v for k, v in self.model_input_keys.items()} t, weights = sampler.sample(hq.shape[0], hq.device) - diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs, channel_balancing_fn=self.channel_balancing_fn) + if self.causal_mode: + cs, ce = self.causal_slope_range + slope = random.random() * (ce-cs) + cs + diffusion_outputs = self.diffusion.causal_training_losses(gen, hq, t, model_kwargs=model_inputs, + channel_balancing_fn=self.channel_balancing_fn, + causal_slope=slope) + else: + diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs, + channel_balancing_fn=self.channel_balancing_fn) + if isinstance(sampler, LossAwareSampler): sampler.update_with_local_losses(t, diffusion_outputs['loss']) if len(self.extra_model_output_keys) > 0: