diff --git a/codes/models/diffusion/resample.py b/codes/models/diffusion/resample.py index c82eccdc..c59a1cc5 100644 --- a/codes/models/diffusion/resample.py +++ b/codes/models/diffusion/resample.py @@ -67,6 +67,21 @@ class UniformSampler(ScheduleSampler): return self._weights +class DeterministicSampler: + """ + Returns the same equally spread-out sampling schedule every time it is called. + """ + def __init__(self, diffusion): + super().__init__() + self.timesteps = diffusion.num_timesteps + + def sample(self, batch_size, device): + rnge = th.arange(0, batch_size, device=device).float() / batch_size + indices = (rnge * self.timesteps).long() + weights = th.ones_like(indices).float() + return indices, weights + + class LossAwareSampler(ScheduleSampler): def update_with_local_losses(self, local_ts, local_losses): """ diff --git a/codes/trainer/injectors/gaussian_diffusion_injector.py b/codes/trainer/injectors/gaussian_diffusion_injector.py index 8119ef82..f93c363e 100644 --- a/codes/trainer/injectors/gaussian_diffusion_injector.py +++ b/codes/trainer/injectors/gaussian_diffusion_injector.py @@ -5,7 +5,7 @@ import torch from torch.cuda.amp import autocast from models.diffusion.gaussian_diffusion import GaussianDiffusion, get_named_beta_schedule -from models.diffusion.resample import create_named_schedule_sampler, LossAwareSampler +from models.diffusion.resample import create_named_schedule_sampler, LossAwareSampler, DeterministicSampler from models.diffusion.respace import space_timesteps, SpacedDiffusion from trainer.inject import Injector from utils.util import opt_get @@ -26,22 +26,22 @@ class GaussianDiffusionInjector(Injector): self.schedule_sampler = create_named_schedule_sampler(opt['sampler_type'], self.diffusion) self.model_input_keys = opt_get(opt, ['model_input_keys'], []) self.extra_model_output_keys = opt_get(opt, ['extra_model_output_keys'], []) + self.deterministic_timesteps_every = opt_get(opt, ['deterministic_timesteps_every'], 0) def forward(self, state): gen = self.env['generators'][self.opt['generator']] hq = state[self.input] - # In eval mode, seed torch with a deterministic seed for reproducibility. - if not gen.training: - torch.manual_seed(0) - random.seed(0) - with autocast(enabled=self.env['opt']['fp16']): + if not gen.training or (self.deterministic_timesteps_every != 0 and self.env['step'] % self.deterministic_timesteps_every == 0): + sampler = DeterministicSampler(self.diffusion) + else: + sampler = self.schedule_sampler model_inputs = {k: state[v] for k, v in self.model_input_keys.items()} - t, weights = self.schedule_sampler.sample(hq.shape[0], hq.device) + t, weights = sampler.sample(hq.shape[0], hq.device) diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs) - if isinstance(self.schedule_sampler, LossAwareSampler): - self.schedule_sampler.update_with_local_losses(t, diffusion_outputs['losses']) + if isinstance(sampler, LossAwareSampler): + sampler.update_with_local_losses(t, diffusion_outputs['losses']) if len(self.extra_model_output_keys) > 0: assert(len(self.extra_model_output_keys) == len(diffusion_outputs['extra_outputs'])) @@ -52,11 +52,6 @@ class GaussianDiffusionInjector(Injector): self.output_variational_bounds_key: diffusion_outputs['vb'], self.output_x_start_key: diffusion_outputs['x_start_predicted']}) - # Absolutely critical to undo the above seed. - if not gen.training: - torch.manual_seed(int(time.time())) - random.seed(int(time.time())) - return out