Report logp and logdet for srflow

This commit is contained in:
James Betker 2020-11-21 10:13:05 -07:00
parent c37d3faa58
commit cad92bada8

View File

@ -34,6 +34,8 @@ class SRFlowNet(nn.Module):
self.force_act_norm_init_until = opt_get(self.opt, ['networks', 'generator', 'flow', 'act_norm_start_step']) self.force_act_norm_init_until = opt_get(self.opt, ['networks', 'generator', 'flow', 'act_norm_start_step'])
self.act_norm_always_init = False self.act_norm_always_init = False
self.i = 0 self.i = 0
self.dbg_logp = 0
self.dbg_logdet = 0
def get_random_z(self, heat, seed=None, batch_size=1, lr_shape=None, device='cuda'): def get_random_z(self, heat, seed=None, batch_size=1, lr_shape=None, device='cuda'):
if seed: torch.manual_seed(seed) if seed: torch.manual_seed(seed)
@ -121,14 +123,20 @@ class SRFlowNet(nn.Module):
else: else:
z = epses z = epses
objective = objective + flow.GaussianDiag.logp(None, None, z) logp = flow.GaussianDiag.logp(None, None, z)
objective = objective + logp
nll = (-objective) / float(np.log(2.) * pixels) nll = (-objective) / float(np.log(2.) * pixels)
self.dbg_logp = -logp.mean().item() / float(np.log(2.) * pixels)
self.dbg_logdet = -logdet.mean().item() / float(np.log(2.) * pixels)
if isinstance(epses, list): if isinstance(epses, list):
return epses, nll, logdet return epses, nll, logdet
return z, nll, logdet return z, nll, logdet
def get_debug_values(self, s, n):
return {"logp": self.dbg_logp, "logdet": self.dbg_logdet}
def rrdbPreprocessing(self, lr): def rrdbPreprocessing(self, lr):
rrdbResults = self.RRDB(lr, get_steps=True) rrdbResults = self.RRDB(lr, get_steps=True)
block_idxs = opt_get(self.opt, ['networks', 'generator','flow', 'stackRRDB', 'blocks']) or [] block_idxs = opt_get(self.opt, ['networks', 'generator','flow', 'stackRRDB', 'blocks']) or []