forked from mrq/DL-Art-School
feed direct inputs into gd
This commit is contained in:
parent
6909f196b4
commit
9b90472e15
|
@ -39,7 +39,7 @@ class GaussianDiffusionInjector(Injector):
|
|||
else:
|
||||
sampler = self.schedule_sampler
|
||||
self.deterministic_sampler.reset() # Keep this reset whenever it is not being used, so it is ready to use automatically.
|
||||
model_inputs = {k: state[v] for k, v in self.model_input_keys.items()}
|
||||
model_inputs = {k: state[v] if isinstance(v, str) else v for k, v in self.model_input_keys.items()}
|
||||
t, weights = sampler.sample(hq.shape[0], hq.device)
|
||||
diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs)
|
||||
if isinstance(sampler, LossAwareSampler):
|
||||
|
|
Loading…
Reference in New Issue
Block a user