forked from mrq/DL-Art-School
report quantile losses for diffusion
This commit is contained in:
parent
1394213f1e
commit
24e60bd510
|
@ -870,6 +870,7 @@ class GaussianDiffusion:
|
|||
s_err = (target - model_output) ** 2
|
||||
if channel_balancing_fn is not None:
|
||||
s_err = channel_balancing_fn(s_err)
|
||||
terms["mse_by_batch"] = s_err.view(s_err.shape[0], -1).mean(dim=1)
|
||||
terms["mse"] = mean_flat(s_err)
|
||||
terms["x_start_predicted"] = x_start_pred
|
||||
if "vb" in terms:
|
||||
|
|
|
@ -56,12 +56,22 @@ class GaussianDiffusionInjector(Injector):
|
|||
self.channel_balancing_fn = None
|
||||
assert k <= 1, 'Only one channel filtering function can be applied.'
|
||||
|
||||
self.num_timesteps = opt['beta_schedule']['num_diffusion_timesteps']
|
||||
self.latest_mse_by_batch = torch.tensor([0])
|
||||
self.latest_timesteps = torch.tensor([0])
|
||||
|
||||
def extra_metrics(self):
|
||||
uqt = self.latest_timesteps > self.num_timesteps * 3 / 4
|
||||
uql = (self.latest_mse_by_batch * uqt).sum() / uqt.sum()
|
||||
muqt = (self.latest_timesteps > self.num_timesteps / 2) * (self.latest_timesteps < self.num_timesteps * 3 / 4)
|
||||
muql = (self.latest_mse_by_batch * muqt).sum() / muqt.sum()
|
||||
d = {
|
||||
'upper_quantile_mse_loss': uql,
|
||||
'mid_upper_quantile_mse_loss': muql,
|
||||
}
|
||||
if hasattr(self, 'schedule_sampler') and isinstance(self.schedule_sampler, LossSecondMomentResampler):
|
||||
return {
|
||||
'sampler_warmed_up': torch.tensor(float(self.schedule_sampler._warmed_up()))
|
||||
}
|
||||
return {}
|
||||
d['sampler_warmed_up'] = torch.tensor(float(self.schedule_sampler._warmed_up()))
|
||||
return d
|
||||
|
||||
def forward(self, state):
|
||||
gen = self.env['generators'][self.opt['generator']]
|
||||
|
@ -87,6 +97,8 @@ class GaussianDiffusionInjector(Injector):
|
|||
out.update({self.output: diffusion_outputs['mse'],
|
||||
self.output_variational_bounds_key: diffusion_outputs['vb'],
|
||||
self.output_x_start_key: diffusion_outputs['x_start_predicted']})
|
||||
self.latest_mse_by_batch = diffusion_outputs['mse_by_batch'].detach().clone()
|
||||
self.latest_timesteps = t.clone()
|
||||
|
||||
return out
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user