Adjustments to how flow networks set size and scale
This commit is contained in:
parent
6f958bb150
commit
71fa532356
|
@ -23,6 +23,7 @@ class FlowUpsamplerNet(nn.Module):
|
|||
self.output_shapes = []
|
||||
self.L = opt_get(opt, ['networks', 'generator','flow', 'L'])
|
||||
self.K = opt_get(opt, ['networks', 'generator','flow', 'K'])
|
||||
self.patch_sz = opt_get(opt, ['networks', 'generator', 'flow', 'patch_size'], 160)
|
||||
if isinstance(self.K, int):
|
||||
self.K = [K for K in [K, ] * (self.L + 1)]
|
||||
|
||||
|
@ -30,7 +31,7 @@ class FlowUpsamplerNet(nn.Module):
|
|||
H, W, self.C = image_shape
|
||||
self.check_image_shape()
|
||||
|
||||
if opt['scale'] == 16:
|
||||
if opt_get(self.opt, ['networks', 'generator', 'flow_scale']) == 16:
|
||||
self.levelToName = {
|
||||
0: 'fea_up16',
|
||||
1: 'fea_up8',
|
||||
|
@ -39,7 +40,7 @@ class FlowUpsamplerNet(nn.Module):
|
|||
4: 'fea_up1',
|
||||
}
|
||||
|
||||
if opt['scale'] == 8:
|
||||
if opt_get(self.opt, ['networks', 'generator', 'flow_scale']) == 8:
|
||||
self.levelToName = {
|
||||
0: 'fea_up8',
|
||||
1: 'fea_up4',
|
||||
|
@ -48,7 +49,7 @@ class FlowUpsamplerNet(nn.Module):
|
|||
4: 'fea_up0'
|
||||
}
|
||||
|
||||
elif opt['scale'] == 4 or opt['scale'] == 2:
|
||||
elif opt_get(self.opt, ['networks', 'generator', 'flow_scale']) == 4:
|
||||
self.levelToName = {
|
||||
0: 'fea_up4',
|
||||
1: 'fea_up2',
|
||||
|
@ -95,8 +96,8 @@ class FlowUpsamplerNet(nn.Module):
|
|||
|
||||
self.H = H
|
||||
self.W = W
|
||||
self.scaleH = 160 / H
|
||||
self.scaleW = 160 / W
|
||||
self.scaleH = self.patch_sz / H
|
||||
self.scaleW = self.patch_sz / W
|
||||
|
||||
def get_n_rrdb_channels(self, opt, opt_get):
|
||||
blocks = opt_get(opt, ['networks', 'generator','flow', 'stackRRDB', 'blocks'])
|
||||
|
@ -110,7 +111,7 @@ class FlowUpsamplerNet(nn.Module):
|
|||
condAff['in_channels_rrdb'] = n_conditinal_channels
|
||||
|
||||
for k in range(K):
|
||||
position_name = get_position_name(H, self.opt['scale'])
|
||||
position_name = self.get_position_name(H, opt_get(self.opt, ['networks', 'generator', 'flow_scale']))
|
||||
if normOpt: normOpt['position'] = position_name
|
||||
|
||||
self.layers.append(
|
||||
|
@ -136,7 +137,7 @@ class FlowUpsamplerNet(nn.Module):
|
|||
if opt_get(opt, ['networks', 'generator','flow', 'split', 'enable']) and L < levels - correction:
|
||||
logs_eps = opt_get(opt, ['networks', 'generator','flow', 'split', 'logs_eps']) or 0
|
||||
consume_ratio = opt_get(opt, ['networks', 'generator','flow', 'split', 'consume_ratio']) or 0.5
|
||||
position_name = get_position_name(H, self.opt['scale'])
|
||||
position_name = self.get_position_name(H, opt_get(self.opt, ['networks', 'generator', 'flow_scale']))
|
||||
position = position_name if opt_get(opt, ['networks', 'generator','flow', 'split', 'conditional']) else None
|
||||
cond_channels = opt_get(opt, ['networks', 'generator','flow', 'split', 'cond_channels'])
|
||||
cond_channels = 0 if cond_channels is None else cond_channels
|
||||
|
@ -210,7 +211,7 @@ class FlowUpsamplerNet(nn.Module):
|
|||
|
||||
for layer, shape in zip(self.layers, self.output_shapes):
|
||||
size = shape[2]
|
||||
level = int(np.log(160 / size) / np.log(2))
|
||||
level = int(np.log(self.patch_sz / size) / np.log(2))
|
||||
|
||||
if level > 0 and level not in level_conditionals.keys():
|
||||
level_conditionals[level] = rrdbResults[self.levelToName[level]]
|
||||
|
@ -258,7 +259,7 @@ class FlowUpsamplerNet(nn.Module):
|
|||
|
||||
for layer, shape in zip(reversed(self.layers), reversed(self.output_shapes)):
|
||||
size = shape[2]
|
||||
level = int(np.log(160 / size) / np.log(2))
|
||||
level = int(np.log(self.patch_sz / size) / np.log(2))
|
||||
# size = fl_fea.shape[2]
|
||||
# level = int(np.log(160 / size) / np.log(2))
|
||||
|
||||
|
@ -284,7 +285,7 @@ class FlowUpsamplerNet(nn.Module):
|
|||
return fl_fea, logdet
|
||||
|
||||
|
||||
def get_position_name(H, scale):
|
||||
downscale_factor = 160 // H
|
||||
def get_position_name(self, H, scale):
|
||||
downscale_factor = self.patch_sz // H
|
||||
position_name = 'fea_up{}'.format(scale / downscale_factor)
|
||||
return position_name
|
||||
|
|
|
@ -135,6 +135,7 @@ class RRDBNet(nn.Module):
|
|||
self.conv_body = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
#### upsampling
|
||||
self.conv_up1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
if self.scale >= 2:
|
||||
self.conv_up2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
if self.scale >= 8:
|
||||
self.conv_up3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
|
@ -167,13 +168,14 @@ class RRDBNet(nn.Module):
|
|||
fea_up2 = self.conv_up1(F.interpolate(last_lr_fea, scale_factor=2, mode='nearest'))
|
||||
fea = self.lrelu(fea_up2)
|
||||
|
||||
fea_up4 = self.conv_up2(F.interpolate(fea, scale_factor=2, mode='nearest'))
|
||||
fea = self.lrelu(fea_up4)
|
||||
|
||||
fea_up4 = None
|
||||
fea_up8 = None
|
||||
fea_up16 = None
|
||||
fea_up32 = None
|
||||
|
||||
if self.scale >= 4:
|
||||
fea_up4 = self.conv_up2(F.interpolate(fea, scale_factor=2, mode='nearest'))
|
||||
fea = self.lrelu(fea_up4)
|
||||
if self.scale >= 8:
|
||||
fea_up8 = self.conv_up3(F.interpolate(fea, scale_factor=2, mode='nearest'))
|
||||
fea = self.lrelu(fea_up8)
|
||||
|
|
|
@ -29,8 +29,9 @@ class SRFlowNet(nn.Module):
|
|||
self.RRDB_training = opt_get(self.opt, ['networks', 'generator','train_RRDB'], default=False)
|
||||
self.flow_scale = opt_get(self.opt, ['networks', 'generator', 'flow_scale'], default=opt['scale']) # <!-- hack to enable RRDB to do 2x scaling while retaining the flow architecture of 4x.
|
||||
|
||||
self.patch_sz = opt_get(self.opt, ['networks', 'generator', 'flow', 'patch_size'], 160)
|
||||
self.flowUpsamplerNet = \
|
||||
FlowUpsamplerNet((160, 160, 3), hidden_channels, K,
|
||||
FlowUpsamplerNet((self.patch_sz, self.patch_sz, 3), hidden_channels, K,
|
||||
flow_coupling=opt['networks']['generator']['flow']['coupling'], opt=opt)
|
||||
self.force_act_norm_init_until = opt_get(self.opt, ['networks', 'generator', 'flow', 'act_norm_start_step'])
|
||||
self.act_norm_always_init = False
|
||||
|
|
Loading…
Reference in New Issue
Block a user