From cad92bada89a9bd2a3a20ca05cb65a82121a2279 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 21 Nov 2020 10:13:05 -0700 Subject: [PATCH] Report logp and logdet for srflow --- codes/models/archs/srflow_orig/SRFlowNet_arch.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/codes/models/archs/srflow_orig/SRFlowNet_arch.py b/codes/models/archs/srflow_orig/SRFlowNet_arch.py index 02ca4aac..f972a8f4 100644 --- a/codes/models/archs/srflow_orig/SRFlowNet_arch.py +++ b/codes/models/archs/srflow_orig/SRFlowNet_arch.py @@ -34,6 +34,8 @@ class SRFlowNet(nn.Module): self.force_act_norm_init_until = opt_get(self.opt, ['networks', 'generator', 'flow', 'act_norm_start_step']) self.act_norm_always_init = False self.i = 0 + self.dbg_logp = 0 + self.dbg_logdet = 0 def get_random_z(self, heat, seed=None, batch_size=1, lr_shape=None, device='cuda'): if seed: torch.manual_seed(seed) @@ -121,14 +123,20 @@ class SRFlowNet(nn.Module): else: z = epses - objective = objective + flow.GaussianDiag.logp(None, None, z) + logp = flow.GaussianDiag.logp(None, None, z) + objective = objective + logp nll = (-objective) / float(np.log(2.) * pixels) + self.dbg_logp = -logp.mean().item() / float(np.log(2.) * pixels) + self.dbg_logdet = -logdet.mean().item() / float(np.log(2.) * pixels) if isinstance(epses, list): return epses, nll, logdet return z, nll, logdet + def get_debug_values(self, s, n): + return {"logp": self.dbg_logp, "logdet": self.dbg_logdet} + def rrdbPreprocessing(self, lr): rrdbResults = self.RRDB(lr, get_steps=True) block_idxs = opt_get(self.opt, ['networks', 'generator','flow', 'stackRRDB', 'blocks']) or []