From 80d44043674b7fe07e00d9ba13f9b914e70b5723 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 5 Jun 2021 13:40:32 -0600 Subject: [PATCH] A few fixes: - Output better prediction of xstart from eps - Support LossAwareSampler - Support AdamW --- codes/models/diffusion/gaussian_diffusion.py | 3 +-- codes/trainer/injectors/gaussian_diffusion_injector.py | 4 +++- codes/trainer/steps.py | 4 ++++ 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/codes/models/diffusion/gaussian_diffusion.py b/codes/models/diffusion/gaussian_diffusion.py index 61a2d25b..2b57f34e 100644 --- a/codes/models/diffusion/gaussian_diffusion.py +++ b/codes/models/diffusion/gaussian_diffusion.py @@ -757,7 +757,6 @@ class GaussianDiffusion: terms = {} if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: - x_start_pred = torch.zeros_like(x_start) # This type of model doesn't predict x_start. terms["loss"] = self._vb_terms_bpd( model=model, x_start=x_start, @@ -803,7 +802,7 @@ class GaussianDiffusion: x_start_pred = model_output elif self.model_mean_type == ModelMeanType.EPSILON: target = noise - x_start_pred = x_t - model_output + x_start_pred = self._predict_xstart_from_eps(x_t, t, model_output) else: raise NotImplementedError(self.model_mean_type) assert model_output.shape == target.shape == x_start.shape diff --git a/codes/trainer/injectors/gaussian_diffusion_injector.py b/codes/trainer/injectors/gaussian_diffusion_injector.py index 00636497..ce019937 100644 --- a/codes/trainer/injectors/gaussian_diffusion_injector.py +++ b/codes/trainer/injectors/gaussian_diffusion_injector.py @@ -1,7 +1,7 @@ import torch from models.diffusion.gaussian_diffusion import GaussianDiffusion, get_named_beta_schedule -from models.diffusion.resample import create_named_schedule_sampler +from models.diffusion.resample import create_named_schedule_sampler, LossAwareSampler from models.diffusion.respace import space_timesteps, SpacedDiffusion from trainer.inject import Injector from utils.util import opt_get @@ -27,6 +27,8 @@ class GaussianDiffusionInjector(Injector): model_inputs = {k: state[v] for k, v in self.model_input_keys.items()} t, weights = self.schedule_sampler.sample(hq.shape[0], hq.device) diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs) + if isinstance(self.schedule_sampler, LossAwareSampler): + self.schedule_sampler.update_with_local_losses(t, diffusion_outputs['losses']) return {self.output: diffusion_outputs['mse'], self.output_variational_bounds_key: diffusion_outputs['vb'], self.output_x_start_key: diffusion_outputs['x_start_predicted']} diff --git a/codes/trainer/steps.py b/codes/trainer/steps.py index df4cade0..20533583 100644 --- a/codes/trainer/steps.py +++ b/codes/trainer/steps.py @@ -101,6 +101,10 @@ class ConfigurableStep(Module): opt = torch.optim.Adam(list(optim_params.values()), weight_decay=opt_config['weight_decay'], betas=(opt_config['beta1'], opt_config['beta2'])) + elif self.step_opt['optimizer'] == 'adamw': + opt = torch.optim.AdamW(list(optim_params.values()), + weight_decay=opt_config['weight_decay'], + betas=(opt_config['beta1'], opt_config['beta2'])) elif self.step_opt['optimizer'] == 'lars': from trainer.optimizers.larc import LARC from trainer.optimizers.sgd import SGDNoBiasMomentum