From 8657d4d0604d140264194ad9e071453b0a8277d9 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 9 Jul 2022 09:43:54 -0600 Subject: [PATCH] fix causal diffusion masking for low timesteps --- codes/models/diffusion/gaussian_diffusion.py | 32 ++++++++++++++------ 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/codes/models/diffusion/gaussian_diffusion.py b/codes/models/diffusion/gaussian_diffusion.py index 6ad457c0..610591fd 100644 --- a/codes/models/diffusion/gaussian_diffusion.py +++ b/codes/models/diffusion/gaussian_diffusion.py @@ -46,10 +46,18 @@ def causal_timestep_adjustment(t, S, num_timesteps, causal_slope=1, add_jitter=T # Now use the re-mapped adj_t to create a timestep vector that propagates across the sequence with the specified slope. t = adj_t.unsqueeze(1).repeat(1, S) - t = (t + torch.arange(0, S, device=t.device) * causal_slope).clamp(0, num_timesteps).long() + t = (t + torch.arange(0, S, device=t.device) * causal_slope).clamp(-1, num_timesteps).long() return t +def causal_mask_and_fix(t, num_timesteps): + mask1 = t == num_timesteps + t[mask1] = num_timesteps-1 + mask2 = t == -1 + t[mask2] = 0 + return t, mask1.logical_or(mask2) + + def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): """ Get a pre-defined beta schedule for the given name. @@ -579,8 +587,7 @@ class GaussianDiffusion: mask = torch.zeros_like(img) if causal: t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope, add_jitter=False).unsqueeze(1) - mask = t == self.num_timesteps - t[mask] = self.num_timesteps-1 + t, mask = causal_mask_and_fix(t, self.num_timesteps) mask = mask.repeat(img.shape[0], img.shape[1], 1) with th.no_grad(): out = self.p_sample( @@ -809,7 +816,7 @@ class GaussianDiffusion: mask = torch.zeros_like(img) if causal: t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope, add_jitter=False).unsqueeze(1) - mask = t == self.num_timesteps + t, mask = causal_mask_and_fix(t, self.num_timesteps) t[mask] = self.num_timesteps-1 mask = mask.repeat(img.shape[0], img.shape[1], 1) with th.no_grad(): @@ -894,8 +901,8 @@ class GaussianDiffusion: noise = th.randn_like(x_start) if len(t.shape) == 3: - t_mask = t != self.num_timesteps - t[t_mask.logical_not()] = self.num_timesteps-1 + t, t_mask = causal_mask_and_fix(t, self.num_timesteps) + t_mask = t_mask.logical_not() # This is used to mask out losses for timesteps that are out of bounds. else: t_mask = torch.ones_like(x_start) @@ -939,7 +946,7 @@ class GaussianDiffusion: x_t=x_t, t=t, clip_denoised=False, - )["output"] * t_mask + )["output"] if self.loss_type == LossType.RESCALED_MSE: # Divide by 1000 for equivalence with initial implementation. # Without a factor of 1/1000, the VB term hurts the MSE term. @@ -964,6 +971,7 @@ class GaussianDiffusion: s_err = channel_balancing_fn(s_err) terms["mse_by_batch"] = s_err.reshape(s_err.shape[0], -1).mean(dim=1) terms["mse"] = mean_flat(s_err) + terms["vb"] = terms["vb"] * t_mask terms["x_start_predicted"] = x_start_pred if "vb" in terms: if channel_balancing_fn is not None: @@ -1085,14 +1093,20 @@ def test_causal_training_losses(): def graph_causal_timestep_adjustment(): import matplotlib.pyplot as plt - S = 400 + S = 2000 #slope=4 num_timesteps=4000 - for slpe in range(10, 400, 10): + for slpe in range(0, 200, 10): slope = slpe / 10 t_res = [] for t in range(num_timesteps, -1, -num_timesteps//50): T = causal_timestep_adjustment(torch.tensor([t]), S, num_timesteps, causal_slope=slope, add_jitter=False)[0] + + # The following adjustment makes it easier to visualize the timestep regions where the model is actually working. + T_adj = (T == num_timesteps).logical_or(T == -1) + T[T_adj] = t + print(t, T.float().mean()) + t_res.append(T) plt.plot(T.numpy())