From 3edca1a906bc266ea930fbc513337b42473ea888 Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Sat, 9 Jul 2022 22:09:25 -0600
Subject: [PATCH] Rescale causal scale along with timestep spacing

---
 codes/models/diffusion/gaussian_diffusion.py | 41 +++++++++++++++-----
 codes/models/diffusion/respace.py            |  3 ++
 codes/trainer/eval/music_diffusion_fid.py    |  4 +-
 3 files changed, 36 insertions(+), 12 deletions(-)

diff --git a/codes/models/diffusion/gaussian_diffusion.py b/codes/models/diffusion/gaussian_diffusion.py
index 678adcba..1cf59dd9 100644
--- a/codes/models/diffusion/gaussian_diffusion.py
+++ b/codes/models/diffusion/gaussian_diffusion.py
@@ -412,6 +412,9 @@ class GaussianDiffusion:
             return t.float() * (1000.0 / self.num_timesteps)
         return t
 
+    def _get_scale_ratio(self):
+        return 1
+
     def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
         """
         Compute the mean for the previous step, given a function cond_fn that
@@ -586,7 +589,7 @@ class GaussianDiffusion:
             t = th.tensor([i] * shape[0], device=device)
             mask = torch.zeros_like(img)
             if causal:
-                t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope, add_jitter=False).unsqueeze(1)
+                t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope * self._get_scale_ratio(), add_jitter=False).unsqueeze(1)
                 t, mask = causal_mask_and_fix(t, self.num_timesteps)
                 mask = mask.repeat(img.shape[0], img.shape[1], 1)
             with th.no_grad():
@@ -816,7 +819,7 @@ class GaussianDiffusion:
             t = th.tensor([i] * shape[0], device=device)
             mask = torch.zeros_like(img)
             if causal:
-                t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope, add_jitter=False).unsqueeze(1)
+                t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope * self._get_scale_ratio(), add_jitter=False).unsqueeze(1)
                 t, mask = causal_mask_and_fix(t, self.num_timesteps)
                 t[mask] = self.num_timesteps-1
                 mask = mask.repeat(img.shape[0], img.shape[1], 1)
@@ -880,7 +883,7 @@ class GaussianDiffusion:
         Compute training losses for a causal diffusion process.
         """
         assert len(x_start.shape) == 3, "causal_training_losses assumes a 1d sequence with the axis being the time axis."
-        ct = causal_timestep_adjustment(t, x_start.shape[-1], self.num_timesteps, causal_slope, add_jitter=True)
+        ct = causal_timestep_adjustment(t, x_start.shape[-1], self.num_timesteps, causal_slope * self._get_scale_ratio(), add_jitter=True)
         ct = ct.unsqueeze(1)  # Necessary to make the output shape compatible with x_start.
         return self.training_losses(model, x_start, ct, model_kwargs, noise, channel_balancing_fn)
 
@@ -1095,19 +1098,18 @@ def test_causal_training_losses():
 
 def graph_causal_timestep_adjustment():
     import matplotlib.pyplot as plt
-    S = 2000
+    S = 400
     #slope=4
     num_timesteps=4000
-    for slpe in range(0, 200, 10):
+    for slpe in range(10, 400, 50):
         slope = slpe / 10
         t_res = []
         for t in range(num_timesteps, -1, -num_timesteps//50):
             T = causal_timestep_adjustment(torch.tensor([t]), S, num_timesteps, causal_slope=slope, add_jitter=False)[0]
 
             # The following adjustment makes it easier to visualize the timestep regions where the model is actually working.
-            T_adj = (T == num_timesteps).logical_or(T == -1)
-            T[T_adj] = t
-            print(t, T.float().mean())
+            #T_adj = (T == num_timesteps).logical_or(T == -1)
+            #T[T_adj] = t
 
             t_res.append(T)
             plt.plot(T.numpy())
@@ -1116,13 +1118,32 @@ def graph_causal_timestep_adjustment():
             for j in range(len(t_res)):
                 if i == j:
                     continue
-                #assert not torch.all(t_res[i] == t_res[j])
+                assert not torch.all(t_res[i] == t_res[j])
         plt.ylim(0,num_timesteps)
         plt.xlim(0,4000)
         plt.ylabel('timestep')
         plt.savefig(f'{slpe}.png')
         plt.clf()
 
+
+def graph_causal_timestep_adjustment_by_timestep():
+    import matplotlib.pyplot as plt
+    S = 400
+    slope=10
+    num_timesteps=4000
+    t_res = []
+    for t in range(num_timesteps, -1, -num_timesteps//50):
+        T = causal_timestep_adjustment(torch.tensor([t]), S, num_timesteps, causal_slope=slope, add_jitter=False)[0]
+
+        t_res.append(T)
+        plt.plot(T.numpy())
+        plt.ylim(0,num_timesteps)
+        plt.xlim(0,4000)
+        plt.ylabel('timestep')
+        plt.savefig(f'{t}.png')
+        plt.clf()
+
 if __name__ == '__main__':
     #test_causal_training_losses()
-    graph_causal_timestep_adjustment()
\ No newline at end of file
+    #graph_causal_timestep_adjustment()
+    graph_causal_timestep_adjustment_by_timestep()
\ No newline at end of file
diff --git a/codes/models/diffusion/respace.py b/codes/models/diffusion/respace.py
index 78403ebb..03aa9234 100644
--- a/codes/models/diffusion/respace.py
+++ b/codes/models/diffusion/respace.py
@@ -113,6 +113,9 @@ class SpacedDiffusion(GaussianDiffusion):
         # Scaling is done by the wrapped model.
         return t
 
+    def _get_scale_ratio(self):
+        return self.num_timesteps / self.original_num_steps
+
 
 class _WrappedModel:
     def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
diff --git a/codes/trainer/eval/music_diffusion_fid.py b/codes/trainer/eval/music_diffusion_fid.py
index 2e7c994a..cc4890ec 100644
--- a/codes/trainer/eval/music_diffusion_fid.py
+++ b/codes/trainer/eval/music_diffusion_fid.py
@@ -424,11 +424,11 @@ class MusicDiffusionFid(evaluator.Evaluator):
 if __name__ == '__main__':
     diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_cheater_gen.yml', 'generator',
                                        also_load_savepoint=False,
-                                       load_path='X:\\dlas\\experiments\\train_music_cheater_gen_v5_causal_retrain\\models\\18000_generator.pth'
+                                       load_path='X:\\dlas\\experiments\\train_music_cheater_gen_v5_causal_retrain\\models\\22000_generator_ema.pth'
                                        ).cuda()
     opt_eval = {'path': 'Y:\\split\\yt-music-eval',  # eval music, mostly electronica. :)
                 #'path': 'E:\\music_eval',  # this is music from the training dataset, including a lot more variety.
-                'diffusion_steps': 256,
+                'diffusion_steps': 100,
                 'conditioning_free': True, 'conditioning_free_k': 1, 'use_ddim': False, 'clip_audio': False,
                 'diffusion_schedule': 'linear', 'diffusion_type': 'cheater_gen',
                 'causal': True, 'causal_slope': 1,