Adjustments to how flow networks set size and scale

This commit is contained in:
James Betker 2020-11-27 21:37:00 -07:00
parent 6f958bb150
commit 71fa532356
3 changed files with 22 additions and 18 deletions

View File

@ -23,6 +23,7 @@ class FlowUpsamplerNet(nn.Module):
self.output_shapes = [] self.output_shapes = []
self.L = opt_get(opt, ['networks', 'generator','flow', 'L']) self.L = opt_get(opt, ['networks', 'generator','flow', 'L'])
self.K = opt_get(opt, ['networks', 'generator','flow', 'K']) self.K = opt_get(opt, ['networks', 'generator','flow', 'K'])
self.patch_sz = opt_get(opt, ['networks', 'generator', 'flow', 'patch_size'], 160)
if isinstance(self.K, int): if isinstance(self.K, int):
self.K = [K for K in [K, ] * (self.L + 1)] self.K = [K for K in [K, ] * (self.L + 1)]
@ -30,7 +31,7 @@ class FlowUpsamplerNet(nn.Module):
H, W, self.C = image_shape H, W, self.C = image_shape
self.check_image_shape() self.check_image_shape()
if opt['scale'] == 16: if opt_get(self.opt, ['networks', 'generator', 'flow_scale']) == 16:
self.levelToName = { self.levelToName = {
0: 'fea_up16', 0: 'fea_up16',
1: 'fea_up8', 1: 'fea_up8',
@ -39,7 +40,7 @@ class FlowUpsamplerNet(nn.Module):
4: 'fea_up1', 4: 'fea_up1',
} }
if opt['scale'] == 8: if opt_get(self.opt, ['networks', 'generator', 'flow_scale']) == 8:
self.levelToName = { self.levelToName = {
0: 'fea_up8', 0: 'fea_up8',
1: 'fea_up4', 1: 'fea_up4',
@ -48,7 +49,7 @@ class FlowUpsamplerNet(nn.Module):
4: 'fea_up0' 4: 'fea_up0'
} }
elif opt['scale'] == 4 or opt['scale'] == 2: elif opt_get(self.opt, ['networks', 'generator', 'flow_scale']) == 4:
self.levelToName = { self.levelToName = {
0: 'fea_up4', 0: 'fea_up4',
1: 'fea_up2', 1: 'fea_up2',
@ -95,8 +96,8 @@ class FlowUpsamplerNet(nn.Module):
self.H = H self.H = H
self.W = W self.W = W
self.scaleH = 160 / H self.scaleH = self.patch_sz / H
self.scaleW = 160 / W self.scaleW = self.patch_sz / W
def get_n_rrdb_channels(self, opt, opt_get): def get_n_rrdb_channels(self, opt, opt_get):
blocks = opt_get(opt, ['networks', 'generator','flow', 'stackRRDB', 'blocks']) blocks = opt_get(opt, ['networks', 'generator','flow', 'stackRRDB', 'blocks'])
@ -110,7 +111,7 @@ class FlowUpsamplerNet(nn.Module):
condAff['in_channels_rrdb'] = n_conditinal_channels condAff['in_channels_rrdb'] = n_conditinal_channels
for k in range(K): for k in range(K):
position_name = get_position_name(H, self.opt['scale']) position_name = self.get_position_name(H, opt_get(self.opt, ['networks', 'generator', 'flow_scale']))
if normOpt: normOpt['position'] = position_name if normOpt: normOpt['position'] = position_name
self.layers.append( self.layers.append(
@ -136,7 +137,7 @@ class FlowUpsamplerNet(nn.Module):
if opt_get(opt, ['networks', 'generator','flow', 'split', 'enable']) and L < levels - correction: if opt_get(opt, ['networks', 'generator','flow', 'split', 'enable']) and L < levels - correction:
logs_eps = opt_get(opt, ['networks', 'generator','flow', 'split', 'logs_eps']) or 0 logs_eps = opt_get(opt, ['networks', 'generator','flow', 'split', 'logs_eps']) or 0
consume_ratio = opt_get(opt, ['networks', 'generator','flow', 'split', 'consume_ratio']) or 0.5 consume_ratio = opt_get(opt, ['networks', 'generator','flow', 'split', 'consume_ratio']) or 0.5
position_name = get_position_name(H, self.opt['scale']) position_name = self.get_position_name(H, opt_get(self.opt, ['networks', 'generator', 'flow_scale']))
position = position_name if opt_get(opt, ['networks', 'generator','flow', 'split', 'conditional']) else None position = position_name if opt_get(opt, ['networks', 'generator','flow', 'split', 'conditional']) else None
cond_channels = opt_get(opt, ['networks', 'generator','flow', 'split', 'cond_channels']) cond_channels = opt_get(opt, ['networks', 'generator','flow', 'split', 'cond_channels'])
cond_channels = 0 if cond_channels is None else cond_channels cond_channels = 0 if cond_channels is None else cond_channels
@ -210,7 +211,7 @@ class FlowUpsamplerNet(nn.Module):
for layer, shape in zip(self.layers, self.output_shapes): for layer, shape in zip(self.layers, self.output_shapes):
size = shape[2] size = shape[2]
level = int(np.log(160 / size) / np.log(2)) level = int(np.log(self.patch_sz / size) / np.log(2))
if level > 0 and level not in level_conditionals.keys(): if level > 0 and level not in level_conditionals.keys():
level_conditionals[level] = rrdbResults[self.levelToName[level]] level_conditionals[level] = rrdbResults[self.levelToName[level]]
@ -258,7 +259,7 @@ class FlowUpsamplerNet(nn.Module):
for layer, shape in zip(reversed(self.layers), reversed(self.output_shapes)): for layer, shape in zip(reversed(self.layers), reversed(self.output_shapes)):
size = shape[2] size = shape[2]
level = int(np.log(160 / size) / np.log(2)) level = int(np.log(self.patch_sz / size) / np.log(2))
# size = fl_fea.shape[2] # size = fl_fea.shape[2]
# level = int(np.log(160 / size) / np.log(2)) # level = int(np.log(160 / size) / np.log(2))
@ -284,7 +285,7 @@ class FlowUpsamplerNet(nn.Module):
return fl_fea, logdet return fl_fea, logdet
def get_position_name(H, scale): def get_position_name(self, H, scale):
downscale_factor = 160 // H downscale_factor = self.patch_sz // H
position_name = 'fea_up{}'.format(scale / downscale_factor) position_name = 'fea_up{}'.format(scale / downscale_factor)
return position_name return position_name

View File

@ -135,7 +135,8 @@ class RRDBNet(nn.Module):
self.conv_body = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv_body = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
#### upsampling #### upsampling
self.conv_up1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv_up1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.conv_up2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) if self.scale >= 2:
self.conv_up2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
if self.scale >= 8: if self.scale >= 8:
self.conv_up3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv_up3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
if self.scale >= 16: if self.scale >= 16:
@ -167,13 +168,14 @@ class RRDBNet(nn.Module):
fea_up2 = self.conv_up1(F.interpolate(last_lr_fea, scale_factor=2, mode='nearest')) fea_up2 = self.conv_up1(F.interpolate(last_lr_fea, scale_factor=2, mode='nearest'))
fea = self.lrelu(fea_up2) fea = self.lrelu(fea_up2)
fea_up4 = self.conv_up2(F.interpolate(fea, scale_factor=2, mode='nearest')) fea_up4 = None
fea = self.lrelu(fea_up4)
fea_up8 = None fea_up8 = None
fea_up16 = None fea_up16 = None
fea_up32 = None fea_up32 = None
if self.scale >= 4:
fea_up4 = self.conv_up2(F.interpolate(fea, scale_factor=2, mode='nearest'))
fea = self.lrelu(fea_up4)
if self.scale >= 8: if self.scale >= 8:
fea_up8 = self.conv_up3(F.interpolate(fea, scale_factor=2, mode='nearest')) fea_up8 = self.conv_up3(F.interpolate(fea, scale_factor=2, mode='nearest'))
fea = self.lrelu(fea_up8) fea = self.lrelu(fea_up8)

View File

@ -29,8 +29,9 @@ class SRFlowNet(nn.Module):
self.RRDB_training = opt_get(self.opt, ['networks', 'generator','train_RRDB'], default=False) self.RRDB_training = opt_get(self.opt, ['networks', 'generator','train_RRDB'], default=False)
self.flow_scale = opt_get(self.opt, ['networks', 'generator', 'flow_scale'], default=opt['scale']) # <!-- hack to enable RRDB to do 2x scaling while retaining the flow architecture of 4x. self.flow_scale = opt_get(self.opt, ['networks', 'generator', 'flow_scale'], default=opt['scale']) # <!-- hack to enable RRDB to do 2x scaling while retaining the flow architecture of 4x.
self.patch_sz = opt_get(self.opt, ['networks', 'generator', 'flow', 'patch_size'], 160)
self.flowUpsamplerNet = \ self.flowUpsamplerNet = \
FlowUpsamplerNet((160, 160, 3), hidden_channels, K, FlowUpsamplerNet((self.patch_sz, self.patch_sz, 3), hidden_channels, K,
flow_coupling=opt['networks']['generator']['flow']['coupling'], opt=opt) flow_coupling=opt['networks']['generator']['flow']['coupling'], opt=opt)
self.force_act_norm_init_until = opt_get(self.opt, ['networks', 'generator', 'flow', 'act_norm_start_step']) self.force_act_norm_init_until = opt_get(self.opt, ['networks', 'generator', 'flow', 'act_norm_start_step'])
self.act_norm_always_init = False self.act_norm_always_init = False