diff --git a/codes/trainer/injectors/gaussian_diffusion_injector.py b/codes/trainer/injectors/gaussian_diffusion_injector.py index 292bf1e1..ce34ba66 100644 --- a/codes/trainer/injectors/gaussian_diffusion_injector.py +++ b/codes/trainer/injectors/gaussian_diffusion_injector.py @@ -1,3 +1,6 @@ +import random +import time + import torch from models.diffusion.gaussian_diffusion import GaussianDiffusion, get_named_beta_schedule @@ -26,6 +29,12 @@ class GaussianDiffusionInjector(Injector): def forward(self, state): gen = self.env['generators'][self.opt['generator']] hq = state[self.input] + + # In eval mode, seed torch with a deterministic seed for reproducibility. + if not gen.trainable: + torch.manual_seed(0) + random.seed(0) + model_inputs = {k: state[v] for k, v in self.model_input_keys.items()} t, weights = self.schedule_sampler.sample(hq.shape[0], hq.device) diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs) @@ -40,6 +49,12 @@ class GaussianDiffusionInjector(Injector): out.update({self.output: diffusion_outputs['mse'], self.output_variational_bounds_key: diffusion_outputs['vb'], self.output_x_start_key: diffusion_outputs['x_start_predicted']}) + + # Absolutely critical to undo the above seed. + if not gen.trainable: + torch.manual_seed(int(time.time())) + random.seed(int(time.time())) + return out