diff --git a/codes/models/diffusion/gaussian_diffusion.py b/codes/models/diffusion/gaussian_diffusion.py index 678adcba..1cf59dd9 100644 --- a/codes/models/diffusion/gaussian_diffusion.py +++ b/codes/models/diffusion/gaussian_diffusion.py @@ -412,6 +412,9 @@ class GaussianDiffusion: return t.float() * (1000.0 / self.num_timesteps) return t + def _get_scale_ratio(self): + return 1 + def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): """ Compute the mean for the previous step, given a function cond_fn that @@ -586,7 +589,7 @@ class GaussianDiffusion: t = th.tensor([i] * shape[0], device=device) mask = torch.zeros_like(img) if causal: - t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope, add_jitter=False).unsqueeze(1) + t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope * self._get_scale_ratio(), add_jitter=False).unsqueeze(1) t, mask = causal_mask_and_fix(t, self.num_timesteps) mask = mask.repeat(img.shape[0], img.shape[1], 1) with th.no_grad(): @@ -816,7 +819,7 @@ class GaussianDiffusion: t = th.tensor([i] * shape[0], device=device) mask = torch.zeros_like(img) if causal: - t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope, add_jitter=False).unsqueeze(1) + t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope * self._get_scale_ratio(), add_jitter=False).unsqueeze(1) t, mask = causal_mask_and_fix(t, self.num_timesteps) t[mask] = self.num_timesteps-1 mask = mask.repeat(img.shape[0], img.shape[1], 1) @@ -880,7 +883,7 @@ class GaussianDiffusion: Compute training losses for a causal diffusion process. """ assert len(x_start.shape) == 3, "causal_training_losses assumes a 1d sequence with the axis being the time axis." - ct = causal_timestep_adjustment(t, x_start.shape[-1], self.num_timesteps, causal_slope, add_jitter=True) + ct = causal_timestep_adjustment(t, x_start.shape[-1], self.num_timesteps, causal_slope * self._get_scale_ratio(), add_jitter=True) ct = ct.unsqueeze(1) # Necessary to make the output shape compatible with x_start. return self.training_losses(model, x_start, ct, model_kwargs, noise, channel_balancing_fn) @@ -1095,19 +1098,18 @@ def test_causal_training_losses(): def graph_causal_timestep_adjustment(): import matplotlib.pyplot as plt - S = 2000 + S = 400 #slope=4 num_timesteps=4000 - for slpe in range(0, 200, 10): + for slpe in range(10, 400, 50): slope = slpe / 10 t_res = [] for t in range(num_timesteps, -1, -num_timesteps//50): T = causal_timestep_adjustment(torch.tensor([t]), S, num_timesteps, causal_slope=slope, add_jitter=False)[0] # The following adjustment makes it easier to visualize the timestep regions where the model is actually working. - T_adj = (T == num_timesteps).logical_or(T == -1) - T[T_adj] = t - print(t, T.float().mean()) + #T_adj = (T == num_timesteps).logical_or(T == -1) + #T[T_adj] = t t_res.append(T) plt.plot(T.numpy()) @@ -1116,13 +1118,32 @@ def graph_causal_timestep_adjustment(): for j in range(len(t_res)): if i == j: continue - #assert not torch.all(t_res[i] == t_res[j]) + assert not torch.all(t_res[i] == t_res[j]) plt.ylim(0,num_timesteps) plt.xlim(0,4000) plt.ylabel('timestep') plt.savefig(f'{slpe}.png') plt.clf() + +def graph_causal_timestep_adjustment_by_timestep(): + import matplotlib.pyplot as plt + S = 400 + slope=10 + num_timesteps=4000 + t_res = [] + for t in range(num_timesteps, -1, -num_timesteps//50): + T = causal_timestep_adjustment(torch.tensor([t]), S, num_timesteps, causal_slope=slope, add_jitter=False)[0] + + t_res.append(T) + plt.plot(T.numpy()) + plt.ylim(0,num_timesteps) + plt.xlim(0,4000) + plt.ylabel('timestep') + plt.savefig(f'{t}.png') + plt.clf() + if __name__ == '__main__': #test_causal_training_losses() - graph_causal_timestep_adjustment() \ No newline at end of file + #graph_causal_timestep_adjustment() + graph_causal_timestep_adjustment_by_timestep() \ No newline at end of file diff --git a/codes/models/diffusion/respace.py b/codes/models/diffusion/respace.py index 78403ebb..03aa9234 100644 --- a/codes/models/diffusion/respace.py +++ b/codes/models/diffusion/respace.py @@ -113,6 +113,9 @@ class SpacedDiffusion(GaussianDiffusion): # Scaling is done by the wrapped model. return t + def _get_scale_ratio(self): + return self.num_timesteps / self.original_num_steps + class _WrappedModel: def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): diff --git a/codes/trainer/eval/music_diffusion_fid.py b/codes/trainer/eval/music_diffusion_fid.py index 2e7c994a..cc4890ec 100644 --- a/codes/trainer/eval/music_diffusion_fid.py +++ b/codes/trainer/eval/music_diffusion_fid.py @@ -424,11 +424,11 @@ class MusicDiffusionFid(evaluator.Evaluator): if __name__ == '__main__': diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_cheater_gen.yml', 'generator', also_load_savepoint=False, - load_path='X:\\dlas\\experiments\\train_music_cheater_gen_v5_causal_retrain\\models\\18000_generator.pth' + load_path='X:\\dlas\\experiments\\train_music_cheater_gen_v5_causal_retrain\\models\\22000_generator_ema.pth' ).cuda() opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :) #'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety. - 'diffusion_steps': 256, + 'diffusion_steps': 100, 'conditioning_free': True, 'conditioning_free_k': 1, 'use_ddim': False, 'clip_audio': False, 'diffusion_schedule': 'linear', 'diffusion_type': 'cheater_gen', 'causal': True, 'causal_slope': 1,