Add causal timestep adjustments

This commit is contained in:
James Betker 2022-07-07 15:17:47 -06:00 committed by GitHub
parent f5c246b879
commit 5f575b5d3c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -17,6 +17,63 @@ from .nn import mean_flat
from .losses import normal_kl, discretized_gaussian_log_likelihood
def causal_timestep_adjustment(t, S, num_timesteps, causal_slope=1, add_jitter=True):
"""
Remaps [t] from a batch of integers into a causal sequence [S] long where each sequence element is [causal_slope]
timesteps advanced from the previous sequence element. At t=0, the sequence is all 0s and at t=[num_timesteps], the
sequence is all [num_timesteps].
As a result of the last property, longer sequences will have larger "gaps" between them in continuous space. This must
be considered at inference time.
:param t: Batched timestep integers
:param S: Sequence length.
:param num_timesteps: Number of total timesteps.
:param causal_slope: The causal slope. Ex: "2" means each sequence element will be 2 timesteps ahead of its predecessor.
:param add_jitter: Whether or not to add random jitter into the extra gaps between timesteps added by this function.
Should be true for training and false for inference.
:return: [b,S] sequence of timestep integers.
"""
S_sloped = causal_slope * (S-1)
# This algorithm for adding causality does so by simply adding S_sloped additional timesteps. To make this
# actually work, we map the existing t from the timescale specified to the model to the causal timescale:
adj_t = t * (num_timesteps + S_sloped) // num_timesteps
if add_jitter:
jitter = (random.random() - .5) * S_sloped
adj_t = (adj_t+jitter).clamp(0, num_timesteps+S_sloped)
# Now use the re-mapped adj_t to create a timestep vector that propagates across the sequence with the specified slope.
t = adj_t.unsqueeze(1).repeat(1, S)
t = (t - torch.arange(0, S) * causal_slope).clamp(0, num_timesteps).long()
return t
def graph_causal_timestep_adjustment():
S = 400
slope=4
num_timesteps=4000
#for num_timesteps in range(100, 4000, 200):
t_res = []
for t in range(num_timesteps, -1, -num_timesteps//50):
T = causal_timestep_adjustment(torch.tensor([t]), S, num_timesteps, causal_slope=slope, add_jitter=False)[0]
t_res.append(T)
plt.plot(T.numpy())
plt.ylim(0,4000)
plt.xlim(0,500)
plt.savefig(f'{t}.png')
plt.clf()
for i in range(len(t_res)):
for j in range(len(t_res)):
if i == j:
continue
#assert not torch.all(t_res[i] == t_res[j])
plt.ylim(0,4000)
plt.xlim(0,500)
plt.ylabel('timestep')
plt.savefig(f'{num_timesteps}.png')
plt.clf()
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
"""
Get a pre-defined beta schedule for the given name.
@ -790,6 +847,14 @@ class GaussianDiffusion:
output = th.where((t == 0).view(-1, 1, 1), decoder_nll, kl)
return {"output": output, "pred_xstart": out["pred_xstart"]}
def causal_training_losses(self, model, x_start, t, causal_slope=1, model_kwargs=None, noise=None, channel_balancing_fn=None):
"""
Compute training losses for a causal diffusion process.
"""
assert len(x_start.shape) == 3, "causal_training_losses assumes a 1d sequence with the axis being the time axis."
t = causal_timestep_adjustment(t, x_start.shape[-1], self.num_timesteps, causal_slope, add_jitter=True)
return self.training_losses(model, x_start, t, model_kwargs, noise, channel_balancing_fn)
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, channel_balancing_fn=None):
"""
Compute training losses for a single timestep.