forked from mrq/DL-Art-School
Clean up of SRFlowNet_arch
This commit is contained in:
parent
42ac8e3eeb
commit
c25b49bb12
|
@ -22,7 +22,6 @@ class _ActNorm(nn.Module):
|
||||||
self.num_features = num_features
|
self.num_features = num_features
|
||||||
self.scale = float(scale)
|
self.scale = float(scale)
|
||||||
self.inited = False
|
self.inited = False
|
||||||
self.force_initialization = False
|
|
||||||
|
|
||||||
def _check_input_dim(self, input):
|
def _check_input_dim(self, input):
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
|
@ -77,8 +76,6 @@ class _ActNorm(nn.Module):
|
||||||
return input, logdet
|
return input, logdet
|
||||||
|
|
||||||
def forward(self, input, logdet=None, reverse=False, offset_mask=None, logs_offset=None, bias_offset=None):
|
def forward(self, input, logdet=None, reverse=False, offset_mask=None, logs_offset=None, bias_offset=None):
|
||||||
if self.force_initialization or not self.inited:
|
|
||||||
self.initialize_parameters(input)
|
|
||||||
self._check_input_dim(input)
|
self._check_input_dim(input)
|
||||||
|
|
||||||
if offset_mask is not None:
|
if offset_mask is not None:
|
||||||
|
|
|
@ -34,8 +34,6 @@ class SRFlowNet(nn.Module):
|
||||||
self.flowUpsamplerNet = \
|
self.flowUpsamplerNet = \
|
||||||
FlowUpsamplerNet((self.patch_sz, self.patch_sz, 3), hidden_channels, K,
|
FlowUpsamplerNet((self.patch_sz, self.patch_sz, 3), hidden_channels, K,
|
||||||
flow_coupling=opt['networks']['generator']['flow']['coupling'], opt=opt)
|
flow_coupling=opt['networks']['generator']['flow']['coupling'], opt=opt)
|
||||||
self.force_act_norm_init_until = opt_get(self.opt, ['networks', 'generator', 'flow', 'act_norm_start_step'])
|
|
||||||
self.act_norm_always_init = False
|
|
||||||
self.i = 0
|
self.i = 0
|
||||||
self.dbg_logp = 0
|
self.dbg_logp = 0
|
||||||
self.dbg_logdet = 0
|
self.dbg_logdet = 0
|
||||||
|
@ -59,23 +57,6 @@ class SRFlowNet(nn.Module):
|
||||||
z = torch.normal(mean=0, std=heat, size=(batch_size, 3 * 8 * 8 * fac * fac, z_size, z_size))
|
z = torch.normal(mean=0, std=heat, size=(batch_size, 3 * 8 * 8 * fac * fac, z_size, z_size))
|
||||||
return z.to(device)
|
return z.to(device)
|
||||||
|
|
||||||
def update_for_step(self, step, experiments_path='.'):
|
|
||||||
if self.act_norm_always_init and step > self.force_act_norm_init_until:
|
|
||||||
set_act_norm_always_init = True
|
|
||||||
set_value = False
|
|
||||||
self.act_norm_always_init = False
|
|
||||||
elif not self.act_norm_always_init and step < self.force_act_norm_init_until:
|
|
||||||
set_act_norm_always_init = True
|
|
||||||
set_value = True
|
|
||||||
self.act_norm_always_init = True
|
|
||||||
else:
|
|
||||||
set_act_norm_always_init = False
|
|
||||||
if set_act_norm_always_init:
|
|
||||||
for m in self.modules():
|
|
||||||
from models.archs.srflow_orig.FlowActNorms import _ActNorm
|
|
||||||
if isinstance(m, _ActNorm):
|
|
||||||
m.force_initialization = set_value
|
|
||||||
|
|
||||||
def forward(self, gt=None, lr=None, z=None, eps_std=None, reverse=False, epses=None, reverse_with_grad=False,
|
def forward(self, gt=None, lr=None, z=None, eps_std=None, reverse=False, epses=None, reverse_with_grad=False,
|
||||||
lr_enc=None,
|
lr_enc=None,
|
||||||
add_gt_noise=True, step=None, y_label=None):
|
add_gt_noise=True, step=None, y_label=None):
|
||||||
|
@ -170,7 +151,7 @@ class SRFlowNet(nn.Module):
|
||||||
|
|
||||||
def reverse_flow(self, lr, z, y_onehot, eps_std, epses=None, lr_enc=None, add_gt_noise=True):
|
def reverse_flow(self, lr, z, y_onehot, eps_std, epses=None, lr_enc=None, add_gt_noise=True):
|
||||||
logdet = torch.zeros_like(lr[:, 0, 0, 0])
|
logdet = torch.zeros_like(lr[:, 0, 0, 0])
|
||||||
pixels = thops.pixels(lr) * self.flow_scale ** 2
|
pixels = thops.pixels(lr) * self.opt['scale'] ** 2
|
||||||
|
|
||||||
if add_gt_noise:
|
if add_gt_noise:
|
||||||
logdet = logdet - float(-np.log(self.quant) * pixels)
|
logdet = logdet - float(-np.log(self.quant) * pixels)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user