diff --git a/codes/models/archs/srflow_orig/FlowActNorms.py b/codes/models/archs/srflow_orig/FlowActNorms.py index e92dc642..d1e993ab 100644 --- a/codes/models/archs/srflow_orig/FlowActNorms.py +++ b/codes/models/archs/srflow_orig/FlowActNorms.py @@ -22,6 +22,7 @@ class _ActNorm(nn.Module): self.num_features = num_features self.scale = float(scale) self.inited = False + self.force_initialization = False def _check_input_dim(self, input): return NotImplemented @@ -76,7 +77,7 @@ class _ActNorm(nn.Module): return input, logdet def forward(self, input, logdet=None, reverse=False, offset_mask=None, logs_offset=None, bias_offset=None): - if not self.inited: + if self.force_initialization or not self.inited: self.initialize_parameters(input) self._check_input_dim(input) diff --git a/codes/models/archs/srflow_orig/SRFlowNet_arch.py b/codes/models/archs/srflow_orig/SRFlowNet_arch.py index 5cb9d4b4..02ca4aac 100644 --- a/codes/models/archs/srflow_orig/SRFlowNet_arch.py +++ b/codes/models/archs/srflow_orig/SRFlowNet_arch.py @@ -31,6 +31,8 @@ class SRFlowNet(nn.Module): self.flowUpsamplerNet = \ FlowUpsamplerNet((160, 160, 3), hidden_channels, K, flow_coupling=opt['networks']['generator']['flow']['coupling'], opt=opt) + self.force_act_norm_init_until = opt_get(self.opt, ['networks', 'generator', 'flow', 'act_norm_start_step']) + self.act_norm_always_init = False self.i = 0 def get_random_z(self, heat, seed=None, batch_size=1, lr_shape=None, device='cuda'): @@ -52,6 +54,23 @@ class SRFlowNet(nn.Module): z = torch.normal(mean=0, std=heat, size=(batch_size, 3 * 8 * 8 * fac * fac, z_size, z_size)) return z.to(device) + def update_for_step(self, step, experiments_path='.'): + if self.act_norm_always_init and step > self.force_act_norm_init_until: + set_act_norm_always_init = True + set_value = False + self.act_norm_always_init = False + elif not self.act_norm_always_init and step < self.force_act_norm_init_until: + set_act_norm_always_init = True + set_value = True + self.act_norm_always_init = True + else: + set_act_norm_always_init = False + if set_act_norm_always_init: + for m in self.modules(): + from models.archs.srflow_orig.FlowActNorms import _ActNorm + if isinstance(m, _ActNorm): + m.force_initialization = set_value + def forward(self, gt=None, lr=None, z=None, eps_std=None, reverse=False, epses=None, reverse_with_grad=False, lr_enc=None, add_gt_noise=True, step=None, y_label=None):