fix causal diffusion masking for low timesteps

This commit is contained in:
James Betker 2022-07-09 09:43:54 -06:00
parent 79a5b54e57
commit 8657d4d060

View File

@ -46,10 +46,18 @@ def causal_timestep_adjustment(t, S, num_timesteps, causal_slope=1, add_jitter=T
# Now use the re-mapped adj_t to create a timestep vector that propagates across the sequence with the specified slope.
t = adj_t.unsqueeze(1).repeat(1, S)
t = (t + torch.arange(0, S, device=t.device) * causal_slope).clamp(0, num_timesteps).long()
t = (t + torch.arange(0, S, device=t.device) * causal_slope).clamp(-1, num_timesteps).long()
return t
def causal_mask_and_fix(t, num_timesteps):
mask1 = t == num_timesteps
t[mask1] = num_timesteps-1
mask2 = t == -1
t[mask2] = 0
return t, mask1.logical_or(mask2)
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
"""
Get a pre-defined beta schedule for the given name.
@ -579,8 +587,7 @@ class GaussianDiffusion:
mask = torch.zeros_like(img)
if causal:
t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope, add_jitter=False).unsqueeze(1)
mask = t == self.num_timesteps
t[mask] = self.num_timesteps-1
t, mask = causal_mask_and_fix(t, self.num_timesteps)
mask = mask.repeat(img.shape[0], img.shape[1], 1)
with th.no_grad():
out = self.p_sample(
@ -809,7 +816,7 @@ class GaussianDiffusion:
mask = torch.zeros_like(img)
if causal:
t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope, add_jitter=False).unsqueeze(1)
mask = t == self.num_timesteps
t, mask = causal_mask_and_fix(t, self.num_timesteps)
t[mask] = self.num_timesteps-1
mask = mask.repeat(img.shape[0], img.shape[1], 1)
with th.no_grad():
@ -894,8 +901,8 @@ class GaussianDiffusion:
noise = th.randn_like(x_start)
if len(t.shape) == 3:
t_mask = t != self.num_timesteps
t[t_mask.logical_not()] = self.num_timesteps-1
t, t_mask = causal_mask_and_fix(t, self.num_timesteps)
t_mask = t_mask.logical_not() # This is used to mask out losses for timesteps that are out of bounds.
else:
t_mask = torch.ones_like(x_start)
@ -939,7 +946,7 @@ class GaussianDiffusion:
x_t=x_t,
t=t,
clip_denoised=False,
)["output"] * t_mask
)["output"]
if self.loss_type == LossType.RESCALED_MSE:
# Divide by 1000 for equivalence with initial implementation.
# Without a factor of 1/1000, the VB term hurts the MSE term.
@ -964,6 +971,7 @@ class GaussianDiffusion:
s_err = channel_balancing_fn(s_err)
terms["mse_by_batch"] = s_err.reshape(s_err.shape[0], -1).mean(dim=1)
terms["mse"] = mean_flat(s_err)
terms["vb"] = terms["vb"] * t_mask
terms["x_start_predicted"] = x_start_pred
if "vb" in terms:
if channel_balancing_fn is not None:
@ -1085,14 +1093,20 @@ def test_causal_training_losses():
def graph_causal_timestep_adjustment():
import matplotlib.pyplot as plt
S = 400
S = 2000
#slope=4
num_timesteps=4000
for slpe in range(10, 400, 10):
for slpe in range(0, 200, 10):
slope = slpe / 10
t_res = []
for t in range(num_timesteps, -1, -num_timesteps//50):
T = causal_timestep_adjustment(torch.tensor([t]), S, num_timesteps, causal_slope=slope, add_jitter=False)[0]
# The following adjustment makes it easier to visualize the timestep regions where the model is actually working.
T_adj = (T == num_timesteps).logical_or(T == -1)
T[T_adj] = t
print(t, T.float().mean())
t_res.append(T)
plt.plot(T.numpy())