Report logp and logdet for srflow
This commit is contained in:
parent
c37d3faa58
commit
cad92bada8
|
@ -34,6 +34,8 @@ class SRFlowNet(nn.Module):
|
|||
self.force_act_norm_init_until = opt_get(self.opt, ['networks', 'generator', 'flow', 'act_norm_start_step'])
|
||||
self.act_norm_always_init = False
|
||||
self.i = 0
|
||||
self.dbg_logp = 0
|
||||
self.dbg_logdet = 0
|
||||
|
||||
def get_random_z(self, heat, seed=None, batch_size=1, lr_shape=None, device='cuda'):
|
||||
if seed: torch.manual_seed(seed)
|
||||
|
@ -121,14 +123,20 @@ class SRFlowNet(nn.Module):
|
|||
else:
|
||||
z = epses
|
||||
|
||||
objective = objective + flow.GaussianDiag.logp(None, None, z)
|
||||
logp = flow.GaussianDiag.logp(None, None, z)
|
||||
objective = objective + logp
|
||||
|
||||
nll = (-objective) / float(np.log(2.) * pixels)
|
||||
self.dbg_logp = -logp.mean().item() / float(np.log(2.) * pixels)
|
||||
self.dbg_logdet = -logdet.mean().item() / float(np.log(2.) * pixels)
|
||||
|
||||
if isinstance(epses, list):
|
||||
return epses, nll, logdet
|
||||
return z, nll, logdet
|
||||
|
||||
def get_debug_values(self, s, n):
|
||||
return {"logp": self.dbg_logp, "logdet": self.dbg_logdet}
|
||||
|
||||
def rrdbPreprocessing(self, lr):
|
||||
rrdbResults = self.RRDB(lr, get_steps=True)
|
||||
block_idxs = opt_get(self.opt, ['networks', 'generator','flow', 'stackRRDB', 'blocks']) or []
|
||||
|
|
Loading…
Reference in New Issue
Block a user