Support 4x zoom on ChainedEmbeddingGen
This commit is contained in:
parent
8636492db0
commit
646d6a621a
|
@ -197,14 +197,16 @@ class StructuredChainedEmbeddingGenWithBypass(nn.Module):
|
|||
|
||||
|
||||
class MultifacetedChainedEmbeddingGen(nn.Module):
|
||||
def __init__(self, depth=10):
|
||||
def __init__(self, depth=10, scale=2):
|
||||
super(MultifacetedChainedEmbeddingGen, self).__init__()
|
||||
assert scale == 2 or scale == 4
|
||||
|
||||
self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False)
|
||||
|
||||
self.teco_recurrent_process = ConvGnLelu(3, 64, kernel_size=3, stride=2, norm=False, bias=True, activation=False)
|
||||
self.teco_recurrent_join = ReferenceJoinBlock(64, residual_weight_init_factor=.01, final_norm=False, kernel_size=1, depth=3, join=False)
|
||||
|
||||
self.prog_recurrent_process = ConvGnLelu(3, 64, kernel_size=3, stride=1, norm=False, bias=True, activation=False)
|
||||
self.prog_recurrent_process = ConvGnLelu(64, 64, kernel_size=3, stride=1, norm=False, bias=True, activation=False)
|
||||
self.prog_recurrent_join = ReferenceJoinBlock(64, residual_weight_init_factor=.01, final_norm=False, kernel_size=1, depth=3, join=False)
|
||||
|
||||
self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False)
|
||||
|
@ -212,9 +214,10 @@ class MultifacetedChainedEmbeddingGen(nn.Module):
|
|||
self.bypasses = nn.ModuleList([OptionalPassthroughBlock(64, initial_bias=0) for i in range(depth)])
|
||||
self.structure_joins = nn.ModuleList([ConjoinBlock(64) for i in range(3)])
|
||||
self.structure_blocks = nn.ModuleList([ConvGnLelu(64, 64, kernel_size=3, bias=False, norm=False, activation=False, weight_init_factor=.1) for i in range(3)])
|
||||
self.structure_upsample = FinalUpsampleBlock2x(64)
|
||||
self.structure_upsample = FinalUpsampleBlock2x(64, scale=scale)
|
||||
self.grad_extract = ImageGradientNoPadding()
|
||||
self.upsample = FinalUpsampleBlock2x(64)
|
||||
self.upsample = FinalUpsampleBlock2x(64, scale=scale)
|
||||
|
||||
self.teco_ref_std = 0
|
||||
self.prog_ref_std = 0
|
||||
self.block_residual_means = [0 for _ in range(depth)]
|
||||
|
|
|
@ -488,12 +488,20 @@ class UpconvBlock(nn.Module):
|
|||
|
||||
# Scales an image up 2x and performs intermediary processing. Designed to be the final block in an SR network.
|
||||
class FinalUpsampleBlock2x(nn.Module):
|
||||
def __init__(self, nf, block=ConvGnLelu, out_nc=3):
|
||||
def __init__(self, nf, block=ConvGnLelu, out_nc=3, scale=2):
|
||||
super(FinalUpsampleBlock2x, self).__init__()
|
||||
self.chain = nn.Sequential(block(nf, nf, kernel_size=3, norm=False, activation=True, bias=True),
|
||||
UpconvBlock(nf, nf // 2, block=block, norm=False, activation=True, bias=True),
|
||||
block(nf // 2, nf // 2, kernel_size=3, norm=False, activation=False, bias=True),
|
||||
block(nf // 2, out_nc, kernel_size=3, norm=False, activation=False, bias=False))
|
||||
if scale == 2:
|
||||
self.chain = nn.Sequential(block(nf, nf, kernel_size=3, norm=False, activation=True, bias=True),
|
||||
UpconvBlock(nf, nf // 2, block=block, norm=False, activation=True, bias=True),
|
||||
block(nf // 2, nf // 2, kernel_size=3, norm=False, activation=False, bias=True),
|
||||
block(nf // 2, out_nc, kernel_size=3, norm=False, activation=False, bias=False))
|
||||
else:
|
||||
self.chain = nn.Sequential(block(nf, nf, kernel_size=3, norm=False, activation=True, bias=True),
|
||||
UpconvBlock(nf, nf, block=block, norm=False, activation=True, bias=True),
|
||||
block(nf, nf, kernel_size=3, norm=False, activation=False, bias=True),
|
||||
UpconvBlock(nf, nf // 2, block=block, norm=False, activation=True, bias=True),
|
||||
block(nf // 2, nf // 2, kernel_size=3, norm=False, activation=False, bias=True),
|
||||
block(nf // 2, out_nc, kernel_size=3, norm=False, activation=False, bias=False))
|
||||
|
||||
def forward(self, x):
|
||||
return self.chain(x)
|
||||
|
|
|
@ -140,7 +140,8 @@ def define_G(opt, net_key='network_G', scale=None):
|
|||
bypass_bias = opt_net['bypass_bias'] if 'bypass_bias' in opt_net.keys() else 0
|
||||
netG = StructuredChainedEmbeddingGenWithBypass(depth=opt_net['depth'], recurrent=rec, recurrent_nf=recnf, recurrent_stride=recstd, bypass_bias=bypass_bias)
|
||||
elif which_model == 'multifaceted_chained':
|
||||
netG = MultifacetedChainedEmbeddingGen(depth=opt_net['depth'])
|
||||
scale = opt_net['scale'] if 'scale' in opt_net.keys() else 2
|
||||
netG = MultifacetedChainedEmbeddingGen(depth=opt_net['depth'], scale=scale)
|
||||
elif which_model == "flownet2":
|
||||
from models.flownet2.models import FlowNet2
|
||||
ld = torch.load(opt_net['load_path'])
|
||||
|
|
|
@ -278,7 +278,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_prog_imgset_multifaceted_chained.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_multifaceted_chained4x.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
args = parser.parse_args()
|
||||
opt = option.parse(args.opt, is_train=True)
|
||||
|
|
Loading…
Reference in New Issue
Block a user