fix causal diffusion masking for low timesteps

This commit is contained in:
James Betker 2022-07-09 09:43:54 -06:00
parent 79a5b54e57
commit 8657d4d060

View File

@ -46,10 +46,18 @@ def causal_timestep_adjustment(t, S, num_timesteps, causal_slope=1, add_jitter=T
# Now use the re-mapped adj_t to create a timestep vector that propagates across the sequence with the specified slope. # Now use the re-mapped adj_t to create a timestep vector that propagates across the sequence with the specified slope.
t = adj_t.unsqueeze(1).repeat(1, S) t = adj_t.unsqueeze(1).repeat(1, S)
t = (t + torch.arange(0, S, device=t.device) * causal_slope).clamp(0, num_timesteps).long() t = (t + torch.arange(0, S, device=t.device) * causal_slope).clamp(-1, num_timesteps).long()
return t return t
def causal_mask_and_fix(t, num_timesteps):
mask1 = t == num_timesteps
t[mask1] = num_timesteps-1
mask2 = t == -1
t[mask2] = 0
return t, mask1.logical_or(mask2)
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
""" """
Get a pre-defined beta schedule for the given name. Get a pre-defined beta schedule for the given name.
@ -579,8 +587,7 @@ class GaussianDiffusion:
mask = torch.zeros_like(img) mask = torch.zeros_like(img)
if causal: if causal:
t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope, add_jitter=False).unsqueeze(1) t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope, add_jitter=False).unsqueeze(1)
mask = t == self.num_timesteps t, mask = causal_mask_and_fix(t, self.num_timesteps)
t[mask] = self.num_timesteps-1
mask = mask.repeat(img.shape[0], img.shape[1], 1) mask = mask.repeat(img.shape[0], img.shape[1], 1)
with th.no_grad(): with th.no_grad():
out = self.p_sample( out = self.p_sample(
@ -809,7 +816,7 @@ class GaussianDiffusion:
mask = torch.zeros_like(img) mask = torch.zeros_like(img)
if causal: if causal:
t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope, add_jitter=False).unsqueeze(1) t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope, add_jitter=False).unsqueeze(1)
mask = t == self.num_timesteps t, mask = causal_mask_and_fix(t, self.num_timesteps)
t[mask] = self.num_timesteps-1 t[mask] = self.num_timesteps-1
mask = mask.repeat(img.shape[0], img.shape[1], 1) mask = mask.repeat(img.shape[0], img.shape[1], 1)
with th.no_grad(): with th.no_grad():
@ -894,8 +901,8 @@ class GaussianDiffusion:
noise = th.randn_like(x_start) noise = th.randn_like(x_start)
if len(t.shape) == 3: if len(t.shape) == 3:
t_mask = t != self.num_timesteps t, t_mask = causal_mask_and_fix(t, self.num_timesteps)
t[t_mask.logical_not()] = self.num_timesteps-1 t_mask = t_mask.logical_not() # This is used to mask out losses for timesteps that are out of bounds.
else: else:
t_mask = torch.ones_like(x_start) t_mask = torch.ones_like(x_start)
@ -939,7 +946,7 @@ class GaussianDiffusion:
x_t=x_t, x_t=x_t,
t=t, t=t,
clip_denoised=False, clip_denoised=False,
)["output"] * t_mask )["output"]
if self.loss_type == LossType.RESCALED_MSE: if self.loss_type == LossType.RESCALED_MSE:
# Divide by 1000 for equivalence with initial implementation. # Divide by 1000 for equivalence with initial implementation.
# Without a factor of 1/1000, the VB term hurts the MSE term. # Without a factor of 1/1000, the VB term hurts the MSE term.
@ -964,6 +971,7 @@ class GaussianDiffusion:
s_err = channel_balancing_fn(s_err) s_err = channel_balancing_fn(s_err)
terms["mse_by_batch"] = s_err.reshape(s_err.shape[0], -1).mean(dim=1) terms["mse_by_batch"] = s_err.reshape(s_err.shape[0], -1).mean(dim=1)
terms["mse"] = mean_flat(s_err) terms["mse"] = mean_flat(s_err)
terms["vb"] = terms["vb"] * t_mask
terms["x_start_predicted"] = x_start_pred terms["x_start_predicted"] = x_start_pred
if "vb" in terms: if "vb" in terms:
if channel_balancing_fn is not None: if channel_balancing_fn is not None:
@ -1085,14 +1093,20 @@ def test_causal_training_losses():
def graph_causal_timestep_adjustment(): def graph_causal_timestep_adjustment():
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
S = 400 S = 2000
#slope=4 #slope=4
num_timesteps=4000 num_timesteps=4000
for slpe in range(10, 400, 10): for slpe in range(0, 200, 10):
slope = slpe / 10 slope = slpe / 10
t_res = [] t_res = []
for t in range(num_timesteps, -1, -num_timesteps//50): for t in range(num_timesteps, -1, -num_timesteps//50):
T = causal_timestep_adjustment(torch.tensor([t]), S, num_timesteps, causal_slope=slope, add_jitter=False)[0] T = causal_timestep_adjustment(torch.tensor([t]), S, num_timesteps, causal_slope=slope, add_jitter=False)[0]
# The following adjustment makes it easier to visualize the timestep regions where the model is actually working.
T_adj = (T == num_timesteps).logical_or(T == -1)
T[T_adj] = t
print(t, T.float().mean())
t_res.append(T) t_res.append(T)
plt.plot(T.numpy()) plt.plot(T.numpy())