diff --git a/codes/trainer/injectors/gaussian_diffusion_injector.py b/codes/trainer/injectors/gaussian_diffusion_injector.py index 45988323..2c94846b 100644 --- a/codes/trainer/injectors/gaussian_diffusion_injector.py +++ b/codes/trainer/injectors/gaussian_diffusion_injector.py @@ -39,7 +39,7 @@ class GaussianDiffusionInjector(Injector): else: sampler = self.schedule_sampler self.deterministic_sampler.reset() # Keep this reset whenever it is not being used, so it is ready to use automatically. - model_inputs = {k: state[v] for k, v in self.model_input_keys.items()} + model_inputs = {k: state[v] if isinstance(v, str) else v for k, v in self.model_input_keys.items()} t, weights = sampler.sample(hq.shape[0], hq.device) diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs) if isinstance(sampler, LossAwareSampler):