More adjustments to srflow_orig

This commit is contained in:
James Betker 2020-11-20 19:38:33 -07:00
parent d51d12a41a
commit c37d3faa58
2 changed files with 21 additions and 1 deletions

View File

@ -22,6 +22,7 @@ class _ActNorm(nn.Module):
self.num_features = num_features
self.scale = float(scale)
self.inited = False
self.force_initialization = False
def _check_input_dim(self, input):
return NotImplemented
@ -76,7 +77,7 @@ class _ActNorm(nn.Module):
return input, logdet
def forward(self, input, logdet=None, reverse=False, offset_mask=None, logs_offset=None, bias_offset=None):
if not self.inited:
if self.force_initialization or not self.inited:
self.initialize_parameters(input)
self._check_input_dim(input)

View File

@ -31,6 +31,8 @@ class SRFlowNet(nn.Module):
self.flowUpsamplerNet = \
FlowUpsamplerNet((160, 160, 3), hidden_channels, K,
flow_coupling=opt['networks']['generator']['flow']['coupling'], opt=opt)
self.force_act_norm_init_until = opt_get(self.opt, ['networks', 'generator', 'flow', 'act_norm_start_step'])
self.act_norm_always_init = False
self.i = 0
def get_random_z(self, heat, seed=None, batch_size=1, lr_shape=None, device='cuda'):
@ -52,6 +54,23 @@ class SRFlowNet(nn.Module):
z = torch.normal(mean=0, std=heat, size=(batch_size, 3 * 8 * 8 * fac * fac, z_size, z_size))
return z.to(device)
def update_for_step(self, step, experiments_path='.'):
if self.act_norm_always_init and step > self.force_act_norm_init_until:
set_act_norm_always_init = True
set_value = False
self.act_norm_always_init = False
elif not self.act_norm_always_init and step < self.force_act_norm_init_until:
set_act_norm_always_init = True
set_value = True
self.act_norm_always_init = True
else:
set_act_norm_always_init = False
if set_act_norm_always_init:
for m in self.modules():
from models.archs.srflow_orig.FlowActNorms import _ActNorm
if isinstance(m, _ActNorm):
m.force_initialization = set_value
def forward(self, gt=None, lr=None, z=None, eps_std=None, reverse=False, epses=None, reverse_with_grad=False,
lr_enc=None,
add_gt_noise=True, step=None, y_label=None):