From 5f575b5d3c8d73539b18cd1db4e0f88eb6244b33 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 7 Jul 2022 15:17:47 -0600 Subject: [PATCH] Add causal timestep adjustments --- codes/models/diffusion/gaussian_diffusion.py | 65 ++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/codes/models/diffusion/gaussian_diffusion.py b/codes/models/diffusion/gaussian_diffusion.py index 9c17e951..220a79ef 100644 --- a/codes/models/diffusion/gaussian_diffusion.py +++ b/codes/models/diffusion/gaussian_diffusion.py @@ -17,6 +17,63 @@ from .nn import mean_flat from .losses import normal_kl, discretized_gaussian_log_likelihood +def causal_timestep_adjustment(t, S, num_timesteps, causal_slope=1, add_jitter=True): + """ + Remaps [t] from a batch of integers into a causal sequence [S] long where each sequence element is [causal_slope] + timesteps advanced from the previous sequence element. At t=0, the sequence is all 0s and at t=[num_timesteps], the + sequence is all [num_timesteps]. + + As a result of the last property, longer sequences will have larger "gaps" between them in continuous space. This must + be considered at inference time. + :param t: Batched timestep integers + :param S: Sequence length. + :param num_timesteps: Number of total timesteps. + :param causal_slope: The causal slope. Ex: "2" means each sequence element will be 2 timesteps ahead of its predecessor. + :param add_jitter: Whether or not to add random jitter into the extra gaps between timesteps added by this function. + Should be true for training and false for inference. + :return: [b,S] sequence of timestep integers. + """ + S_sloped = causal_slope * (S-1) + # This algorithm for adding causality does so by simply adding S_sloped additional timesteps. To make this + # actually work, we map the existing t from the timescale specified to the model to the causal timescale: + adj_t = t * (num_timesteps + S_sloped) // num_timesteps + if add_jitter: + jitter = (random.random() - .5) * S_sloped + adj_t = (adj_t+jitter).clamp(0, num_timesteps+S_sloped) + + # Now use the re-mapped adj_t to create a timestep vector that propagates across the sequence with the specified slope. + t = adj_t.unsqueeze(1).repeat(1, S) + t = (t - torch.arange(0, S) * causal_slope).clamp(0, num_timesteps).long() + return t + + +def graph_causal_timestep_adjustment(): + S = 400 + slope=4 + num_timesteps=4000 + #for num_timesteps in range(100, 4000, 200): + t_res = [] + for t in range(num_timesteps, -1, -num_timesteps//50): + T = causal_timestep_adjustment(torch.tensor([t]), S, num_timesteps, causal_slope=slope, add_jitter=False)[0] + t_res.append(T) + plt.plot(T.numpy()) + plt.ylim(0,4000) + plt.xlim(0,500) + plt.savefig(f'{t}.png') + plt.clf() + + for i in range(len(t_res)): + for j in range(len(t_res)): + if i == j: + continue + #assert not torch.all(t_res[i] == t_res[j]) + plt.ylim(0,4000) + plt.xlim(0,500) + plt.ylabel('timestep') + plt.savefig(f'{num_timesteps}.png') + plt.clf() + + def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): """ Get a pre-defined beta schedule for the given name. @@ -790,6 +847,14 @@ class GaussianDiffusion: output = th.where((t == 0).view(-1, 1, 1), decoder_nll, kl) return {"output": output, "pred_xstart": out["pred_xstart"]} + def causal_training_losses(self, model, x_start, t, causal_slope=1, model_kwargs=None, noise=None, channel_balancing_fn=None): + """ + Compute training losses for a causal diffusion process. + """ + assert len(x_start.shape) == 3, "causal_training_losses assumes a 1d sequence with the axis being the time axis." + t = causal_timestep_adjustment(t, x_start.shape[-1], self.num_timesteps, causal_slope, add_jitter=True) + return self.training_losses(model, x_start, t, model_kwargs, noise, channel_balancing_fn) + def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, channel_balancing_fn=None): """ Compute training losses for a single timestep.