From 519ba6f10c769fa810b93cdcb1c12f3e08850e3b Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 21 Nov 2020 14:46:15 -0700 Subject: [PATCH] Support 2x RRDB with 4x srflow --- .../archs/srflow_orig/FlowUpsamplerNet.py | 2 +- .../models/archs/srflow_orig/RRDBNet_arch.py | 51 ++++++++++++++----- .../archs/srflow_orig/SRFlowNet_arch.py | 11 ++-- codes/train2.py | 2 +- 4 files changed, 45 insertions(+), 21 deletions(-) diff --git a/codes/models/archs/srflow_orig/FlowUpsamplerNet.py b/codes/models/archs/srflow_orig/FlowUpsamplerNet.py index 52fd918f..4a4c8ce5 100644 --- a/codes/models/archs/srflow_orig/FlowUpsamplerNet.py +++ b/codes/models/archs/srflow_orig/FlowUpsamplerNet.py @@ -48,7 +48,7 @@ class FlowUpsamplerNet(nn.Module): 4: 'fea_up0' } - elif opt['scale'] == 4: + elif opt['scale'] == 4 or opt['scale'] == 2: self.levelToName = { 0: 'fea_up4', 1: 'fea_up2', diff --git a/codes/models/archs/srflow_orig/RRDBNet_arch.py b/codes/models/archs/srflow_orig/RRDBNet_arch.py index a6d75117..607f4e0e 100644 --- a/codes/models/archs/srflow_orig/RRDBNet_arch.py +++ b/codes/models/archs/srflow_orig/RRDBNet_arch.py @@ -186,21 +186,44 @@ class RRDBNet(nn.Module): out = self.conv_last(self.lrelu(self.conv_hr(fea))) - results = {'last_lr_fea': last_lr_fea, - 'fea_up1': last_lr_fea, - 'fea_up2': fea_up2, - 'fea_up4': fea_up4, - 'fea_up8': fea_up8, - 'fea_up16': fea_up16, - 'fea_up32': fea_up32, - 'out': out} + if self.scale >= 4: + results = {'last_lr_fea': last_lr_fea, + 'fea_up1': last_lr_fea, + 'fea_up2': fea_up2, + 'fea_up4': fea_up4, + 'fea_up8': fea_up8, + 'fea_up16': fea_up16, + 'fea_up32': fea_up32, + 'out': out} - fea_up0_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up0']) or False - if fea_up0_en: - results['fea_up0'] = F.interpolate(last_lr_fea, scale_factor=1/2, mode='bilinear', align_corners=False, recompute_scale_factor=True) - fea_upn1_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up-1']) or False - if fea_upn1_en: - results['fea_up-1'] = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True) + fea_up0_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up0']) or False + if fea_up0_en: + results['fea_up0'] = F.interpolate(last_lr_fea, scale_factor=1/2, mode='bilinear', align_corners=False, recompute_scale_factor=True) + fea_upn1_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up-1']) or False + if fea_upn1_en: + results['fea_up-1'] = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True) + elif self.scale == 2: + # "Pretend" this is is 4x by shuffling around the inputs a bit. + half = F.interpolate(last_lr_fea, scale_factor=1/2, mode='bilinear', align_corners=False, recompute_scale_factor=True) + quarter = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True) + eighth = F.interpolate(last_lr_fea, scale_factor=1/8, mode='bilinear', align_corners=False, recompute_scale_factor=True) + results = {'last_lr_fea': half, + 'fea_up1': half, + 'fea_up2': last_lr_fea, + 'fea_up4': fea_up2, + 'fea_up8': fea_up4, + 'fea_up16': fea_up8, + 'fea_up32': fea_up16, + 'out': out} + + fea_up0_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up0']) or False + if fea_up0_en: + results['fea_up0'] = quarter + fea_upn1_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up-1']) or False + if fea_upn1_en: + results['fea_up-1'] = eighth + else: + raise NotImplementedError if get_steps: for k, v in block_results.items(): diff --git a/codes/models/archs/srflow_orig/SRFlowNet_arch.py b/codes/models/archs/srflow_orig/SRFlowNet_arch.py index f972a8f4..66a52f26 100644 --- a/codes/models/archs/srflow_orig/SRFlowNet_arch.py +++ b/codes/models/archs/srflow_orig/SRFlowNet_arch.py @@ -27,6 +27,7 @@ class SRFlowNet(nn.Module): hidden_channels = opt_get(opt, ['networks', 'generator','flow', 'hidden_channels']) hidden_channels = hidden_channels or 64 self.RRDB_training = opt_get(self.opt, ['networks', 'generator','train_RRDB'], default=False) + self.flow_scale = opt_get(self.opt, ['networks', 'generator', 'flow_scale'], default=opt['scale']) #