Add causal timestep adjustments
This commit is contained in:
parent
f5c246b879
commit
5f575b5d3c
|
@ -17,6 +17,63 @@ from .nn import mean_flat
|
|||
from .losses import normal_kl, discretized_gaussian_log_likelihood
|
||||
|
||||
|
||||
def causal_timestep_adjustment(t, S, num_timesteps, causal_slope=1, add_jitter=True):
|
||||
"""
|
||||
Remaps [t] from a batch of integers into a causal sequence [S] long where each sequence element is [causal_slope]
|
||||
timesteps advanced from the previous sequence element. At t=0, the sequence is all 0s and at t=[num_timesteps], the
|
||||
sequence is all [num_timesteps].
|
||||
|
||||
As a result of the last property, longer sequences will have larger "gaps" between them in continuous space. This must
|
||||
be considered at inference time.
|
||||
:param t: Batched timestep integers
|
||||
:param S: Sequence length.
|
||||
:param num_timesteps: Number of total timesteps.
|
||||
:param causal_slope: The causal slope. Ex: "2" means each sequence element will be 2 timesteps ahead of its predecessor.
|
||||
:param add_jitter: Whether or not to add random jitter into the extra gaps between timesteps added by this function.
|
||||
Should be true for training and false for inference.
|
||||
:return: [b,S] sequence of timestep integers.
|
||||
"""
|
||||
S_sloped = causal_slope * (S-1)
|
||||
# This algorithm for adding causality does so by simply adding S_sloped additional timesteps. To make this
|
||||
# actually work, we map the existing t from the timescale specified to the model to the causal timescale:
|
||||
adj_t = t * (num_timesteps + S_sloped) // num_timesteps
|
||||
if add_jitter:
|
||||
jitter = (random.random() - .5) * S_sloped
|
||||
adj_t = (adj_t+jitter).clamp(0, num_timesteps+S_sloped)
|
||||
|
||||
# Now use the re-mapped adj_t to create a timestep vector that propagates across the sequence with the specified slope.
|
||||
t = adj_t.unsqueeze(1).repeat(1, S)
|
||||
t = (t - torch.arange(0, S) * causal_slope).clamp(0, num_timesteps).long()
|
||||
return t
|
||||
|
||||
|
||||
def graph_causal_timestep_adjustment():
|
||||
S = 400
|
||||
slope=4
|
||||
num_timesteps=4000
|
||||
#for num_timesteps in range(100, 4000, 200):
|
||||
t_res = []
|
||||
for t in range(num_timesteps, -1, -num_timesteps//50):
|
||||
T = causal_timestep_adjustment(torch.tensor([t]), S, num_timesteps, causal_slope=slope, add_jitter=False)[0]
|
||||
t_res.append(T)
|
||||
plt.plot(T.numpy())
|
||||
plt.ylim(0,4000)
|
||||
plt.xlim(0,500)
|
||||
plt.savefig(f'{t}.png')
|
||||
plt.clf()
|
||||
|
||||
for i in range(len(t_res)):
|
||||
for j in range(len(t_res)):
|
||||
if i == j:
|
||||
continue
|
||||
#assert not torch.all(t_res[i] == t_res[j])
|
||||
plt.ylim(0,4000)
|
||||
plt.xlim(0,500)
|
||||
plt.ylabel('timestep')
|
||||
plt.savefig(f'{num_timesteps}.png')
|
||||
plt.clf()
|
||||
|
||||
|
||||
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
|
||||
"""
|
||||
Get a pre-defined beta schedule for the given name.
|
||||
|
@ -790,6 +847,14 @@ class GaussianDiffusion:
|
|||
output = th.where((t == 0).view(-1, 1, 1), decoder_nll, kl)
|
||||
return {"output": output, "pred_xstart": out["pred_xstart"]}
|
||||
|
||||
def causal_training_losses(self, model, x_start, t, causal_slope=1, model_kwargs=None, noise=None, channel_balancing_fn=None):
|
||||
"""
|
||||
Compute training losses for a causal diffusion process.
|
||||
"""
|
||||
assert len(x_start.shape) == 3, "causal_training_losses assumes a 1d sequence with the axis being the time axis."
|
||||
t = causal_timestep_adjustment(t, x_start.shape[-1], self.num_timesteps, causal_slope, add_jitter=True)
|
||||
return self.training_losses(model, x_start, t, model_kwargs, noise, channel_balancing_fn)
|
||||
|
||||
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, channel_balancing_fn=None):
|
||||
"""
|
||||
Compute training losses for a single timestep.
|
||||
|
|
Loading…
Reference in New Issue
Block a user