Bring split gaussian nll out of split so it can be computed accurately with the rest of the nll component
This commit is contained in:
parent
11d2b70bdd
commit
ef8d5f88c1
|
@ -242,8 +242,6 @@ class FlowUpsamplerNet(nn.Module):
|
|||
def forward_split2d(self, epses, fl_fea, layer, logdet, reverse, rrdbResults, y_onehot=None):
|
||||
ft = None if layer.position is None else rrdbResults[layer.position]
|
||||
fl_fea, logdet, eps = layer(fl_fea, logdet, reverse=reverse, eps=epses, ft=ft, y_onehot=y_onehot)
|
||||
|
||||
if isinstance(epses, list):
|
||||
epses.append(eps)
|
||||
return fl_fea, logdet
|
||||
|
||||
|
|
|
@ -124,7 +124,11 @@ class SRFlowNet(nn.Module):
|
|||
else:
|
||||
z = epses
|
||||
|
||||
logp = flow.GaussianDiag.logp(None, None, z)
|
||||
logp = 0
|
||||
for eps in epses:
|
||||
logp = logp + flow.GaussianDiag.logp(None, None, eps)
|
||||
logp_weight = opt_get(self.opt, ['networks', 'generator', 'flow', 'gaussian_loss_weight'], 1)
|
||||
logp = logp * logp_weight
|
||||
objective = objective + logp
|
||||
|
||||
nll = (-objective) / float(np.log(2.) * pixels)
|
||||
|
|
|
@ -18,7 +18,7 @@ class Split2d(nn.Module):
|
|||
out_channels=self.num_channels_consume * 2)
|
||||
self.logs_eps = logs_eps
|
||||
self.position = position
|
||||
self.opt = opt
|
||||
self.gaussian_nll_weight = opt_get(opt, ['networks', 'generator', 'flow', 'gaussian_loss_weight'], 1)
|
||||
|
||||
def split2d_prior(self, z, ft):
|
||||
if ft is not None:
|
||||
|
@ -37,7 +37,8 @@ class Split2d(nn.Module):
|
|||
|
||||
eps = (z2 - mean) / self.exp_eps(logs)
|
||||
|
||||
logdet = logdet + self.get_logdet(logs, mean, z2)
|
||||
# This has been moved into SRFlowNet_arch.py alongside the other Z NLL losses.
|
||||
# logdet = logdet + self.get_logdet(logs, mean, z2)
|
||||
|
||||
# print(logs.shape, mean.shape, z2.shape)
|
||||
# self.eps = eps
|
||||
|
@ -54,17 +55,17 @@ class Split2d(nn.Module):
|
|||
|
||||
eps = eps.to(mean.device)
|
||||
z2 = mean + self.exp_eps(logs) * eps
|
||||
|
||||
z = thops.cat_feature(z1, z2)
|
||||
logdet = logdet - self.get_logdet(logs, mean, z2)
|
||||
|
||||
# This has been moved into SRFlowNet_arch.py alongside the other Z NLL losses.
|
||||
#logdet = logdet - self.get_logdet(logs, mean, z2)
|
||||
|
||||
return z, logdet
|
||||
# return z, logdet, eps
|
||||
|
||||
def get_logdet(self, logs, mean, z2):
|
||||
logdet_diff = GaussianDiag.logp(mean, logs, z2)
|
||||
# print("Split2D: logdet diff", logdet_diff.item())
|
||||
return logdet_diff
|
||||
return logdet_diff * self.gaussian_nll_weight
|
||||
|
||||
def split_ratio(self, input):
|
||||
z1, z2 = input[:, :self.num_channels_pass, ...], input[:, self.num_channels_pass:, ...]
|
||||
|
|
|
@ -291,7 +291,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_using_rrdb_features.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_rrdbdisc.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
Loading…
Reference in New Issue
Block a user