forked from mrq/DL-Art-School
Try to make diffusion validator more reproducible
This commit is contained in:
parent
5956eb757c
commit
47fe032a3d
|
@ -1,3 +1,6 @@
|
|||
import random
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
from models.diffusion.gaussian_diffusion import GaussianDiffusion, get_named_beta_schedule
|
||||
|
@ -26,6 +29,12 @@ class GaussianDiffusionInjector(Injector):
|
|||
def forward(self, state):
|
||||
gen = self.env['generators'][self.opt['generator']]
|
||||
hq = state[self.input]
|
||||
|
||||
# In eval mode, seed torch with a deterministic seed for reproducibility.
|
||||
if not gen.trainable:
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
|
||||
model_inputs = {k: state[v] for k, v in self.model_input_keys.items()}
|
||||
t, weights = self.schedule_sampler.sample(hq.shape[0], hq.device)
|
||||
diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs)
|
||||
|
@ -40,6 +49,12 @@ class GaussianDiffusionInjector(Injector):
|
|||
out.update({self.output: diffusion_outputs['mse'],
|
||||
self.output_variational_bounds_key: diffusion_outputs['vb'],
|
||||
self.output_x_start_key: diffusion_outputs['x_start_predicted']})
|
||||
|
||||
# Absolutely critical to undo the above seed.
|
||||
if not gen.trainable:
|
||||
torch.manual_seed(int(time.time()))
|
||||
random.seed(int(time.time()))
|
||||
|
||||
return out
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user