Report logp and logdet for srflow
This commit is contained in:
parent
c37d3faa58
commit
cad92bada8
|
@ -34,6 +34,8 @@ class SRFlowNet(nn.Module):
|
||||||
self.force_act_norm_init_until = opt_get(self.opt, ['networks', 'generator', 'flow', 'act_norm_start_step'])
|
self.force_act_norm_init_until = opt_get(self.opt, ['networks', 'generator', 'flow', 'act_norm_start_step'])
|
||||||
self.act_norm_always_init = False
|
self.act_norm_always_init = False
|
||||||
self.i = 0
|
self.i = 0
|
||||||
|
self.dbg_logp = 0
|
||||||
|
self.dbg_logdet = 0
|
||||||
|
|
||||||
def get_random_z(self, heat, seed=None, batch_size=1, lr_shape=None, device='cuda'):
|
def get_random_z(self, heat, seed=None, batch_size=1, lr_shape=None, device='cuda'):
|
||||||
if seed: torch.manual_seed(seed)
|
if seed: torch.manual_seed(seed)
|
||||||
|
@ -121,14 +123,20 @@ class SRFlowNet(nn.Module):
|
||||||
else:
|
else:
|
||||||
z = epses
|
z = epses
|
||||||
|
|
||||||
objective = objective + flow.GaussianDiag.logp(None, None, z)
|
logp = flow.GaussianDiag.logp(None, None, z)
|
||||||
|
objective = objective + logp
|
||||||
|
|
||||||
nll = (-objective) / float(np.log(2.) * pixels)
|
nll = (-objective) / float(np.log(2.) * pixels)
|
||||||
|
self.dbg_logp = -logp.mean().item() / float(np.log(2.) * pixels)
|
||||||
|
self.dbg_logdet = -logdet.mean().item() / float(np.log(2.) * pixels)
|
||||||
|
|
||||||
if isinstance(epses, list):
|
if isinstance(epses, list):
|
||||||
return epses, nll, logdet
|
return epses, nll, logdet
|
||||||
return z, nll, logdet
|
return z, nll, logdet
|
||||||
|
|
||||||
|
def get_debug_values(self, s, n):
|
||||||
|
return {"logp": self.dbg_logp, "logdet": self.dbg_logdet}
|
||||||
|
|
||||||
def rrdbPreprocessing(self, lr):
|
def rrdbPreprocessing(self, lr):
|
||||||
rrdbResults = self.RRDB(lr, get_steps=True)
|
rrdbResults = self.RRDB(lr, get_steps=True)
|
||||||
block_idxs = opt_get(self.opt, ['networks', 'generator','flow', 'stackRRDB', 'blocks']) or []
|
block_idxs = opt_get(self.opt, ['networks', 'generator','flow', 'stackRRDB', 'blocks']) or []
|
||||||
|
|
Loading…
Reference in New Issue
Block a user