Fix rounding warning

This commit is contained in:
James Betker 2022-07-11 17:02:59 -06:00
parent 3edca1a906
commit 72c0e4b56b

View File

@ -25,7 +25,9 @@ def causal_timestep_adjustment(t, S, num_timesteps, causal_slope=1, add_jitter=T
sequence is all [num_timesteps].
As a result of the last property, longer sequences will have larger "gaps" between them in continuous space. This must
be considered at inference time.
be considered at inference time. Specifically, you should allot ((num_timesteps+causal_slope*(seq_len-1))/num_timesteps)
times more timesteps in inference for the same quality.
:param t: Batched timestep integers
:param S: Sequence length.
:param num_timesteps: Number of total timesteps.
@ -37,7 +39,7 @@ def causal_timestep_adjustment(t, S, num_timesteps, causal_slope=1, add_jitter=T
S_sloped = causal_slope * (S-1)
# This algorithm for adding causality does so by simply adding S_sloped additional timesteps. To make this
# actually work, we map the existing t from the timescale specified to the model to the causal timescale:
adj_t = t * (num_timesteps + S_sloped) // num_timesteps
adj_t = torch.div(t * (num_timesteps + S_sloped), num_timesteps, rounding_mode='floor')
adj_t = adj_t - S_sloped
if add_jitter:
t_gap = (num_timesteps + S_sloped) / num_timesteps
@ -1129,12 +1131,11 @@ def graph_causal_timestep_adjustment():
def graph_causal_timestep_adjustment_by_timestep():
import matplotlib.pyplot as plt
S = 400
slope=10
slope=8
num_timesteps=4000
t_res = []
for t in range(num_timesteps, -1, -num_timesteps//50):
T = causal_timestep_adjustment(torch.tensor([t]), S, num_timesteps, causal_slope=slope, add_jitter=False)[0]
t_res.append(T)
plt.plot(T.numpy())
plt.ylim(0,num_timesteps)