forked from mrq/DL-Art-School
fix causal diffusion masking for low timesteps
This commit is contained in:
parent
79a5b54e57
commit
8657d4d060
|
@ -46,10 +46,18 @@ def causal_timestep_adjustment(t, S, num_timesteps, causal_slope=1, add_jitter=T
|
|||
|
||||
# Now use the re-mapped adj_t to create a timestep vector that propagates across the sequence with the specified slope.
|
||||
t = adj_t.unsqueeze(1).repeat(1, S)
|
||||
t = (t + torch.arange(0, S, device=t.device) * causal_slope).clamp(0, num_timesteps).long()
|
||||
t = (t + torch.arange(0, S, device=t.device) * causal_slope).clamp(-1, num_timesteps).long()
|
||||
return t
|
||||
|
||||
|
||||
def causal_mask_and_fix(t, num_timesteps):
|
||||
mask1 = t == num_timesteps
|
||||
t[mask1] = num_timesteps-1
|
||||
mask2 = t == -1
|
||||
t[mask2] = 0
|
||||
return t, mask1.logical_or(mask2)
|
||||
|
||||
|
||||
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
|
||||
"""
|
||||
Get a pre-defined beta schedule for the given name.
|
||||
|
@ -579,8 +587,7 @@ class GaussianDiffusion:
|
|||
mask = torch.zeros_like(img)
|
||||
if causal:
|
||||
t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope, add_jitter=False).unsqueeze(1)
|
||||
mask = t == self.num_timesteps
|
||||
t[mask] = self.num_timesteps-1
|
||||
t, mask = causal_mask_and_fix(t, self.num_timesteps)
|
||||
mask = mask.repeat(img.shape[0], img.shape[1], 1)
|
||||
with th.no_grad():
|
||||
out = self.p_sample(
|
||||
|
@ -809,7 +816,7 @@ class GaussianDiffusion:
|
|||
mask = torch.zeros_like(img)
|
||||
if causal:
|
||||
t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope, add_jitter=False).unsqueeze(1)
|
||||
mask = t == self.num_timesteps
|
||||
t, mask = causal_mask_and_fix(t, self.num_timesteps)
|
||||
t[mask] = self.num_timesteps-1
|
||||
mask = mask.repeat(img.shape[0], img.shape[1], 1)
|
||||
with th.no_grad():
|
||||
|
@ -894,8 +901,8 @@ class GaussianDiffusion:
|
|||
noise = th.randn_like(x_start)
|
||||
|
||||
if len(t.shape) == 3:
|
||||
t_mask = t != self.num_timesteps
|
||||
t[t_mask.logical_not()] = self.num_timesteps-1
|
||||
t, t_mask = causal_mask_and_fix(t, self.num_timesteps)
|
||||
t_mask = t_mask.logical_not() # This is used to mask out losses for timesteps that are out of bounds.
|
||||
else:
|
||||
t_mask = torch.ones_like(x_start)
|
||||
|
||||
|
@ -939,7 +946,7 @@ class GaussianDiffusion:
|
|||
x_t=x_t,
|
||||
t=t,
|
||||
clip_denoised=False,
|
||||
)["output"] * t_mask
|
||||
)["output"]
|
||||
if self.loss_type == LossType.RESCALED_MSE:
|
||||
# Divide by 1000 for equivalence with initial implementation.
|
||||
# Without a factor of 1/1000, the VB term hurts the MSE term.
|
||||
|
@ -964,6 +971,7 @@ class GaussianDiffusion:
|
|||
s_err = channel_balancing_fn(s_err)
|
||||
terms["mse_by_batch"] = s_err.reshape(s_err.shape[0], -1).mean(dim=1)
|
||||
terms["mse"] = mean_flat(s_err)
|
||||
terms["vb"] = terms["vb"] * t_mask
|
||||
terms["x_start_predicted"] = x_start_pred
|
||||
if "vb" in terms:
|
||||
if channel_balancing_fn is not None:
|
||||
|
@ -1085,14 +1093,20 @@ def test_causal_training_losses():
|
|||
|
||||
def graph_causal_timestep_adjustment():
|
||||
import matplotlib.pyplot as plt
|
||||
S = 400
|
||||
S = 2000
|
||||
#slope=4
|
||||
num_timesteps=4000
|
||||
for slpe in range(10, 400, 10):
|
||||
for slpe in range(0, 200, 10):
|
||||
slope = slpe / 10
|
||||
t_res = []
|
||||
for t in range(num_timesteps, -1, -num_timesteps//50):
|
||||
T = causal_timestep_adjustment(torch.tensor([t]), S, num_timesteps, causal_slope=slope, add_jitter=False)[0]
|
||||
|
||||
# The following adjustment makes it easier to visualize the timestep regions where the model is actually working.
|
||||
T_adj = (T == num_timesteps).logical_or(T == -1)
|
||||
T[T_adj] = t
|
||||
print(t, T.float().mean())
|
||||
|
||||
t_res.append(T)
|
||||
plt.plot(T.numpy())
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user