From 24e60bd510dac3a7eac2f3c948afa65b4f0c738e Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 21 Jun 2022 20:04:16 -0600 Subject: [PATCH] report quantile losses for diffusion --- codes/models/diffusion/gaussian_diffusion.py | 1 + .../injectors/gaussian_diffusion_injector.py | 20 +++++++++++++++---- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/codes/models/diffusion/gaussian_diffusion.py b/codes/models/diffusion/gaussian_diffusion.py index 77751a72..472b210b 100644 --- a/codes/models/diffusion/gaussian_diffusion.py +++ b/codes/models/diffusion/gaussian_diffusion.py @@ -870,6 +870,7 @@ class GaussianDiffusion: s_err = (target - model_output) ** 2 if channel_balancing_fn is not None: s_err = channel_balancing_fn(s_err) + terms["mse_by_batch"] = s_err.view(s_err.shape[0], -1).mean(dim=1) terms["mse"] = mean_flat(s_err) terms["x_start_predicted"] = x_start_pred if "vb" in terms: diff --git a/codes/trainer/injectors/gaussian_diffusion_injector.py b/codes/trainer/injectors/gaussian_diffusion_injector.py index d20678bc..7f82a821 100644 --- a/codes/trainer/injectors/gaussian_diffusion_injector.py +++ b/codes/trainer/injectors/gaussian_diffusion_injector.py @@ -56,12 +56,22 @@ class GaussianDiffusionInjector(Injector): self.channel_balancing_fn = None assert k <= 1, 'Only one channel filtering function can be applied.' + self.num_timesteps = opt['beta_schedule']['num_diffusion_timesteps'] + self.latest_mse_by_batch = torch.tensor([0]) + self.latest_timesteps = torch.tensor([0]) + def extra_metrics(self): + uqt = self.latest_timesteps > self.num_timesteps * 3 / 4 + uql = (self.latest_mse_by_batch * uqt).sum() / uqt.sum() + muqt = (self.latest_timesteps > self.num_timesteps / 2) * (self.latest_timesteps < self.num_timesteps * 3 / 4) + muql = (self.latest_mse_by_batch * muqt).sum() / muqt.sum() + d = { + 'upper_quantile_mse_loss': uql, + 'mid_upper_quantile_mse_loss': muql, + } if hasattr(self, 'schedule_sampler') and isinstance(self.schedule_sampler, LossSecondMomentResampler): - return { - 'sampler_warmed_up': torch.tensor(float(self.schedule_sampler._warmed_up())) - } - return {} + d['sampler_warmed_up'] = torch.tensor(float(self.schedule_sampler._warmed_up())) + return d def forward(self, state): gen = self.env['generators'][self.opt['generator']] @@ -87,6 +97,8 @@ class GaussianDiffusionInjector(Injector): out.update({self.output: diffusion_outputs['mse'], self.output_variational_bounds_key: diffusion_outputs['vb'], self.output_x_start_key: diffusion_outputs['x_start_predicted']}) + self.latest_mse_by_batch = diffusion_outputs['mse_by_batch'].detach().clone() + self.latest_timesteps = t.clone() return out