Fix rounding warning

This commit is contained in:
James Betker 2022-07-11 17:02:59 -06:00
parent 3edca1a906
commit 72c0e4b56b

View File

@ -25,7 +25,9 @@ def causal_timestep_adjustment(t, S, num_timesteps, causal_slope=1, add_jitter=T
sequence is all [num_timesteps]. sequence is all [num_timesteps].
As a result of the last property, longer sequences will have larger "gaps" between them in continuous space. This must As a result of the last property, longer sequences will have larger "gaps" between them in continuous space. This must
be considered at inference time. be considered at inference time. Specifically, you should allot ((num_timesteps+causal_slope*(seq_len-1))/num_timesteps)
times more timesteps in inference for the same quality.
:param t: Batched timestep integers :param t: Batched timestep integers
:param S: Sequence length. :param S: Sequence length.
:param num_timesteps: Number of total timesteps. :param num_timesteps: Number of total timesteps.
@ -37,7 +39,7 @@ def causal_timestep_adjustment(t, S, num_timesteps, causal_slope=1, add_jitter=T
S_sloped = causal_slope * (S-1) S_sloped = causal_slope * (S-1)
# This algorithm for adding causality does so by simply adding S_sloped additional timesteps. To make this # This algorithm for adding causality does so by simply adding S_sloped additional timesteps. To make this
# actually work, we map the existing t from the timescale specified to the model to the causal timescale: # actually work, we map the existing t from the timescale specified to the model to the causal timescale:
adj_t = t * (num_timesteps + S_sloped) // num_timesteps adj_t = torch.div(t * (num_timesteps + S_sloped), num_timesteps, rounding_mode='floor')
adj_t = adj_t - S_sloped adj_t = adj_t - S_sloped
if add_jitter: if add_jitter:
t_gap = (num_timesteps + S_sloped) / num_timesteps t_gap = (num_timesteps + S_sloped) / num_timesteps
@ -1129,12 +1131,11 @@ def graph_causal_timestep_adjustment():
def graph_causal_timestep_adjustment_by_timestep(): def graph_causal_timestep_adjustment_by_timestep():
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
S = 400 S = 400
slope=10 slope=8
num_timesteps=4000 num_timesteps=4000
t_res = [] t_res = []
for t in range(num_timesteps, -1, -num_timesteps//50): for t in range(num_timesteps, -1, -num_timesteps//50):
T = causal_timestep_adjustment(torch.tensor([t]), S, num_timesteps, causal_slope=slope, add_jitter=False)[0] T = causal_timestep_adjustment(torch.tensor([t]), S, num_timesteps, causal_slope=slope, add_jitter=False)[0]
t_res.append(T) t_res.append(T)
plt.plot(T.numpy()) plt.plot(T.numpy())
plt.ylim(0,num_timesteps) plt.ylim(0,num_timesteps)