From 72c0e4b56bbe1a0d7134d107d3f812859ef45810 Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 11 Jul 2022 17:02:59 -0600 Subject: [PATCH] Fix rounding warning --- codes/models/diffusion/gaussian_diffusion.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/codes/models/diffusion/gaussian_diffusion.py b/codes/models/diffusion/gaussian_diffusion.py index 1cf59dd9..385c790f 100644 --- a/codes/models/diffusion/gaussian_diffusion.py +++ b/codes/models/diffusion/gaussian_diffusion.py @@ -25,7 +25,9 @@ def causal_timestep_adjustment(t, S, num_timesteps, causal_slope=1, add_jitter=T sequence is all [num_timesteps]. As a result of the last property, longer sequences will have larger "gaps" between them in continuous space. This must - be considered at inference time. + be considered at inference time. Specifically, you should allot ((num_timesteps+causal_slope*(seq_len-1))/num_timesteps) + times more timesteps in inference for the same quality. + :param t: Batched timestep integers :param S: Sequence length. :param num_timesteps: Number of total timesteps. @@ -37,7 +39,7 @@ def causal_timestep_adjustment(t, S, num_timesteps, causal_slope=1, add_jitter=T S_sloped = causal_slope * (S-1) # This algorithm for adding causality does so by simply adding S_sloped additional timesteps. To make this # actually work, we map the existing t from the timescale specified to the model to the causal timescale: - adj_t = t * (num_timesteps + S_sloped) // num_timesteps + adj_t = torch.div(t * (num_timesteps + S_sloped), num_timesteps, rounding_mode='floor') adj_t = adj_t - S_sloped if add_jitter: t_gap = (num_timesteps + S_sloped) / num_timesteps @@ -1129,12 +1131,11 @@ def graph_causal_timestep_adjustment(): def graph_causal_timestep_adjustment_by_timestep(): import matplotlib.pyplot as plt S = 400 - slope=10 + slope=8 num_timesteps=4000 t_res = [] for t in range(num_timesteps, -1, -num_timesteps//50): T = causal_timestep_adjustment(torch.tensor([t]), S, num_timesteps, causal_slope=slope, add_jitter=False)[0] - t_res.append(T) plt.plot(T.numpy()) plt.ylim(0,num_timesteps)