Bring split gaussian nll out of split so it can be computed accurately with the rest of the nll component

This commit is contained in:
James Betker 2020-11-27 13:30:21 -07:00
parent 11d2b70bdd
commit ef8d5f88c1
4 changed files with 14 additions and 11 deletions

View File

@ -242,9 +242,7 @@ class FlowUpsamplerNet(nn.Module):
def forward_split2d(self, epses, fl_fea, layer, logdet, reverse, rrdbResults, y_onehot=None): def forward_split2d(self, epses, fl_fea, layer, logdet, reverse, rrdbResults, y_onehot=None):
ft = None if layer.position is None else rrdbResults[layer.position] ft = None if layer.position is None else rrdbResults[layer.position]
fl_fea, logdet, eps = layer(fl_fea, logdet, reverse=reverse, eps=epses, ft=ft, y_onehot=y_onehot) fl_fea, logdet, eps = layer(fl_fea, logdet, reverse=reverse, eps=epses, ft=ft, y_onehot=y_onehot)
epses.append(eps)
if isinstance(epses, list):
epses.append(eps)
return fl_fea, logdet return fl_fea, logdet
def decode(self, rrdbResults, z, eps_std=None, epses=None, logdet=0.0, y_onehot=None): def decode(self, rrdbResults, z, eps_std=None, epses=None, logdet=0.0, y_onehot=None):

View File

@ -124,7 +124,11 @@ class SRFlowNet(nn.Module):
else: else:
z = epses z = epses
logp = flow.GaussianDiag.logp(None, None, z) logp = 0
for eps in epses:
logp = logp + flow.GaussianDiag.logp(None, None, eps)
logp_weight = opt_get(self.opt, ['networks', 'generator', 'flow', 'gaussian_loss_weight'], 1)
logp = logp * logp_weight
objective = objective + logp objective = objective + logp
nll = (-objective) / float(np.log(2.) * pixels) nll = (-objective) / float(np.log(2.) * pixels)

View File

@ -18,7 +18,7 @@ class Split2d(nn.Module):
out_channels=self.num_channels_consume * 2) out_channels=self.num_channels_consume * 2)
self.logs_eps = logs_eps self.logs_eps = logs_eps
self.position = position self.position = position
self.opt = opt self.gaussian_nll_weight = opt_get(opt, ['networks', 'generator', 'flow', 'gaussian_loss_weight'], 1)
def split2d_prior(self, z, ft): def split2d_prior(self, z, ft):
if ft is not None: if ft is not None:
@ -37,7 +37,8 @@ class Split2d(nn.Module):
eps = (z2 - mean) / self.exp_eps(logs) eps = (z2 - mean) / self.exp_eps(logs)
logdet = logdet + self.get_logdet(logs, mean, z2) # This has been moved into SRFlowNet_arch.py alongside the other Z NLL losses.
# logdet = logdet + self.get_logdet(logs, mean, z2)
# print(logs.shape, mean.shape, z2.shape) # print(logs.shape, mean.shape, z2.shape)
# self.eps = eps # self.eps = eps
@ -54,17 +55,17 @@ class Split2d(nn.Module):
eps = eps.to(mean.device) eps = eps.to(mean.device)
z2 = mean + self.exp_eps(logs) * eps z2 = mean + self.exp_eps(logs) * eps
z = thops.cat_feature(z1, z2) z = thops.cat_feature(z1, z2)
logdet = logdet - self.get_logdet(logs, mean, z2)
# This has been moved into SRFlowNet_arch.py alongside the other Z NLL losses.
#logdet = logdet - self.get_logdet(logs, mean, z2)
return z, logdet return z, logdet
# return z, logdet, eps # return z, logdet, eps
def get_logdet(self, logs, mean, z2): def get_logdet(self, logs, mean, z2):
logdet_diff = GaussianDiag.logp(mean, logs, z2) logdet_diff = GaussianDiag.logp(mean, logs, z2)
# print("Split2D: logdet diff", logdet_diff.item()) return logdet_diff * self.gaussian_nll_weight
return logdet_diff
def split_ratio(self, input): def split_ratio(self, input):
z1, z2 = input[:, :self.num_channels_pass, ...], input[:, self.num_channels_pass:, ...] z1, z2 = input[:, :self.num_channels_pass, ...], input[:, self.num_channels_pass:, ...]

View File

@ -291,7 +291,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_using_rrdb_features.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_rrdbdisc.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()