Rescale causal scale along with timestep spacing
This commit is contained in:
parent
b432d7c7de
commit
3edca1a906
|
@ -412,6 +412,9 @@ class GaussianDiffusion:
|
|||
return t.float() * (1000.0 / self.num_timesteps)
|
||||
return t
|
||||
|
||||
def _get_scale_ratio(self):
|
||||
return 1
|
||||
|
||||
def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
||||
"""
|
||||
Compute the mean for the previous step, given a function cond_fn that
|
||||
|
@ -586,7 +589,7 @@ class GaussianDiffusion:
|
|||
t = th.tensor([i] * shape[0], device=device)
|
||||
mask = torch.zeros_like(img)
|
||||
if causal:
|
||||
t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope, add_jitter=False).unsqueeze(1)
|
||||
t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope * self._get_scale_ratio(), add_jitter=False).unsqueeze(1)
|
||||
t, mask = causal_mask_and_fix(t, self.num_timesteps)
|
||||
mask = mask.repeat(img.shape[0], img.shape[1], 1)
|
||||
with th.no_grad():
|
||||
|
@ -816,7 +819,7 @@ class GaussianDiffusion:
|
|||
t = th.tensor([i] * shape[0], device=device)
|
||||
mask = torch.zeros_like(img)
|
||||
if causal:
|
||||
t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope, add_jitter=False).unsqueeze(1)
|
||||
t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope * self._get_scale_ratio(), add_jitter=False).unsqueeze(1)
|
||||
t, mask = causal_mask_and_fix(t, self.num_timesteps)
|
||||
t[mask] = self.num_timesteps-1
|
||||
mask = mask.repeat(img.shape[0], img.shape[1], 1)
|
||||
|
@ -880,7 +883,7 @@ class GaussianDiffusion:
|
|||
Compute training losses for a causal diffusion process.
|
||||
"""
|
||||
assert len(x_start.shape) == 3, "causal_training_losses assumes a 1d sequence with the axis being the time axis."
|
||||
ct = causal_timestep_adjustment(t, x_start.shape[-1], self.num_timesteps, causal_slope, add_jitter=True)
|
||||
ct = causal_timestep_adjustment(t, x_start.shape[-1], self.num_timesteps, causal_slope * self._get_scale_ratio(), add_jitter=True)
|
||||
ct = ct.unsqueeze(1) # Necessary to make the output shape compatible with x_start.
|
||||
return self.training_losses(model, x_start, ct, model_kwargs, noise, channel_balancing_fn)
|
||||
|
||||
|
@ -1095,19 +1098,18 @@ def test_causal_training_losses():
|
|||
|
||||
def graph_causal_timestep_adjustment():
|
||||
import matplotlib.pyplot as plt
|
||||
S = 2000
|
||||
S = 400
|
||||
#slope=4
|
||||
num_timesteps=4000
|
||||
for slpe in range(0, 200, 10):
|
||||
for slpe in range(10, 400, 50):
|
||||
slope = slpe / 10
|
||||
t_res = []
|
||||
for t in range(num_timesteps, -1, -num_timesteps//50):
|
||||
T = causal_timestep_adjustment(torch.tensor([t]), S, num_timesteps, causal_slope=slope, add_jitter=False)[0]
|
||||
|
||||
# The following adjustment makes it easier to visualize the timestep regions where the model is actually working.
|
||||
T_adj = (T == num_timesteps).logical_or(T == -1)
|
||||
T[T_adj] = t
|
||||
print(t, T.float().mean())
|
||||
#T_adj = (T == num_timesteps).logical_or(T == -1)
|
||||
#T[T_adj] = t
|
||||
|
||||
t_res.append(T)
|
||||
plt.plot(T.numpy())
|
||||
|
@ -1116,13 +1118,32 @@ def graph_causal_timestep_adjustment():
|
|||
for j in range(len(t_res)):
|
||||
if i == j:
|
||||
continue
|
||||
#assert not torch.all(t_res[i] == t_res[j])
|
||||
assert not torch.all(t_res[i] == t_res[j])
|
||||
plt.ylim(0,num_timesteps)
|
||||
plt.xlim(0,4000)
|
||||
plt.ylabel('timestep')
|
||||
plt.savefig(f'{slpe}.png')
|
||||
plt.clf()
|
||||
|
||||
|
||||
def graph_causal_timestep_adjustment_by_timestep():
|
||||
import matplotlib.pyplot as plt
|
||||
S = 400
|
||||
slope=10
|
||||
num_timesteps=4000
|
||||
t_res = []
|
||||
for t in range(num_timesteps, -1, -num_timesteps//50):
|
||||
T = causal_timestep_adjustment(torch.tensor([t]), S, num_timesteps, causal_slope=slope, add_jitter=False)[0]
|
||||
|
||||
t_res.append(T)
|
||||
plt.plot(T.numpy())
|
||||
plt.ylim(0,num_timesteps)
|
||||
plt.xlim(0,4000)
|
||||
plt.ylabel('timestep')
|
||||
plt.savefig(f'{t}.png')
|
||||
plt.clf()
|
||||
|
||||
if __name__ == '__main__':
|
||||
#test_causal_training_losses()
|
||||
graph_causal_timestep_adjustment()
|
||||
#graph_causal_timestep_adjustment()
|
||||
graph_causal_timestep_adjustment_by_timestep()
|
|
@ -113,6 +113,9 @@ class SpacedDiffusion(GaussianDiffusion):
|
|||
# Scaling is done by the wrapped model.
|
||||
return t
|
||||
|
||||
def _get_scale_ratio(self):
|
||||
return self.num_timesteps / self.original_num_steps
|
||||
|
||||
|
||||
class _WrappedModel:
|
||||
def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
|
||||
|
|
|
@ -424,11 +424,11 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
if __name__ == '__main__':
|
||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_cheater_gen.yml', 'generator',
|
||||
also_load_savepoint=False,
|
||||
load_path='X:\\dlas\\experiments\\train_music_cheater_gen_v5_causal_retrain\\models\\18000_generator.pth'
|
||||
load_path='X:\\dlas\\experiments\\train_music_cheater_gen_v5_causal_retrain\\models\\22000_generator_ema.pth'
|
||||
).cuda()
|
||||
opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :)
|
||||
#'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety.
|
||||
'diffusion_steps': 256,
|
||||
'diffusion_steps': 100,
|
||||
'conditioning_free': True, 'conditioning_free_k': 1, 'use_ddim': False, 'clip_audio': False,
|
||||
'diffusion_schedule': 'linear', 'diffusion_type': 'cheater_gen',
|
||||
'causal': True, 'causal_slope': 1,
|
||||
|
|
Loading…
Reference in New Issue
Block a user