Rescale causal scale along with timestep spacing

This commit is contained in:
James Betker 2022-07-09 22:09:25 -06:00
parent b432d7c7de
commit 3edca1a906
3 changed files with 36 additions and 12 deletions

View File

@ -412,6 +412,9 @@ class GaussianDiffusion:
return t.float() * (1000.0 / self.num_timesteps)
return t
def _get_scale_ratio(self):
return 1
def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
"""
Compute the mean for the previous step, given a function cond_fn that
@ -586,7 +589,7 @@ class GaussianDiffusion:
t = th.tensor([i] * shape[0], device=device)
mask = torch.zeros_like(img)
if causal:
t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope, add_jitter=False).unsqueeze(1)
t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope * self._get_scale_ratio(), add_jitter=False).unsqueeze(1)
t, mask = causal_mask_and_fix(t, self.num_timesteps)
mask = mask.repeat(img.shape[0], img.shape[1], 1)
with th.no_grad():
@ -816,7 +819,7 @@ class GaussianDiffusion:
t = th.tensor([i] * shape[0], device=device)
mask = torch.zeros_like(img)
if causal:
t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope, add_jitter=False).unsqueeze(1)
t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope * self._get_scale_ratio(), add_jitter=False).unsqueeze(1)
t, mask = causal_mask_and_fix(t, self.num_timesteps)
t[mask] = self.num_timesteps-1
mask = mask.repeat(img.shape[0], img.shape[1], 1)
@ -880,7 +883,7 @@ class GaussianDiffusion:
Compute training losses for a causal diffusion process.
"""
assert len(x_start.shape) == 3, "causal_training_losses assumes a 1d sequence with the axis being the time axis."
ct = causal_timestep_adjustment(t, x_start.shape[-1], self.num_timesteps, causal_slope, add_jitter=True)
ct = causal_timestep_adjustment(t, x_start.shape[-1], self.num_timesteps, causal_slope * self._get_scale_ratio(), add_jitter=True)
ct = ct.unsqueeze(1) # Necessary to make the output shape compatible with x_start.
return self.training_losses(model, x_start, ct, model_kwargs, noise, channel_balancing_fn)
@ -1095,19 +1098,18 @@ def test_causal_training_losses():
def graph_causal_timestep_adjustment():
import matplotlib.pyplot as plt
S = 2000
S = 400
#slope=4
num_timesteps=4000
for slpe in range(0, 200, 10):
for slpe in range(10, 400, 50):
slope = slpe / 10
t_res = []
for t in range(num_timesteps, -1, -num_timesteps//50):
T = causal_timestep_adjustment(torch.tensor([t]), S, num_timesteps, causal_slope=slope, add_jitter=False)[0]
# The following adjustment makes it easier to visualize the timestep regions where the model is actually working.
T_adj = (T == num_timesteps).logical_or(T == -1)
T[T_adj] = t
print(t, T.float().mean())
#T_adj = (T == num_timesteps).logical_or(T == -1)
#T[T_adj] = t
t_res.append(T)
plt.plot(T.numpy())
@ -1116,13 +1118,32 @@ def graph_causal_timestep_adjustment():
for j in range(len(t_res)):
if i == j:
continue
#assert not torch.all(t_res[i] == t_res[j])
assert not torch.all(t_res[i] == t_res[j])
plt.ylim(0,num_timesteps)
plt.xlim(0,4000)
plt.ylabel('timestep')
plt.savefig(f'{slpe}.png')
plt.clf()
def graph_causal_timestep_adjustment_by_timestep():
import matplotlib.pyplot as plt
S = 400
slope=10
num_timesteps=4000
t_res = []
for t in range(num_timesteps, -1, -num_timesteps//50):
T = causal_timestep_adjustment(torch.tensor([t]), S, num_timesteps, causal_slope=slope, add_jitter=False)[0]
t_res.append(T)
plt.plot(T.numpy())
plt.ylim(0,num_timesteps)
plt.xlim(0,4000)
plt.ylabel('timestep')
plt.savefig(f'{t}.png')
plt.clf()
if __name__ == '__main__':
#test_causal_training_losses()
graph_causal_timestep_adjustment()
#graph_causal_timestep_adjustment()
graph_causal_timestep_adjustment_by_timestep()

View File

@ -113,6 +113,9 @@ class SpacedDiffusion(GaussianDiffusion):
# Scaling is done by the wrapped model.
return t
def _get_scale_ratio(self):
return self.num_timesteps / self.original_num_steps
class _WrappedModel:
def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):

View File

@ -424,11 +424,11 @@ class MusicDiffusionFid(evaluator.Evaluator):
if __name__ == '__main__':
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_cheater_gen.yml', 'generator',
also_load_savepoint=False,
load_path='X:\\dlas\\experiments\\train_music_cheater_gen_v5_causal_retrain\\models\\18000_generator.pth'
load_path='X:\\dlas\\experiments\\train_music_cheater_gen_v5_causal_retrain\\models\\22000_generator_ema.pth'
).cuda()
opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :)
#'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety.
'diffusion_steps': 256,
'diffusion_steps': 100,
'conditioning_free': True, 'conditioning_free_k': 1, 'use_ddim': False, 'clip_audio': False,
'diffusion_schedule': 'linear', 'diffusion_type': 'cheater_gen',
'causal': True, 'causal_slope': 1,