diff --git a/codes/trainer/injectors/gaussian_diffusion_injector.py b/codes/trainer/injectors/gaussian_diffusion_injector.py index ed6213da..576dc52b 100644 --- a/codes/trainer/injectors/gaussian_diffusion_injector.py +++ b/codes/trainer/injectors/gaussian_diffusion_injector.py @@ -16,7 +16,8 @@ class GaussianDiffusionInjector(Injector): self.output_variational_bounds_key = opt['out_key_vb_loss'] self.output_x_start_key = opt['out_key_x_start'] opt['diffusion_args']['betas'] = get_named_beta_schedule(**opt['beta_schedule']) - opt['diffusion_args']['use_timesteps'] = space_timesteps(opt['beta_schedule']['num_diffusion_timesteps'], [opt['beta_schedule']['num_diffusion_timesteps']]) # TODO: Figure out how these work and specify them differently. + opt['diffusion_args']['use_timesteps'] = space_timesteps(opt['beta_schedule']['num_diffusion_timesteps'], + [opt['beta_schedule']['num_diffusion_timesteps']]) self.diffusion = SpacedDiffusion(**opt['diffusion_args']) self.schedule_sampler = create_named_schedule_sampler(opt['sampler_type'], self.diffusion) self.model_input_keys = opt_get(opt, ['model_input_keys'], []) @@ -41,7 +42,8 @@ class GaussianDiffusionInferenceInjector(Injector): self.generator = opt['generator'] self.output_shape = opt['output_shape'] opt['diffusion_args']['betas'] = get_named_beta_schedule(**opt['beta_schedule']) - opt['diffusion_args']['use_timesteps'] = space_timesteps(opt['beta_schedule']['num_diffusion_timesteps'], [opt_get(opt, ['respaced_timestep_spacing'], opt['beta_schedule']['num_diffusion_timesteps'])]) # TODO: Figure out how these work and specify them differently. + opt['diffusion_args']['use_timesteps'] = space_timesteps(opt['beta_schedule']['num_diffusion_timesteps'], + [opt_get(opt, ['respaced_timestep_spacing'], opt['beta_schedule']['num_diffusion_timesteps'])]) self.diffusion = SpacedDiffusion(**opt['diffusion_args']) self.model_input_keys = opt_get(opt, ['model_input_keys'], [])