report quantile losses for diffusion

This commit is contained in:
James Betker 2022-06-21 20:04:16 -06:00
parent 1394213f1e
commit 24e60bd510
2 changed files with 17 additions and 4 deletions

View File

@ -870,6 +870,7 @@ class GaussianDiffusion:
s_err = (target - model_output) ** 2
if channel_balancing_fn is not None:
s_err = channel_balancing_fn(s_err)
terms["mse_by_batch"] = s_err.view(s_err.shape[0], -1).mean(dim=1)
terms["mse"] = mean_flat(s_err)
terms["x_start_predicted"] = x_start_pred
if "vb" in terms:

View File

@ -56,12 +56,22 @@ class GaussianDiffusionInjector(Injector):
self.channel_balancing_fn = None
assert k <= 1, 'Only one channel filtering function can be applied.'
self.num_timesteps = opt['beta_schedule']['num_diffusion_timesteps']
self.latest_mse_by_batch = torch.tensor([0])
self.latest_timesteps = torch.tensor([0])
def extra_metrics(self):
uqt = self.latest_timesteps > self.num_timesteps * 3 / 4
uql = (self.latest_mse_by_batch * uqt).sum() / uqt.sum()
muqt = (self.latest_timesteps > self.num_timesteps / 2) * (self.latest_timesteps < self.num_timesteps * 3 / 4)
muql = (self.latest_mse_by_batch * muqt).sum() / muqt.sum()
d = {
'upper_quantile_mse_loss': uql,
'mid_upper_quantile_mse_loss': muql,
}
if hasattr(self, 'schedule_sampler') and isinstance(self.schedule_sampler, LossSecondMomentResampler):
return {
'sampler_warmed_up': torch.tensor(float(self.schedule_sampler._warmed_up()))
}
return {}
d['sampler_warmed_up'] = torch.tensor(float(self.schedule_sampler._warmed_up()))
return d
def forward(self, state):
gen = self.env['generators'][self.opt['generator']]
@ -87,6 +97,8 @@ class GaussianDiffusionInjector(Injector):
out.update({self.output: diffusion_outputs['mse'],
self.output_variational_bounds_key: diffusion_outputs['vb'],
self.output_x_start_key: diffusion_outputs['x_start_predicted']})
self.latest_mse_by_batch = diffusion_outputs['mse_by_batch'].detach().clone()
self.latest_timesteps = t.clone()
return out