From 1f20d59c31a835cb516fb03336aca740731e424b Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 14 Oct 2020 11:03:34 -0600 Subject: [PATCH] Revert big switch back --- .../archs/SwitchedResidualGenerator_arch.py | 50 ++++--------------- codes/models/networks.py | 2 +- codes/train2.py | 2 +- 3 files changed, 11 insertions(+), 43 deletions(-) diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index 79837f99..eb49aa62 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -606,7 +606,7 @@ class SwitchModelBase(nn.Module): from models.archs.spinenet_arch import make_res_layer, BasicBlock class BigMultiplexer(nn.Module): - def __init__(self, in_nc, nf, mode, multiplexer_channels): + def __init__(self, in_nc, nf, multiplexer_channels): super(BigMultiplexer, self).__init__() self.spine = SpineNet(arch='96', output_level=[3], double_reduce_early=False) @@ -615,54 +615,28 @@ class BigMultiplexer(nn.Module): self.tail_proc = make_res_layer(BasicBlock, nf, nf, 2) self.tail_join = ReferenceJoinBlock(nf) - self.mode = mode - if mode == 0: - self.key_process = ConvGnSilu(nf, nf, kernel_size=1, activation=True, norm=False, bias=True) - self.query_key_combine = ConvGnSilu(nf*2, nf, kernel_size=3, activation=True, norm=False, bias=False) - self.cbl0 = ConvGnSilu(nf, nf, kernel_size=3, activation=True, norm=True, bias=False) - self.cbl1 = ConvGnSilu(nf, nf // 2, kernel_size=1, norm=True, bias=False, num_groups=4) - self.cbl2 = ConvGnSilu(nf // 2, 1, kernel_size=1, norm=False, bias=False) - else: - self.key_process = ConvGnSilu(nf, nf, kernel_size=3, activation=True, norm=False, bias=True) - self.query_key_combine = ConvGnSilu(nf*2, nf, kernel_size=1, activation=True, norm=True, bias=False) - self.cbl0 = ConvGnSilu(nf, nf, kernel_size=1, activation=True, norm=True, bias=False) - self.cbl1 = ConvGnSilu(nf, nf // 2, kernel_size=1, activation=True, norm=False, bias=False) - self.cbl2 = ConvGnSilu(nf // 2, 1, kernel_size=1, activation=False, norm=False, bias=False) + self.reduce = nn.Sequential(ConvGnSilu(nf, nf // 2, kernel_size=1, activation=True, norm=True, bias=False), + ConvGnSilu(nf // 2, multiplexer_channels, kernel_size=1, activation=False, norm=False, bias=False)) def forward(self, x, transformations): s = self.spine(x)[0] tail = self.fea_tail(x) tail = self.tail_proc(tail) - if self.mode == 0: - q = F.interpolate(s, scale_factor=2, mode='bilinear') - else: - q = F.interpolate(s, scale_factor=2, mode='nearest') + q = F.interpolate(s, scale_factor=2, mode='nearest') q = self.spine_red_proc(q) q, _ = self.tail_join(q, tail) - - b, t, f, h, w = transformations.shape - k = transformations.view(b * t, f, h, w) - k = self.key_process(k) - - q = q.view(b, 1, f, h, w).repeat(1, t, 1, 1, 1).view(b * t, f, h, w) - v = self.query_key_combine(torch.cat([q, k], dim=1)) - v = self.cbl0(v) - v = self.cbl1(v) - v = self.cbl2(v) - - return v.view(b, t, h, w) + return self.reduce(q) class TheBigSwitch(SwitchModelBase): - def __init__(self, in_nc, nf, xforms=16, upscale=2, mode=0, init_temperature=10): + def __init__(self, in_nc, nf, xforms=16, upscale=2, init_temperature=10): super(TheBigSwitch, self).__init__(init_temperature, 10000) self.nf = nf self.transformation_counts = xforms - self.mode = mode self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=7, norm=False, activation=False) - multiplx_fn = functools.partial(BigMultiplexer, in_nc, nf, mode) + multiplx_fn = functools.partial(BigMultiplexer, in_nc, nf) transform_fn = functools.partial(MultiConvBlock, nf, int(nf * 1.5), nf, kernel_size=3, depth=4, weight_init_factor=.1) self.switch = ConfigurableSwitchComputer(nf, multiplx_fn, pre_transform_block=None, transform_block=transform_fn, @@ -686,20 +660,14 @@ class TheBigSwitch(SwitchModelBase): sw.set_update_attention_norm(save_attentions) x1 = self.model_fea_conv(x) - if self.mode == 0: - x1, a1 = self.switch(x1, att_in=x, do_checkpointing=True) - else: - x1, a1, attlogits = self.switch(x1, att_in=x, do_checkpointing=True, output_att_logits=True) + x1, a1 = self.switch(x1, att_in=x, do_checkpointing=True) x_out = checkpoint(self.final_lr_conv, x1) x_out = checkpoint(self.upsample, x_out) x_out = checkpoint(self.final_hr_conv2, x_out) if save_attentions: self.attentions = [a1] - if self.mode == 0: - return x_out, - else: - return x_out, attlogits.permute(0,3,1,2) + return x_out, if __name__ == '__main__': diff --git a/codes/models/networks.py b/codes/models/networks.py index 3e6ebeed..e831d690 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -111,7 +111,7 @@ def define_G(opt, net_key='network_G', scale=None): netG = ssg.StackedSwitchGenerator2xTeco(nf=opt_net['nf'], xforms=opt_net['num_transforms'], init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10) elif which_model == 'big_switch': netG = SwitchedGen_arch.TheBigSwitch(opt_net['in_nc'], nf=opt_net['nf'], xforms=opt_net['num_transforms'], upscale=opt_net['scale'], - init_temperature=opt_net['temperature'], mode=opt_net['mode']) + init_temperature=opt_net['temperature']) elif which_model == "flownet2": from models.flownet2.models import FlowNet2 ld = torch.load(opt_net['load_path']) diff --git a/codes/train2.py b/codes/train2.py index bccb5119..7bcdf092 100644 --- a/codes/train2.py +++ b/codes/train2.py @@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_bigswitch_att_invariance.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_ssgdeep.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()