A few fixes:

- Output better prediction of xstart from eps
- Support LossAwareSampler
- Support AdamW
This commit is contained in:
James Betker 2021-06-05 13:40:32 -06:00
parent fa908a6a15
commit 80d4404367
3 changed files with 8 additions and 3 deletions

View File

@ -757,7 +757,6 @@ class GaussianDiffusion:
terms = {}
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
x_start_pred = torch.zeros_like(x_start) # This type of model doesn't predict x_start.
terms["loss"] = self._vb_terms_bpd(
model=model,
x_start=x_start,
@ -803,7 +802,7 @@ class GaussianDiffusion:
x_start_pred = model_output
elif self.model_mean_type == ModelMeanType.EPSILON:
target = noise
x_start_pred = x_t - model_output
x_start_pred = self._predict_xstart_from_eps(x_t, t, model_output)
else:
raise NotImplementedError(self.model_mean_type)
assert model_output.shape == target.shape == x_start.shape

View File

@ -1,7 +1,7 @@
import torch
from models.diffusion.gaussian_diffusion import GaussianDiffusion, get_named_beta_schedule
from models.diffusion.resample import create_named_schedule_sampler
from models.diffusion.resample import create_named_schedule_sampler, LossAwareSampler
from models.diffusion.respace import space_timesteps, SpacedDiffusion
from trainer.inject import Injector
from utils.util import opt_get
@ -27,6 +27,8 @@ class GaussianDiffusionInjector(Injector):
model_inputs = {k: state[v] for k, v in self.model_input_keys.items()}
t, weights = self.schedule_sampler.sample(hq.shape[0], hq.device)
diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs)
if isinstance(self.schedule_sampler, LossAwareSampler):
self.schedule_sampler.update_with_local_losses(t, diffusion_outputs['losses'])
return {self.output: diffusion_outputs['mse'],
self.output_variational_bounds_key: diffusion_outputs['vb'],
self.output_x_start_key: diffusion_outputs['x_start_predicted']}

View File

@ -101,6 +101,10 @@ class ConfigurableStep(Module):
opt = torch.optim.Adam(list(optim_params.values()),
weight_decay=opt_config['weight_decay'],
betas=(opt_config['beta1'], opt_config['beta2']))
elif self.step_opt['optimizer'] == 'adamw':
opt = torch.optim.AdamW(list(optim_params.values()),
weight_decay=opt_config['weight_decay'],
betas=(opt_config['beta1'], opt_config['beta2']))
elif self.step_opt['optimizer'] == 'lars':
from trainer.optimizers.larc import LARC
from trainer.optimizers.sgd import SGDNoBiasMomentum