From 2d1cb83c1d4056d4b7fcc73ba6ee29e8d94cadfd Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Fri, 4 Mar 2022 10:40:14 -0700
Subject: [PATCH] Add a deterministic timestep sampler, with provisions to
 employ it every n steps

---
 codes/models/diffusion/resample.py            | 15 ++++++++++++
 .../injectors/gaussian_diffusion_injector.py  | 23 ++++++++-----------
 2 files changed, 24 insertions(+), 14 deletions(-)

diff --git a/codes/models/diffusion/resample.py b/codes/models/diffusion/resample.py
index c82eccdc..c59a1cc5 100644
--- a/codes/models/diffusion/resample.py
+++ b/codes/models/diffusion/resample.py
@@ -67,6 +67,21 @@ class UniformSampler(ScheduleSampler):
         return self._weights
 
 
+class DeterministicSampler:
+    """
+    Returns the same equally spread-out sampling schedule every time it is called.
+    """
+    def __init__(self, diffusion):
+        super().__init__()
+        self.timesteps = diffusion.num_timesteps
+
+    def sample(self, batch_size, device):
+        rnge = th.arange(0, batch_size, device=device).float() / batch_size
+        indices = (rnge * self.timesteps).long()
+        weights = th.ones_like(indices).float()
+        return indices, weights
+
+
 class LossAwareSampler(ScheduleSampler):
     def update_with_local_losses(self, local_ts, local_losses):
         """
diff --git a/codes/trainer/injectors/gaussian_diffusion_injector.py b/codes/trainer/injectors/gaussian_diffusion_injector.py
index 8119ef82..f93c363e 100644
--- a/codes/trainer/injectors/gaussian_diffusion_injector.py
+++ b/codes/trainer/injectors/gaussian_diffusion_injector.py
@@ -5,7 +5,7 @@ import torch
 from torch.cuda.amp import autocast
 
 from models.diffusion.gaussian_diffusion import GaussianDiffusion, get_named_beta_schedule
-from models.diffusion.resample import create_named_schedule_sampler, LossAwareSampler
+from models.diffusion.resample import create_named_schedule_sampler, LossAwareSampler, DeterministicSampler
 from models.diffusion.respace import space_timesteps, SpacedDiffusion
 from trainer.inject import Injector
 from utils.util import opt_get
@@ -26,22 +26,22 @@ class GaussianDiffusionInjector(Injector):
         self.schedule_sampler = create_named_schedule_sampler(opt['sampler_type'], self.diffusion)
         self.model_input_keys = opt_get(opt, ['model_input_keys'], [])
         self.extra_model_output_keys = opt_get(opt, ['extra_model_output_keys'], [])
+        self.deterministic_timesteps_every = opt_get(opt, ['deterministic_timesteps_every'], 0)
 
     def forward(self, state):
         gen = self.env['generators'][self.opt['generator']]
         hq = state[self.input]
 
-        # In eval mode, seed torch with a deterministic seed for reproducibility.
-        if not gen.training:
-            torch.manual_seed(0)
-            random.seed(0)
-
         with autocast(enabled=self.env['opt']['fp16']):
+            if not gen.training or (self.deterministic_timesteps_every != 0 and self.env['step'] % self.deterministic_timesteps_every == 0):
+                sampler = DeterministicSampler(self.diffusion)
+            else:
+                sampler = self.schedule_sampler
             model_inputs = {k: state[v] for k, v in self.model_input_keys.items()}
-            t, weights = self.schedule_sampler.sample(hq.shape[0], hq.device)
+            t, weights = sampler.sample(hq.shape[0], hq.device)
             diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs)
-            if isinstance(self.schedule_sampler, LossAwareSampler):
-                self.schedule_sampler.update_with_local_losses(t, diffusion_outputs['losses'])
+            if isinstance(sampler, LossAwareSampler):
+                sampler.update_with_local_losses(t, diffusion_outputs['losses'])
 
             if len(self.extra_model_output_keys) > 0:
                 assert(len(self.extra_model_output_keys) == len(diffusion_outputs['extra_outputs']))
@@ -52,11 +52,6 @@ class GaussianDiffusionInjector(Injector):
                     self.output_variational_bounds_key: diffusion_outputs['vb'],
                     self.output_x_start_key: diffusion_outputs['x_start_predicted']})
 
-        # Absolutely critical to undo the above seed.
-        if not gen.training:
-            torch.manual_seed(int(time.time()))
-            random.seed(int(time.time()))
-
         return out