Bring split gaussian nll out of split so it can be computed accurately with the rest of the nll component
This commit is contained in:
parent
11d2b70bdd
commit
ef8d5f88c1
|
@ -242,9 +242,7 @@ class FlowUpsamplerNet(nn.Module):
|
||||||
def forward_split2d(self, epses, fl_fea, layer, logdet, reverse, rrdbResults, y_onehot=None):
|
def forward_split2d(self, epses, fl_fea, layer, logdet, reverse, rrdbResults, y_onehot=None):
|
||||||
ft = None if layer.position is None else rrdbResults[layer.position]
|
ft = None if layer.position is None else rrdbResults[layer.position]
|
||||||
fl_fea, logdet, eps = layer(fl_fea, logdet, reverse=reverse, eps=epses, ft=ft, y_onehot=y_onehot)
|
fl_fea, logdet, eps = layer(fl_fea, logdet, reverse=reverse, eps=epses, ft=ft, y_onehot=y_onehot)
|
||||||
|
epses.append(eps)
|
||||||
if isinstance(epses, list):
|
|
||||||
epses.append(eps)
|
|
||||||
return fl_fea, logdet
|
return fl_fea, logdet
|
||||||
|
|
||||||
def decode(self, rrdbResults, z, eps_std=None, epses=None, logdet=0.0, y_onehot=None):
|
def decode(self, rrdbResults, z, eps_std=None, epses=None, logdet=0.0, y_onehot=None):
|
||||||
|
|
|
@ -124,7 +124,11 @@ class SRFlowNet(nn.Module):
|
||||||
else:
|
else:
|
||||||
z = epses
|
z = epses
|
||||||
|
|
||||||
logp = flow.GaussianDiag.logp(None, None, z)
|
logp = 0
|
||||||
|
for eps in epses:
|
||||||
|
logp = logp + flow.GaussianDiag.logp(None, None, eps)
|
||||||
|
logp_weight = opt_get(self.opt, ['networks', 'generator', 'flow', 'gaussian_loss_weight'], 1)
|
||||||
|
logp = logp * logp_weight
|
||||||
objective = objective + logp
|
objective = objective + logp
|
||||||
|
|
||||||
nll = (-objective) / float(np.log(2.) * pixels)
|
nll = (-objective) / float(np.log(2.) * pixels)
|
||||||
|
|
|
@ -18,7 +18,7 @@ class Split2d(nn.Module):
|
||||||
out_channels=self.num_channels_consume * 2)
|
out_channels=self.num_channels_consume * 2)
|
||||||
self.logs_eps = logs_eps
|
self.logs_eps = logs_eps
|
||||||
self.position = position
|
self.position = position
|
||||||
self.opt = opt
|
self.gaussian_nll_weight = opt_get(opt, ['networks', 'generator', 'flow', 'gaussian_loss_weight'], 1)
|
||||||
|
|
||||||
def split2d_prior(self, z, ft):
|
def split2d_prior(self, z, ft):
|
||||||
if ft is not None:
|
if ft is not None:
|
||||||
|
@ -37,7 +37,8 @@ class Split2d(nn.Module):
|
||||||
|
|
||||||
eps = (z2 - mean) / self.exp_eps(logs)
|
eps = (z2 - mean) / self.exp_eps(logs)
|
||||||
|
|
||||||
logdet = logdet + self.get_logdet(logs, mean, z2)
|
# This has been moved into SRFlowNet_arch.py alongside the other Z NLL losses.
|
||||||
|
# logdet = logdet + self.get_logdet(logs, mean, z2)
|
||||||
|
|
||||||
# print(logs.shape, mean.shape, z2.shape)
|
# print(logs.shape, mean.shape, z2.shape)
|
||||||
# self.eps = eps
|
# self.eps = eps
|
||||||
|
@ -54,17 +55,17 @@ class Split2d(nn.Module):
|
||||||
|
|
||||||
eps = eps.to(mean.device)
|
eps = eps.to(mean.device)
|
||||||
z2 = mean + self.exp_eps(logs) * eps
|
z2 = mean + self.exp_eps(logs) * eps
|
||||||
|
|
||||||
z = thops.cat_feature(z1, z2)
|
z = thops.cat_feature(z1, z2)
|
||||||
logdet = logdet - self.get_logdet(logs, mean, z2)
|
|
||||||
|
# This has been moved into SRFlowNet_arch.py alongside the other Z NLL losses.
|
||||||
|
#logdet = logdet - self.get_logdet(logs, mean, z2)
|
||||||
|
|
||||||
return z, logdet
|
return z, logdet
|
||||||
# return z, logdet, eps
|
# return z, logdet, eps
|
||||||
|
|
||||||
def get_logdet(self, logs, mean, z2):
|
def get_logdet(self, logs, mean, z2):
|
||||||
logdet_diff = GaussianDiag.logp(mean, logs, z2)
|
logdet_diff = GaussianDiag.logp(mean, logs, z2)
|
||||||
# print("Split2D: logdet diff", logdet_diff.item())
|
return logdet_diff * self.gaussian_nll_weight
|
||||||
return logdet_diff
|
|
||||||
|
|
||||||
def split_ratio(self, input):
|
def split_ratio(self, input):
|
||||||
z1, z2 = input[:, :self.num_channels_pass, ...], input[:, self.num_channels_pass:, ...]
|
z1, z2 = input[:, :self.num_channels_pass, ...], input[:, self.num_channels_pass:, ...]
|
||||||
|
|
|
@ -291,7 +291,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_using_rrdb_features.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_rrdbdisc.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user