A few fixes:
- Output better prediction of xstart from eps - Support LossAwareSampler - Support AdamW
This commit is contained in:
parent
fa908a6a15
commit
80d4404367
|
@ -757,7 +757,6 @@ class GaussianDiffusion:
|
||||||
terms = {}
|
terms = {}
|
||||||
|
|
||||||
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
|
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
|
||||||
x_start_pred = torch.zeros_like(x_start) # This type of model doesn't predict x_start.
|
|
||||||
terms["loss"] = self._vb_terms_bpd(
|
terms["loss"] = self._vb_terms_bpd(
|
||||||
model=model,
|
model=model,
|
||||||
x_start=x_start,
|
x_start=x_start,
|
||||||
|
@ -803,7 +802,7 @@ class GaussianDiffusion:
|
||||||
x_start_pred = model_output
|
x_start_pred = model_output
|
||||||
elif self.model_mean_type == ModelMeanType.EPSILON:
|
elif self.model_mean_type == ModelMeanType.EPSILON:
|
||||||
target = noise
|
target = noise
|
||||||
x_start_pred = x_t - model_output
|
x_start_pred = self._predict_xstart_from_eps(x_t, t, model_output)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(self.model_mean_type)
|
raise NotImplementedError(self.model_mean_type)
|
||||||
assert model_output.shape == target.shape == x_start.shape
|
assert model_output.shape == target.shape == x_start.shape
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from models.diffusion.gaussian_diffusion import GaussianDiffusion, get_named_beta_schedule
|
from models.diffusion.gaussian_diffusion import GaussianDiffusion, get_named_beta_schedule
|
||||||
from models.diffusion.resample import create_named_schedule_sampler
|
from models.diffusion.resample import create_named_schedule_sampler, LossAwareSampler
|
||||||
from models.diffusion.respace import space_timesteps, SpacedDiffusion
|
from models.diffusion.respace import space_timesteps, SpacedDiffusion
|
||||||
from trainer.inject import Injector
|
from trainer.inject import Injector
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
@ -27,6 +27,8 @@ class GaussianDiffusionInjector(Injector):
|
||||||
model_inputs = {k: state[v] for k, v in self.model_input_keys.items()}
|
model_inputs = {k: state[v] for k, v in self.model_input_keys.items()}
|
||||||
t, weights = self.schedule_sampler.sample(hq.shape[0], hq.device)
|
t, weights = self.schedule_sampler.sample(hq.shape[0], hq.device)
|
||||||
diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs)
|
diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs)
|
||||||
|
if isinstance(self.schedule_sampler, LossAwareSampler):
|
||||||
|
self.schedule_sampler.update_with_local_losses(t, diffusion_outputs['losses'])
|
||||||
return {self.output: diffusion_outputs['mse'],
|
return {self.output: diffusion_outputs['mse'],
|
||||||
self.output_variational_bounds_key: diffusion_outputs['vb'],
|
self.output_variational_bounds_key: diffusion_outputs['vb'],
|
||||||
self.output_x_start_key: diffusion_outputs['x_start_predicted']}
|
self.output_x_start_key: diffusion_outputs['x_start_predicted']}
|
||||||
|
|
|
@ -101,6 +101,10 @@ class ConfigurableStep(Module):
|
||||||
opt = torch.optim.Adam(list(optim_params.values()),
|
opt = torch.optim.Adam(list(optim_params.values()),
|
||||||
weight_decay=opt_config['weight_decay'],
|
weight_decay=opt_config['weight_decay'],
|
||||||
betas=(opt_config['beta1'], opt_config['beta2']))
|
betas=(opt_config['beta1'], opt_config['beta2']))
|
||||||
|
elif self.step_opt['optimizer'] == 'adamw':
|
||||||
|
opt = torch.optim.AdamW(list(optim_params.values()),
|
||||||
|
weight_decay=opt_config['weight_decay'],
|
||||||
|
betas=(opt_config['beta1'], opt_config['beta2']))
|
||||||
elif self.step_opt['optimizer'] == 'lars':
|
elif self.step_opt['optimizer'] == 'lars':
|
||||||
from trainer.optimizers.larc import LARC
|
from trainer.optimizers.larc import LARC
|
||||||
from trainer.optimizers.sgd import SGDNoBiasMomentum
|
from trainer.optimizers.sgd import SGDNoBiasMomentum
|
||||||
|
|
Loading…
Reference in New Issue
Block a user