feed direct inputs into gd

This commit is contained in:
James Betker 2022-03-26 08:36:19 -06:00
parent 6909f196b4
commit 9b90472e15

View File

@ -39,7 +39,7 @@ class GaussianDiffusionInjector(Injector):
else:
sampler = self.schedule_sampler
self.deterministic_sampler.reset() # Keep this reset whenever it is not being used, so it is ready to use automatically.
model_inputs = {k: state[v] for k, v in self.model_input_keys.items()}
model_inputs = {k: state[v] if isinstance(v, str) else v for k, v in self.model_input_keys.items()}
t, weights = sampler.sample(hq.shape[0], hq.device)
diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs)
if isinstance(sampler, LossAwareSampler):