Bring split gaussian nll out of split so it can be computed accurately with the rest of the nll component

This commit is contained in:
James Betker 2020-11-27 13:30:21 -07:00
parent 11d2b70bdd
commit ef8d5f88c1
4 changed files with 14 additions and 11 deletions

View File

@ -242,8 +242,6 @@ class FlowUpsamplerNet(nn.Module):
def forward_split2d(self, epses, fl_fea, layer, logdet, reverse, rrdbResults, y_onehot=None):
ft = None if layer.position is None else rrdbResults[layer.position]
fl_fea, logdet, eps = layer(fl_fea, logdet, reverse=reverse, eps=epses, ft=ft, y_onehot=y_onehot)
if isinstance(epses, list):
epses.append(eps)
return fl_fea, logdet

View File

@ -124,7 +124,11 @@ class SRFlowNet(nn.Module):
else:
z = epses
logp = flow.GaussianDiag.logp(None, None, z)
logp = 0
for eps in epses:
logp = logp + flow.GaussianDiag.logp(None, None, eps)
logp_weight = opt_get(self.opt, ['networks', 'generator', 'flow', 'gaussian_loss_weight'], 1)
logp = logp * logp_weight
objective = objective + logp
nll = (-objective) / float(np.log(2.) * pixels)

View File

@ -18,7 +18,7 @@ class Split2d(nn.Module):
out_channels=self.num_channels_consume * 2)
self.logs_eps = logs_eps
self.position = position
self.opt = opt
self.gaussian_nll_weight = opt_get(opt, ['networks', 'generator', 'flow', 'gaussian_loss_weight'], 1)
def split2d_prior(self, z, ft):
if ft is not None:
@ -37,7 +37,8 @@ class Split2d(nn.Module):
eps = (z2 - mean) / self.exp_eps(logs)
logdet = logdet + self.get_logdet(logs, mean, z2)
# This has been moved into SRFlowNet_arch.py alongside the other Z NLL losses.
# logdet = logdet + self.get_logdet(logs, mean, z2)
# print(logs.shape, mean.shape, z2.shape)
# self.eps = eps
@ -54,17 +55,17 @@ class Split2d(nn.Module):
eps = eps.to(mean.device)
z2 = mean + self.exp_eps(logs) * eps
z = thops.cat_feature(z1, z2)
logdet = logdet - self.get_logdet(logs, mean, z2)
# This has been moved into SRFlowNet_arch.py alongside the other Z NLL losses.
#logdet = logdet - self.get_logdet(logs, mean, z2)
return z, logdet
# return z, logdet, eps
def get_logdet(self, logs, mean, z2):
logdet_diff = GaussianDiag.logp(mean, logs, z2)
# print("Split2D: logdet diff", logdet_diff.item())
return logdet_diff
return logdet_diff * self.gaussian_nll_weight
def split_ratio(self, input):
z1, z2 = input[:, :self.num_channels_pass, ...], input[:, self.num_channels_pass:, ...]

View File

@ -291,7 +291,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_using_rrdb_features.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_rrdbdisc.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()