A few fixes:
- Output better prediction of xstart from eps - Support LossAwareSampler - Support AdamW
This commit is contained in:
parent
fa908a6a15
commit
80d4404367
|
@ -757,7 +757,6 @@ class GaussianDiffusion:
|
|||
terms = {}
|
||||
|
||||
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
|
||||
x_start_pred = torch.zeros_like(x_start) # This type of model doesn't predict x_start.
|
||||
terms["loss"] = self._vb_terms_bpd(
|
||||
model=model,
|
||||
x_start=x_start,
|
||||
|
@ -803,7 +802,7 @@ class GaussianDiffusion:
|
|||
x_start_pred = model_output
|
||||
elif self.model_mean_type == ModelMeanType.EPSILON:
|
||||
target = noise
|
||||
x_start_pred = x_t - model_output
|
||||
x_start_pred = self._predict_xstart_from_eps(x_t, t, model_output)
|
||||
else:
|
||||
raise NotImplementedError(self.model_mean_type)
|
||||
assert model_output.shape == target.shape == x_start.shape
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
|
||||
from models.diffusion.gaussian_diffusion import GaussianDiffusion, get_named_beta_schedule
|
||||
from models.diffusion.resample import create_named_schedule_sampler
|
||||
from models.diffusion.resample import create_named_schedule_sampler, LossAwareSampler
|
||||
from models.diffusion.respace import space_timesteps, SpacedDiffusion
|
||||
from trainer.inject import Injector
|
||||
from utils.util import opt_get
|
||||
|
@ -27,6 +27,8 @@ class GaussianDiffusionInjector(Injector):
|
|||
model_inputs = {k: state[v] for k, v in self.model_input_keys.items()}
|
||||
t, weights = self.schedule_sampler.sample(hq.shape[0], hq.device)
|
||||
diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs)
|
||||
if isinstance(self.schedule_sampler, LossAwareSampler):
|
||||
self.schedule_sampler.update_with_local_losses(t, diffusion_outputs['losses'])
|
||||
return {self.output: diffusion_outputs['mse'],
|
||||
self.output_variational_bounds_key: diffusion_outputs['vb'],
|
||||
self.output_x_start_key: diffusion_outputs['x_start_predicted']}
|
||||
|
|
|
@ -101,6 +101,10 @@ class ConfigurableStep(Module):
|
|||
opt = torch.optim.Adam(list(optim_params.values()),
|
||||
weight_decay=opt_config['weight_decay'],
|
||||
betas=(opt_config['beta1'], opt_config['beta2']))
|
||||
elif self.step_opt['optimizer'] == 'adamw':
|
||||
opt = torch.optim.AdamW(list(optim_params.values()),
|
||||
weight_decay=opt_config['weight_decay'],
|
||||
betas=(opt_config['beta1'], opt_config['beta2']))
|
||||
elif self.step_opt['optimizer'] == 'lars':
|
||||
from trainer.optimizers.larc import LARC
|
||||
from trainer.optimizers.sgd import SGDNoBiasMomentum
|
||||
|
|
Loading…
Reference in New Issue
Block a user