Fix rounding warning
This commit is contained in:
parent
3edca1a906
commit
72c0e4b56b
|
@ -25,7 +25,9 @@ def causal_timestep_adjustment(t, S, num_timesteps, causal_slope=1, add_jitter=T
|
|||
sequence is all [num_timesteps].
|
||||
|
||||
As a result of the last property, longer sequences will have larger "gaps" between them in continuous space. This must
|
||||
be considered at inference time.
|
||||
be considered at inference time. Specifically, you should allot ((num_timesteps+causal_slope*(seq_len-1))/num_timesteps)
|
||||
times more timesteps in inference for the same quality.
|
||||
|
||||
:param t: Batched timestep integers
|
||||
:param S: Sequence length.
|
||||
:param num_timesteps: Number of total timesteps.
|
||||
|
@ -37,7 +39,7 @@ def causal_timestep_adjustment(t, S, num_timesteps, causal_slope=1, add_jitter=T
|
|||
S_sloped = causal_slope * (S-1)
|
||||
# This algorithm for adding causality does so by simply adding S_sloped additional timesteps. To make this
|
||||
# actually work, we map the existing t from the timescale specified to the model to the causal timescale:
|
||||
adj_t = t * (num_timesteps + S_sloped) // num_timesteps
|
||||
adj_t = torch.div(t * (num_timesteps + S_sloped), num_timesteps, rounding_mode='floor')
|
||||
adj_t = adj_t - S_sloped
|
||||
if add_jitter:
|
||||
t_gap = (num_timesteps + S_sloped) / num_timesteps
|
||||
|
@ -1129,12 +1131,11 @@ def graph_causal_timestep_adjustment():
|
|||
def graph_causal_timestep_adjustment_by_timestep():
|
||||
import matplotlib.pyplot as plt
|
||||
S = 400
|
||||
slope=10
|
||||
slope=8
|
||||
num_timesteps=4000
|
||||
t_res = []
|
||||
for t in range(num_timesteps, -1, -num_timesteps//50):
|
||||
T = causal_timestep_adjustment(torch.tensor([t]), S, num_timesteps, causal_slope=slope, add_jitter=False)[0]
|
||||
|
||||
t_res.append(T)
|
||||
plt.plot(T.numpy())
|
||||
plt.ylim(0,num_timesteps)
|
||||
|
|
Loading…
Reference in New Issue
Block a user