Support 4x zoom on ChainedEmbeddingGen

This commit is contained in:
James Betker 2020-10-23 09:25:58 -06:00
parent 8636492db0
commit 646d6a621a
4 changed files with 23 additions and 11 deletions

View File

@ -197,14 +197,16 @@ class StructuredChainedEmbeddingGenWithBypass(nn.Module):
class MultifacetedChainedEmbeddingGen(nn.Module):
def __init__(self, depth=10):
def __init__(self, depth=10, scale=2):
super(MultifacetedChainedEmbeddingGen, self).__init__()
assert scale == 2 or scale == 4
self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False)
self.teco_recurrent_process = ConvGnLelu(3, 64, kernel_size=3, stride=2, norm=False, bias=True, activation=False)
self.teco_recurrent_join = ReferenceJoinBlock(64, residual_weight_init_factor=.01, final_norm=False, kernel_size=1, depth=3, join=False)
self.prog_recurrent_process = ConvGnLelu(3, 64, kernel_size=3, stride=1, norm=False, bias=True, activation=False)
self.prog_recurrent_process = ConvGnLelu(64, 64, kernel_size=3, stride=1, norm=False, bias=True, activation=False)
self.prog_recurrent_join = ReferenceJoinBlock(64, residual_weight_init_factor=.01, final_norm=False, kernel_size=1, depth=3, join=False)
self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False)
@ -212,9 +214,10 @@ class MultifacetedChainedEmbeddingGen(nn.Module):
self.bypasses = nn.ModuleList([OptionalPassthroughBlock(64, initial_bias=0) for i in range(depth)])
self.structure_joins = nn.ModuleList([ConjoinBlock(64) for i in range(3)])
self.structure_blocks = nn.ModuleList([ConvGnLelu(64, 64, kernel_size=3, bias=False, norm=False, activation=False, weight_init_factor=.1) for i in range(3)])
self.structure_upsample = FinalUpsampleBlock2x(64)
self.structure_upsample = FinalUpsampleBlock2x(64, scale=scale)
self.grad_extract = ImageGradientNoPadding()
self.upsample = FinalUpsampleBlock2x(64)
self.upsample = FinalUpsampleBlock2x(64, scale=scale)
self.teco_ref_std = 0
self.prog_ref_std = 0
self.block_residual_means = [0 for _ in range(depth)]

View File

@ -488,12 +488,20 @@ class UpconvBlock(nn.Module):
# Scales an image up 2x and performs intermediary processing. Designed to be the final block in an SR network.
class FinalUpsampleBlock2x(nn.Module):
def __init__(self, nf, block=ConvGnLelu, out_nc=3):
def __init__(self, nf, block=ConvGnLelu, out_nc=3, scale=2):
super(FinalUpsampleBlock2x, self).__init__()
if scale == 2:
self.chain = nn.Sequential(block(nf, nf, kernel_size=3, norm=False, activation=True, bias=True),
UpconvBlock(nf, nf // 2, block=block, norm=False, activation=True, bias=True),
block(nf // 2, nf // 2, kernel_size=3, norm=False, activation=False, bias=True),
block(nf // 2, out_nc, kernel_size=3, norm=False, activation=False, bias=False))
else:
self.chain = nn.Sequential(block(nf, nf, kernel_size=3, norm=False, activation=True, bias=True),
UpconvBlock(nf, nf, block=block, norm=False, activation=True, bias=True),
block(nf, nf, kernel_size=3, norm=False, activation=False, bias=True),
UpconvBlock(nf, nf // 2, block=block, norm=False, activation=True, bias=True),
block(nf // 2, nf // 2, kernel_size=3, norm=False, activation=False, bias=True),
block(nf // 2, out_nc, kernel_size=3, norm=False, activation=False, bias=False))
def forward(self, x):
return self.chain(x)

View File

@ -140,7 +140,8 @@ def define_G(opt, net_key='network_G', scale=None):
bypass_bias = opt_net['bypass_bias'] if 'bypass_bias' in opt_net.keys() else 0
netG = StructuredChainedEmbeddingGenWithBypass(depth=opt_net['depth'], recurrent=rec, recurrent_nf=recnf, recurrent_stride=recstd, bypass_bias=bypass_bias)
elif which_model == 'multifaceted_chained':
netG = MultifacetedChainedEmbeddingGen(depth=opt_net['depth'])
scale = opt_net['scale'] if 'scale' in opt_net.keys() else 2
netG = MultifacetedChainedEmbeddingGen(depth=opt_net['depth'], scale=scale)
elif which_model == "flownet2":
from models.flownet2.models import FlowNet2
ld = torch.load(opt_net['load_path'])

View File

@ -278,7 +278,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_prog_imgset_multifaceted_chained.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_multifaceted_chained4x.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
args = parser.parse_args()
opt = option.parse(args.opt, is_train=True)