From ef8d5f88c1ee43744011ec6409d9f5de6a55d064 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 27 Nov 2020 13:30:21 -0700 Subject: [PATCH] Bring split gaussian nll out of split so it can be computed accurately with the rest of the nll component --- codes/models/archs/srflow_orig/FlowUpsamplerNet.py | 4 +--- codes/models/archs/srflow_orig/SRFlowNet_arch.py | 6 +++++- codes/models/archs/srflow_orig/Split.py | 13 +++++++------ codes/train.py | 2 +- 4 files changed, 14 insertions(+), 11 deletions(-) diff --git a/codes/models/archs/srflow_orig/FlowUpsamplerNet.py b/codes/models/archs/srflow_orig/FlowUpsamplerNet.py index 4a4c8ce5..c208fbc4 100644 --- a/codes/models/archs/srflow_orig/FlowUpsamplerNet.py +++ b/codes/models/archs/srflow_orig/FlowUpsamplerNet.py @@ -242,9 +242,7 @@ class FlowUpsamplerNet(nn.Module): def forward_split2d(self, epses, fl_fea, layer, logdet, reverse, rrdbResults, y_onehot=None): ft = None if layer.position is None else rrdbResults[layer.position] fl_fea, logdet, eps = layer(fl_fea, logdet, reverse=reverse, eps=epses, ft=ft, y_onehot=y_onehot) - - if isinstance(epses, list): - epses.append(eps) + epses.append(eps) return fl_fea, logdet def decode(self, rrdbResults, z, eps_std=None, epses=None, logdet=0.0, y_onehot=None): diff --git a/codes/models/archs/srflow_orig/SRFlowNet_arch.py b/codes/models/archs/srflow_orig/SRFlowNet_arch.py index 93a729ef..8973c69c 100644 --- a/codes/models/archs/srflow_orig/SRFlowNet_arch.py +++ b/codes/models/archs/srflow_orig/SRFlowNet_arch.py @@ -124,7 +124,11 @@ class SRFlowNet(nn.Module): else: z = epses - logp = flow.GaussianDiag.logp(None, None, z) + logp = 0 + for eps in epses: + logp = logp + flow.GaussianDiag.logp(None, None, eps) + logp_weight = opt_get(self.opt, ['networks', 'generator', 'flow', 'gaussian_loss_weight'], 1) + logp = logp * logp_weight objective = objective + logp nll = (-objective) / float(np.log(2.) * pixels) diff --git a/codes/models/archs/srflow_orig/Split.py b/codes/models/archs/srflow_orig/Split.py index 1f5de9bf..f6d0aff7 100644 --- a/codes/models/archs/srflow_orig/Split.py +++ b/codes/models/archs/srflow_orig/Split.py @@ -18,7 +18,7 @@ class Split2d(nn.Module): out_channels=self.num_channels_consume * 2) self.logs_eps = logs_eps self.position = position - self.opt = opt + self.gaussian_nll_weight = opt_get(opt, ['networks', 'generator', 'flow', 'gaussian_loss_weight'], 1) def split2d_prior(self, z, ft): if ft is not None: @@ -37,7 +37,8 @@ class Split2d(nn.Module): eps = (z2 - mean) / self.exp_eps(logs) - logdet = logdet + self.get_logdet(logs, mean, z2) + # This has been moved into SRFlowNet_arch.py alongside the other Z NLL losses. + # logdet = logdet + self.get_logdet(logs, mean, z2) # print(logs.shape, mean.shape, z2.shape) # self.eps = eps @@ -54,17 +55,17 @@ class Split2d(nn.Module): eps = eps.to(mean.device) z2 = mean + self.exp_eps(logs) * eps - z = thops.cat_feature(z1, z2) - logdet = logdet - self.get_logdet(logs, mean, z2) + + # This has been moved into SRFlowNet_arch.py alongside the other Z NLL losses. + #logdet = logdet - self.get_logdet(logs, mean, z2) return z, logdet # return z, logdet, eps def get_logdet(self, logs, mean, z2): logdet_diff = GaussianDiag.logp(mean, logs, z2) - # print("Split2D: logdet diff", logdet_diff.item()) - return logdet_diff + return logdet_diff * self.gaussian_nll_weight def split_ratio(self, input): z1, z2 = input[:, :self.num_channels_pass, ...], input[:, self.num_channels_pass:, ...] diff --git a/codes/train.py b/codes/train.py index 251a079d..2f18ef64 100644 --- a/codes/train.py +++ b/codes/train.py @@ -291,7 +291,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_using_rrdb_features.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_rrdbdisc.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()