forked from mrq/DL-Art-School
Try to make diffusion validator more reproducible
This commit is contained in:
parent
5956eb757c
commit
47fe032a3d
|
@ -1,3 +1,6 @@
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from models.diffusion.gaussian_diffusion import GaussianDiffusion, get_named_beta_schedule
|
from models.diffusion.gaussian_diffusion import GaussianDiffusion, get_named_beta_schedule
|
||||||
|
@ -26,6 +29,12 @@ class GaussianDiffusionInjector(Injector):
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
gen = self.env['generators'][self.opt['generator']]
|
gen = self.env['generators'][self.opt['generator']]
|
||||||
hq = state[self.input]
|
hq = state[self.input]
|
||||||
|
|
||||||
|
# In eval mode, seed torch with a deterministic seed for reproducibility.
|
||||||
|
if not gen.trainable:
|
||||||
|
torch.manual_seed(0)
|
||||||
|
random.seed(0)
|
||||||
|
|
||||||
model_inputs = {k: state[v] for k, v in self.model_input_keys.items()}
|
model_inputs = {k: state[v] for k, v in self.model_input_keys.items()}
|
||||||
t, weights = self.schedule_sampler.sample(hq.shape[0], hq.device)
|
t, weights = self.schedule_sampler.sample(hq.shape[0], hq.device)
|
||||||
diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs)
|
diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs)
|
||||||
|
@ -40,6 +49,12 @@ class GaussianDiffusionInjector(Injector):
|
||||||
out.update({self.output: diffusion_outputs['mse'],
|
out.update({self.output: diffusion_outputs['mse'],
|
||||||
self.output_variational_bounds_key: diffusion_outputs['vb'],
|
self.output_variational_bounds_key: diffusion_outputs['vb'],
|
||||||
self.output_x_start_key: diffusion_outputs['x_start_predicted']})
|
self.output_x_start_key: diffusion_outputs['x_start_predicted']})
|
||||||
|
|
||||||
|
# Absolutely critical to undo the above seed.
|
||||||
|
if not gen.trainable:
|
||||||
|
torch.manual_seed(int(time.time()))
|
||||||
|
random.seed(int(time.time()))
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user