Try to make diffusion validator more reproducible

This commit is contained in:
James Betker 2021-11-24 09:38:10 -07:00
parent 5956eb757c
commit 47fe032a3d

View File

@ -1,3 +1,6 @@
import random
import time
import torch import torch
from models.diffusion.gaussian_diffusion import GaussianDiffusion, get_named_beta_schedule from models.diffusion.gaussian_diffusion import GaussianDiffusion, get_named_beta_schedule
@ -26,6 +29,12 @@ class GaussianDiffusionInjector(Injector):
def forward(self, state): def forward(self, state):
gen = self.env['generators'][self.opt['generator']] gen = self.env['generators'][self.opt['generator']]
hq = state[self.input] hq = state[self.input]
# In eval mode, seed torch with a deterministic seed for reproducibility.
if not gen.trainable:
torch.manual_seed(0)
random.seed(0)
model_inputs = {k: state[v] for k, v in self.model_input_keys.items()} model_inputs = {k: state[v] for k, v in self.model_input_keys.items()}
t, weights = self.schedule_sampler.sample(hq.shape[0], hq.device) t, weights = self.schedule_sampler.sample(hq.shape[0], hq.device)
diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs) diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs)
@ -40,6 +49,12 @@ class GaussianDiffusionInjector(Injector):
out.update({self.output: diffusion_outputs['mse'], out.update({self.output: diffusion_outputs['mse'],
self.output_variational_bounds_key: diffusion_outputs['vb'], self.output_variational_bounds_key: diffusion_outputs['vb'],
self.output_x_start_key: diffusion_outputs['x_start_predicted']}) self.output_x_start_key: diffusion_outputs['x_start_predicted']})
# Absolutely critical to undo the above seed.
if not gen.trainable:
torch.manual_seed(int(time.time()))
random.seed(int(time.time()))
return out return out