forked from mrq/DL-Art-School
report quantile losses for diffusion
This commit is contained in:
parent
1394213f1e
commit
24e60bd510
|
@ -870,6 +870,7 @@ class GaussianDiffusion:
|
||||||
s_err = (target - model_output) ** 2
|
s_err = (target - model_output) ** 2
|
||||||
if channel_balancing_fn is not None:
|
if channel_balancing_fn is not None:
|
||||||
s_err = channel_balancing_fn(s_err)
|
s_err = channel_balancing_fn(s_err)
|
||||||
|
terms["mse_by_batch"] = s_err.view(s_err.shape[0], -1).mean(dim=1)
|
||||||
terms["mse"] = mean_flat(s_err)
|
terms["mse"] = mean_flat(s_err)
|
||||||
terms["x_start_predicted"] = x_start_pred
|
terms["x_start_predicted"] = x_start_pred
|
||||||
if "vb" in terms:
|
if "vb" in terms:
|
||||||
|
|
|
@ -56,12 +56,22 @@ class GaussianDiffusionInjector(Injector):
|
||||||
self.channel_balancing_fn = None
|
self.channel_balancing_fn = None
|
||||||
assert k <= 1, 'Only one channel filtering function can be applied.'
|
assert k <= 1, 'Only one channel filtering function can be applied.'
|
||||||
|
|
||||||
|
self.num_timesteps = opt['beta_schedule']['num_diffusion_timesteps']
|
||||||
|
self.latest_mse_by_batch = torch.tensor([0])
|
||||||
|
self.latest_timesteps = torch.tensor([0])
|
||||||
|
|
||||||
def extra_metrics(self):
|
def extra_metrics(self):
|
||||||
|
uqt = self.latest_timesteps > self.num_timesteps * 3 / 4
|
||||||
|
uql = (self.latest_mse_by_batch * uqt).sum() / uqt.sum()
|
||||||
|
muqt = (self.latest_timesteps > self.num_timesteps / 2) * (self.latest_timesteps < self.num_timesteps * 3 / 4)
|
||||||
|
muql = (self.latest_mse_by_batch * muqt).sum() / muqt.sum()
|
||||||
|
d = {
|
||||||
|
'upper_quantile_mse_loss': uql,
|
||||||
|
'mid_upper_quantile_mse_loss': muql,
|
||||||
|
}
|
||||||
if hasattr(self, 'schedule_sampler') and isinstance(self.schedule_sampler, LossSecondMomentResampler):
|
if hasattr(self, 'schedule_sampler') and isinstance(self.schedule_sampler, LossSecondMomentResampler):
|
||||||
return {
|
d['sampler_warmed_up'] = torch.tensor(float(self.schedule_sampler._warmed_up()))
|
||||||
'sampler_warmed_up': torch.tensor(float(self.schedule_sampler._warmed_up()))
|
return d
|
||||||
}
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
gen = self.env['generators'][self.opt['generator']]
|
gen = self.env['generators'][self.opt['generator']]
|
||||||
|
@ -87,6 +97,8 @@ class GaussianDiffusionInjector(Injector):
|
||||||
out.update({self.output: diffusion_outputs['mse'],
|
out.update({self.output: diffusion_outputs['mse'],
|
||||||
self.output_variational_bounds_key: diffusion_outputs['vb'],
|
self.output_variational_bounds_key: diffusion_outputs['vb'],
|
||||||
self.output_x_start_key: diffusion_outputs['x_start_predicted']})
|
self.output_x_start_key: diffusion_outputs['x_start_predicted']})
|
||||||
|
self.latest_mse_by_batch = diffusion_outputs['mse_by_batch'].detach().clone()
|
||||||
|
self.latest_timesteps = t.clone()
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user