forked from mrq/DL-Art-School
Add causal timestep adjustments
This commit is contained in:
parent
f5c246b879
commit
5f575b5d3c
|
@ -17,6 +17,63 @@ from .nn import mean_flat
|
||||||
from .losses import normal_kl, discretized_gaussian_log_likelihood
|
from .losses import normal_kl, discretized_gaussian_log_likelihood
|
||||||
|
|
||||||
|
|
||||||
|
def causal_timestep_adjustment(t, S, num_timesteps, causal_slope=1, add_jitter=True):
|
||||||
|
"""
|
||||||
|
Remaps [t] from a batch of integers into a causal sequence [S] long where each sequence element is [causal_slope]
|
||||||
|
timesteps advanced from the previous sequence element. At t=0, the sequence is all 0s and at t=[num_timesteps], the
|
||||||
|
sequence is all [num_timesteps].
|
||||||
|
|
||||||
|
As a result of the last property, longer sequences will have larger "gaps" between them in continuous space. This must
|
||||||
|
be considered at inference time.
|
||||||
|
:param t: Batched timestep integers
|
||||||
|
:param S: Sequence length.
|
||||||
|
:param num_timesteps: Number of total timesteps.
|
||||||
|
:param causal_slope: The causal slope. Ex: "2" means each sequence element will be 2 timesteps ahead of its predecessor.
|
||||||
|
:param add_jitter: Whether or not to add random jitter into the extra gaps between timesteps added by this function.
|
||||||
|
Should be true for training and false for inference.
|
||||||
|
:return: [b,S] sequence of timestep integers.
|
||||||
|
"""
|
||||||
|
S_sloped = causal_slope * (S-1)
|
||||||
|
# This algorithm for adding causality does so by simply adding S_sloped additional timesteps. To make this
|
||||||
|
# actually work, we map the existing t from the timescale specified to the model to the causal timescale:
|
||||||
|
adj_t = t * (num_timesteps + S_sloped) // num_timesteps
|
||||||
|
if add_jitter:
|
||||||
|
jitter = (random.random() - .5) * S_sloped
|
||||||
|
adj_t = (adj_t+jitter).clamp(0, num_timesteps+S_sloped)
|
||||||
|
|
||||||
|
# Now use the re-mapped adj_t to create a timestep vector that propagates across the sequence with the specified slope.
|
||||||
|
t = adj_t.unsqueeze(1).repeat(1, S)
|
||||||
|
t = (t - torch.arange(0, S) * causal_slope).clamp(0, num_timesteps).long()
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
def graph_causal_timestep_adjustment():
|
||||||
|
S = 400
|
||||||
|
slope=4
|
||||||
|
num_timesteps=4000
|
||||||
|
#for num_timesteps in range(100, 4000, 200):
|
||||||
|
t_res = []
|
||||||
|
for t in range(num_timesteps, -1, -num_timesteps//50):
|
||||||
|
T = causal_timestep_adjustment(torch.tensor([t]), S, num_timesteps, causal_slope=slope, add_jitter=False)[0]
|
||||||
|
t_res.append(T)
|
||||||
|
plt.plot(T.numpy())
|
||||||
|
plt.ylim(0,4000)
|
||||||
|
plt.xlim(0,500)
|
||||||
|
plt.savefig(f'{t}.png')
|
||||||
|
plt.clf()
|
||||||
|
|
||||||
|
for i in range(len(t_res)):
|
||||||
|
for j in range(len(t_res)):
|
||||||
|
if i == j:
|
||||||
|
continue
|
||||||
|
#assert not torch.all(t_res[i] == t_res[j])
|
||||||
|
plt.ylim(0,4000)
|
||||||
|
plt.xlim(0,500)
|
||||||
|
plt.ylabel('timestep')
|
||||||
|
plt.savefig(f'{num_timesteps}.png')
|
||||||
|
plt.clf()
|
||||||
|
|
||||||
|
|
||||||
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
|
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
|
||||||
"""
|
"""
|
||||||
Get a pre-defined beta schedule for the given name.
|
Get a pre-defined beta schedule for the given name.
|
||||||
|
@ -790,6 +847,14 @@ class GaussianDiffusion:
|
||||||
output = th.where((t == 0).view(-1, 1, 1), decoder_nll, kl)
|
output = th.where((t == 0).view(-1, 1, 1), decoder_nll, kl)
|
||||||
return {"output": output, "pred_xstart": out["pred_xstart"]}
|
return {"output": output, "pred_xstart": out["pred_xstart"]}
|
||||||
|
|
||||||
|
def causal_training_losses(self, model, x_start, t, causal_slope=1, model_kwargs=None, noise=None, channel_balancing_fn=None):
|
||||||
|
"""
|
||||||
|
Compute training losses for a causal diffusion process.
|
||||||
|
"""
|
||||||
|
assert len(x_start.shape) == 3, "causal_training_losses assumes a 1d sequence with the axis being the time axis."
|
||||||
|
t = causal_timestep_adjustment(t, x_start.shape[-1], self.num_timesteps, causal_slope, add_jitter=True)
|
||||||
|
return self.training_losses(model, x_start, t, model_kwargs, noise, channel_balancing_fn)
|
||||||
|
|
||||||
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, channel_balancing_fn=None):
|
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, channel_balancing_fn=None):
|
||||||
"""
|
"""
|
||||||
Compute training losses for a single timestep.
|
Compute training losses for a single timestep.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user