From c25b49bb12ccd299ca2c7907685f888cb12d2f22 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 16 Dec 2020 10:27:38 -0700 Subject: [PATCH] Clean up of SRFlowNet_arch --- .../models/archs/srflow_orig/FlowActNorms.py | 3 --- .../archs/srflow_orig/SRFlowNet_arch.py | 21 +------------------ 2 files changed, 1 insertion(+), 23 deletions(-) diff --git a/codes/models/archs/srflow_orig/FlowActNorms.py b/codes/models/archs/srflow_orig/FlowActNorms.py index d1e993ab..4314607d 100644 --- a/codes/models/archs/srflow_orig/FlowActNorms.py +++ b/codes/models/archs/srflow_orig/FlowActNorms.py @@ -22,7 +22,6 @@ class _ActNorm(nn.Module): self.num_features = num_features self.scale = float(scale) self.inited = False - self.force_initialization = False def _check_input_dim(self, input): return NotImplemented @@ -77,8 +76,6 @@ class _ActNorm(nn.Module): return input, logdet def forward(self, input, logdet=None, reverse=False, offset_mask=None, logs_offset=None, bias_offset=None): - if self.force_initialization or not self.inited: - self.initialize_parameters(input) self._check_input_dim(input) if offset_mask is not None: diff --git a/codes/models/archs/srflow_orig/SRFlowNet_arch.py b/codes/models/archs/srflow_orig/SRFlowNet_arch.py index d1d9cf62..b5959f2b 100644 --- a/codes/models/archs/srflow_orig/SRFlowNet_arch.py +++ b/codes/models/archs/srflow_orig/SRFlowNet_arch.py @@ -34,8 +34,6 @@ class SRFlowNet(nn.Module): self.flowUpsamplerNet = \ FlowUpsamplerNet((self.patch_sz, self.patch_sz, 3), hidden_channels, K, flow_coupling=opt['networks']['generator']['flow']['coupling'], opt=opt) - self.force_act_norm_init_until = opt_get(self.opt, ['networks', 'generator', 'flow', 'act_norm_start_step']) - self.act_norm_always_init = False self.i = 0 self.dbg_logp = 0 self.dbg_logdet = 0 @@ -59,23 +57,6 @@ class SRFlowNet(nn.Module): z = torch.normal(mean=0, std=heat, size=(batch_size, 3 * 8 * 8 * fac * fac, z_size, z_size)) return z.to(device) - def update_for_step(self, step, experiments_path='.'): - if self.act_norm_always_init and step > self.force_act_norm_init_until: - set_act_norm_always_init = True - set_value = False - self.act_norm_always_init = False - elif not self.act_norm_always_init and step < self.force_act_norm_init_until: - set_act_norm_always_init = True - set_value = True - self.act_norm_always_init = True - else: - set_act_norm_always_init = False - if set_act_norm_always_init: - for m in self.modules(): - from models.archs.srflow_orig.FlowActNorms import _ActNorm - if isinstance(m, _ActNorm): - m.force_initialization = set_value - def forward(self, gt=None, lr=None, z=None, eps_std=None, reverse=False, epses=None, reverse_with_grad=False, lr_enc=None, add_gt_noise=True, step=None, y_label=None): @@ -170,7 +151,7 @@ class SRFlowNet(nn.Module): def reverse_flow(self, lr, z, y_onehot, eps_std, epses=None, lr_enc=None, add_gt_noise=True): logdet = torch.zeros_like(lr[:, 0, 0, 0]) - pixels = thops.pixels(lr) * self.flow_scale ** 2 + pixels = thops.pixels(lr) * self.opt['scale'] ** 2 if add_gt_noise: logdet = logdet - float(-np.log(self.quant) * pixels)