From 24bdcc11819d237bdb1580c96e4874611ba29a07 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 18 Aug 2020 09:10:25 -0600 Subject: [PATCH] Let SwitchedSpsr transform count be specified --- codes/models/archs/SPSR_arch.py | 6 +++--- codes/models/networks.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/codes/models/archs/SPSR_arch.py b/codes/models/archs/SPSR_arch.py index 86d0f07b..965d810e 100644 --- a/codes/models/archs/SPSR_arch.py +++ b/codes/models/archs/SPSR_arch.py @@ -228,11 +228,11 @@ class SPSRNet(nn.Module): x_out = self._branch_pretrain_HR_conv1(x_out) ######### - return x_out_branch, x_out, x_grad + return x_out_branch, x_out, x_gradn class SwitchedSpsr(nn.Module): - def __init__(self, in_nc, out_nc, nf, upscale=4): + def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4): super(SwitchedSpsr, self).__init__() n_upscale = int(math.log(upscale, 2)) @@ -241,7 +241,7 @@ class SwitchedSpsr(nn.Module): switch_filters = nf switch_reductions = 3 switch_processing_layers = 2 - self.transformation_counts = 8 + self.transformation_counts = xforms multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, switch_reductions, switch_processing_layers, self.transformation_counts, use_exp2=True) pretransform_fn = functools.partial(ConvGnLelu, transformation_filters, transformation_filters, norm=False, bias=False, weight_init_factor=.1) diff --git a/codes/models/networks.py b/codes/models/networks.py index 85397f33..4d8a0d25 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -112,7 +112,8 @@ def define_G(opt, net_key='network_G'): netG = spsr.SPSRNetSimplifiedNoSkip(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale']) elif which_model == "spsr_switched": - netG = spsr.SwitchedSpsr(in_nc=3, out_nc=3, nf=opt_net['nf'], upscale=opt_net['scale']) + xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8 + netG = spsr.SwitchedSpsr(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale']) # image corruption elif which_model == 'HighToLowResNet':