From 71fa532356f77243dffc159856b33a921734abe5 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 27 Nov 2020 21:37:00 -0700 Subject: [PATCH] Adjustments to how flow networks set size and scale --- .../archs/srflow_orig/FlowUpsamplerNet.py | 27 ++++++++++--------- .../models/archs/srflow_orig/RRDBNet_arch.py | 10 ++++--- .../archs/srflow_orig/SRFlowNet_arch.py | 3 ++- 3 files changed, 22 insertions(+), 18 deletions(-) diff --git a/codes/models/archs/srflow_orig/FlowUpsamplerNet.py b/codes/models/archs/srflow_orig/FlowUpsamplerNet.py index c208fbc4..0b595128 100644 --- a/codes/models/archs/srflow_orig/FlowUpsamplerNet.py +++ b/codes/models/archs/srflow_orig/FlowUpsamplerNet.py @@ -23,6 +23,7 @@ class FlowUpsamplerNet(nn.Module): self.output_shapes = [] self.L = opt_get(opt, ['networks', 'generator','flow', 'L']) self.K = opt_get(opt, ['networks', 'generator','flow', 'K']) + self.patch_sz = opt_get(opt, ['networks', 'generator', 'flow', 'patch_size'], 160) if isinstance(self.K, int): self.K = [K for K in [K, ] * (self.L + 1)] @@ -30,7 +31,7 @@ class FlowUpsamplerNet(nn.Module): H, W, self.C = image_shape self.check_image_shape() - if opt['scale'] == 16: + if opt_get(self.opt, ['networks', 'generator', 'flow_scale']) == 16: self.levelToName = { 0: 'fea_up16', 1: 'fea_up8', @@ -39,7 +40,7 @@ class FlowUpsamplerNet(nn.Module): 4: 'fea_up1', } - if opt['scale'] == 8: + if opt_get(self.opt, ['networks', 'generator', 'flow_scale']) == 8: self.levelToName = { 0: 'fea_up8', 1: 'fea_up4', @@ -48,7 +49,7 @@ class FlowUpsamplerNet(nn.Module): 4: 'fea_up0' } - elif opt['scale'] == 4 or opt['scale'] == 2: + elif opt_get(self.opt, ['networks', 'generator', 'flow_scale']) == 4: self.levelToName = { 0: 'fea_up4', 1: 'fea_up2', @@ -95,8 +96,8 @@ class FlowUpsamplerNet(nn.Module): self.H = H self.W = W - self.scaleH = 160 / H - self.scaleW = 160 / W + self.scaleH = self.patch_sz / H + self.scaleW = self.patch_sz / W def get_n_rrdb_channels(self, opt, opt_get): blocks = opt_get(opt, ['networks', 'generator','flow', 'stackRRDB', 'blocks']) @@ -110,7 +111,7 @@ class FlowUpsamplerNet(nn.Module): condAff['in_channels_rrdb'] = n_conditinal_channels for k in range(K): - position_name = get_position_name(H, self.opt['scale']) + position_name = self.get_position_name(H, opt_get(self.opt, ['networks', 'generator', 'flow_scale'])) if normOpt: normOpt['position'] = position_name self.layers.append( @@ -136,7 +137,7 @@ class FlowUpsamplerNet(nn.Module): if opt_get(opt, ['networks', 'generator','flow', 'split', 'enable']) and L < levels - correction: logs_eps = opt_get(opt, ['networks', 'generator','flow', 'split', 'logs_eps']) or 0 consume_ratio = opt_get(opt, ['networks', 'generator','flow', 'split', 'consume_ratio']) or 0.5 - position_name = get_position_name(H, self.opt['scale']) + position_name = self.get_position_name(H, opt_get(self.opt, ['networks', 'generator', 'flow_scale'])) position = position_name if opt_get(opt, ['networks', 'generator','flow', 'split', 'conditional']) else None cond_channels = opt_get(opt, ['networks', 'generator','flow', 'split', 'cond_channels']) cond_channels = 0 if cond_channels is None else cond_channels @@ -210,7 +211,7 @@ class FlowUpsamplerNet(nn.Module): for layer, shape in zip(self.layers, self.output_shapes): size = shape[2] - level = int(np.log(160 / size) / np.log(2)) + level = int(np.log(self.patch_sz / size) / np.log(2)) if level > 0 and level not in level_conditionals.keys(): level_conditionals[level] = rrdbResults[self.levelToName[level]] @@ -258,7 +259,7 @@ class FlowUpsamplerNet(nn.Module): for layer, shape in zip(reversed(self.layers), reversed(self.output_shapes)): size = shape[2] - level = int(np.log(160 / size) / np.log(2)) + level = int(np.log(self.patch_sz / size) / np.log(2)) # size = fl_fea.shape[2] # level = int(np.log(160 / size) / np.log(2)) @@ -284,7 +285,7 @@ class FlowUpsamplerNet(nn.Module): return fl_fea, logdet -def get_position_name(H, scale): - downscale_factor = 160 // H - position_name = 'fea_up{}'.format(scale / downscale_factor) - return position_name + def get_position_name(self, H, scale): + downscale_factor = self.patch_sz // H + position_name = 'fea_up{}'.format(scale / downscale_factor) + return position_name diff --git a/codes/models/archs/srflow_orig/RRDBNet_arch.py b/codes/models/archs/srflow_orig/RRDBNet_arch.py index 3828d7b8..7028e163 100644 --- a/codes/models/archs/srflow_orig/RRDBNet_arch.py +++ b/codes/models/archs/srflow_orig/RRDBNet_arch.py @@ -135,7 +135,8 @@ class RRDBNet(nn.Module): self.conv_body = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) #### upsampling self.conv_up1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.conv_up2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + if self.scale >= 2: + self.conv_up2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) if self.scale >= 8: self.conv_up3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) if self.scale >= 16: @@ -167,13 +168,14 @@ class RRDBNet(nn.Module): fea_up2 = self.conv_up1(F.interpolate(last_lr_fea, scale_factor=2, mode='nearest')) fea = self.lrelu(fea_up2) - fea_up4 = self.conv_up2(F.interpolate(fea, scale_factor=2, mode='nearest')) - fea = self.lrelu(fea_up4) - + fea_up4 = None fea_up8 = None fea_up16 = None fea_up32 = None + if self.scale >= 4: + fea_up4 = self.conv_up2(F.interpolate(fea, scale_factor=2, mode='nearest')) + fea = self.lrelu(fea_up4) if self.scale >= 8: fea_up8 = self.conv_up3(F.interpolate(fea, scale_factor=2, mode='nearest')) fea = self.lrelu(fea_up8) diff --git a/codes/models/archs/srflow_orig/SRFlowNet_arch.py b/codes/models/archs/srflow_orig/SRFlowNet_arch.py index 8973c69c..933fe57e 100644 --- a/codes/models/archs/srflow_orig/SRFlowNet_arch.py +++ b/codes/models/archs/srflow_orig/SRFlowNet_arch.py @@ -29,8 +29,9 @@ class SRFlowNet(nn.Module): self.RRDB_training = opt_get(self.opt, ['networks', 'generator','train_RRDB'], default=False) self.flow_scale = opt_get(self.opt, ['networks', 'generator', 'flow_scale'], default=opt['scale']) #