From 646d6a621a0f0ff1a2cd5999d33414b82664a200 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 23 Oct 2020 09:25:58 -0600 Subject: [PATCH] Support 4x zoom on ChainedEmbeddingGen --- codes/models/archs/ChainedEmbeddingGen.py | 11 +++++++---- codes/models/archs/arch_util.py | 18 +++++++++++++----- codes/models/networks.py | 3 ++- codes/train2.py | 2 +- 4 files changed, 23 insertions(+), 11 deletions(-) diff --git a/codes/models/archs/ChainedEmbeddingGen.py b/codes/models/archs/ChainedEmbeddingGen.py index a425a322..90a9e426 100644 --- a/codes/models/archs/ChainedEmbeddingGen.py +++ b/codes/models/archs/ChainedEmbeddingGen.py @@ -197,14 +197,16 @@ class StructuredChainedEmbeddingGenWithBypass(nn.Module): class MultifacetedChainedEmbeddingGen(nn.Module): - def __init__(self, depth=10): + def __init__(self, depth=10, scale=2): super(MultifacetedChainedEmbeddingGen, self).__init__() + assert scale == 2 or scale == 4 + self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False) self.teco_recurrent_process = ConvGnLelu(3, 64, kernel_size=3, stride=2, norm=False, bias=True, activation=False) self.teco_recurrent_join = ReferenceJoinBlock(64, residual_weight_init_factor=.01, final_norm=False, kernel_size=1, depth=3, join=False) - self.prog_recurrent_process = ConvGnLelu(3, 64, kernel_size=3, stride=1, norm=False, bias=True, activation=False) + self.prog_recurrent_process = ConvGnLelu(64, 64, kernel_size=3, stride=1, norm=False, bias=True, activation=False) self.prog_recurrent_join = ReferenceJoinBlock(64, residual_weight_init_factor=.01, final_norm=False, kernel_size=1, depth=3, join=False) self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False) @@ -212,9 +214,10 @@ class MultifacetedChainedEmbeddingGen(nn.Module): self.bypasses = nn.ModuleList([OptionalPassthroughBlock(64, initial_bias=0) for i in range(depth)]) self.structure_joins = nn.ModuleList([ConjoinBlock(64) for i in range(3)]) self.structure_blocks = nn.ModuleList([ConvGnLelu(64, 64, kernel_size=3, bias=False, norm=False, activation=False, weight_init_factor=.1) for i in range(3)]) - self.structure_upsample = FinalUpsampleBlock2x(64) + self.structure_upsample = FinalUpsampleBlock2x(64, scale=scale) self.grad_extract = ImageGradientNoPadding() - self.upsample = FinalUpsampleBlock2x(64) + self.upsample = FinalUpsampleBlock2x(64, scale=scale) + self.teco_ref_std = 0 self.prog_ref_std = 0 self.block_residual_means = [0 for _ in range(depth)] diff --git a/codes/models/archs/arch_util.py b/codes/models/archs/arch_util.py index 5e5c44b4..04d49ee0 100644 --- a/codes/models/archs/arch_util.py +++ b/codes/models/archs/arch_util.py @@ -488,12 +488,20 @@ class UpconvBlock(nn.Module): # Scales an image up 2x and performs intermediary processing. Designed to be the final block in an SR network. class FinalUpsampleBlock2x(nn.Module): - def __init__(self, nf, block=ConvGnLelu, out_nc=3): + def __init__(self, nf, block=ConvGnLelu, out_nc=3, scale=2): super(FinalUpsampleBlock2x, self).__init__() - self.chain = nn.Sequential(block(nf, nf, kernel_size=3, norm=False, activation=True, bias=True), - UpconvBlock(nf, nf // 2, block=block, norm=False, activation=True, bias=True), - block(nf // 2, nf // 2, kernel_size=3, norm=False, activation=False, bias=True), - block(nf // 2, out_nc, kernel_size=3, norm=False, activation=False, bias=False)) + if scale == 2: + self.chain = nn.Sequential(block(nf, nf, kernel_size=3, norm=False, activation=True, bias=True), + UpconvBlock(nf, nf // 2, block=block, norm=False, activation=True, bias=True), + block(nf // 2, nf // 2, kernel_size=3, norm=False, activation=False, bias=True), + block(nf // 2, out_nc, kernel_size=3, norm=False, activation=False, bias=False)) + else: + self.chain = nn.Sequential(block(nf, nf, kernel_size=3, norm=False, activation=True, bias=True), + UpconvBlock(nf, nf, block=block, norm=False, activation=True, bias=True), + block(nf, nf, kernel_size=3, norm=False, activation=False, bias=True), + UpconvBlock(nf, nf // 2, block=block, norm=False, activation=True, bias=True), + block(nf // 2, nf // 2, kernel_size=3, norm=False, activation=False, bias=True), + block(nf // 2, out_nc, kernel_size=3, norm=False, activation=False, bias=False)) def forward(self, x): return self.chain(x) diff --git a/codes/models/networks.py b/codes/models/networks.py index abeecefc..176a6d2b 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -140,7 +140,8 @@ def define_G(opt, net_key='network_G', scale=None): bypass_bias = opt_net['bypass_bias'] if 'bypass_bias' in opt_net.keys() else 0 netG = StructuredChainedEmbeddingGenWithBypass(depth=opt_net['depth'], recurrent=rec, recurrent_nf=recnf, recurrent_stride=recstd, bypass_bias=bypass_bias) elif which_model == 'multifaceted_chained': - netG = MultifacetedChainedEmbeddingGen(depth=opt_net['depth']) + scale = opt_net['scale'] if 'scale' in opt_net.keys() else 2 + netG = MultifacetedChainedEmbeddingGen(depth=opt_net['depth'], scale=scale) elif which_model == "flownet2": from models.flownet2.models import FlowNet2 ld = torch.load(opt_net['load_path']) diff --git a/codes/train2.py b/codes/train2.py index c7232118..81bd7c35 100644 --- a/codes/train2.py +++ b/codes/train2.py @@ -278,7 +278,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_prog_imgset_multifaceted_chained.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_multifaceted_chained4x.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') args = parser.parse_args() opt = option.parse(args.opt, is_train=True)