diff --git a/codes/models/archs/ChainedEmbeddingGen.py b/codes/models/archs/ChainedEmbeddingGen.py index e38c4cfc..00ea4d13 100644 --- a/codes/models/archs/ChainedEmbeddingGen.py +++ b/codes/models/archs/ChainedEmbeddingGen.py @@ -1,6 +1,7 @@ import torch from torch import nn +from models.archs.SPSR_arch import ImageGradientNoPadding from models.archs.arch_util import ConvGnLelu, ExpansionBlock2, ConvGnSilu, ConjoinBlock, MultiConvBlock, \ FinalUpsampleBlock2x from models.archs.spinenet_arch import SpineNet @@ -49,11 +50,11 @@ class BasicEmbeddingPyramid(nn.Module): class ChainedEmbeddingGen(nn.Module): - def __init__(self): + def __init__(self, depth=10): super(ChainedEmbeddingGen, self).__init__() self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False) self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False) - self.blocks = nn.ModuleList([BasicEmbeddingPyramid() for i in range(5)]) + self.blocks = nn.ModuleList([BasicEmbeddingPyramid() for i in range(depth)]) self.upsample = FinalUpsampleBlock2x(64) def forward(self, x): @@ -62,3 +63,56 @@ class ChainedEmbeddingGen(nn.Module): for block in self.blocks: fea = fea + checkpoint(block, fea, *emb) return checkpoint(self.upsample, fea), + + +class ChainedEmbeddingGenWithStructure(nn.Module): + def __init__(self, depth=10): + super(ChainedEmbeddingGenWithStructure, self).__init__() + self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False) + self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False) + self.blocks = nn.ModuleList([BasicEmbeddingPyramid() for i in range(depth)]) + self.structure_joins = nn.ModuleList([ConjoinBlock(64) for i in range(3)]) + self.structure_blocks = nn.ModuleList([ConvGnLelu(64, 64, kernel_size=3, bias=False, norm=False, activation=False, weight_init_factor=.1) for i in range(3)]) + self.structure_upsample = FinalUpsampleBlock2x(64) + self.grad_extract = ImageGradientNoPadding() + self.upsample = FinalUpsampleBlock2x(64) + + def forward(self, x): + emb = checkpoint(self.spine, x) + fea = self.initial_conv(x) + grad = fea + for i, block in enumerate(self.blocks): + fea = fea + checkpoint(block, fea, *emb) + if i < 3: + structure_br = checkpoint(self.structure_joins[i], grad, fea) + grad = grad + checkpoint(self.structure_blocks[i], structure_br) + out = checkpoint(self.upsample, fea) + return out, self.grad_extract(checkpoint(self.structure_upsample, grad)), self.grad_extract(out) + + +class ChainedEmbeddingGenWithStructureR2(nn.Module): + def __init__(self, depth=10): + super(ChainedEmbeddingGenWithStructureR2, self).__init__() + self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False) + self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False) + self.blocks = nn.ModuleList([BasicEmbeddingPyramid() for i in range(depth)]) + self.structure_joins = nn.ModuleList([ConjoinBlock(64) for i in range(3)]) + self.structure_blocks = nn.ModuleList([ConvGnLelu(64, 64, kernel_size=3, bias=False, norm=False, activation=False, weight_init_factor=.1) for i in range(3)]) + self.structure_upsample = FinalUpsampleBlock2x(64) + self.structure_rejoin = ConjoinBlock(64) + self.grad_extract = ImageGradientNoPadding() + self.upsample = FinalUpsampleBlock2x(64) + + def forward(self, x): + emb = checkpoint(self.spine, x) + fea = self.initial_conv(x) + grad = fea + for i, block in enumerate(self.blocks): + fea = fea + checkpoint(block, fea, *emb) + if i < 3: + structure_br = checkpoint(self.structure_joins[i], grad, fea) + grad = grad + checkpoint(self.structure_blocks[i], structure_br) + if i == 3: + fea = fea + self.structure_rejoin(fea, grad) + out = checkpoint(self.upsample, fea) + return out, self.grad_extract(checkpoint(self.structure_upsample, grad)), self.grad_extract(out) diff --git a/codes/models/networks.py b/codes/models/networks.py index c02ae705..a6841346 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -17,7 +17,8 @@ from collections import OrderedDict import torchvision import functools -from models.archs.ChainedEmbeddingGen import ChainedEmbeddingGen +from models.archs.ChainedEmbeddingGen import ChainedEmbeddingGen, ChainedEmbeddingGenWithStructure, \ + ChainedEmbeddingGenWithStructureR2 logger = logging.getLogger('base') @@ -123,7 +124,11 @@ def define_G(opt, net_key='network_G', scale=None): netG = SwitchedGen_arch.ArtistGen(opt_net['in_nc'], nf=opt_net['nf'], xforms=opt_net['num_transforms'], upscale=opt_net['scale'], init_temperature=opt_net['temperature']) elif which_model == 'chained_gen': - netG = ChainedEmbeddingGen() + netG = ChainedEmbeddingGen(depth=opt_net['depth']) + elif which_model == 'chained_gen_structured': + netG = ChainedEmbeddingGenWithStructure(depth=opt_net['depth']) + elif which_model == 'chained_gen_structuredr2': + netG = ChainedEmbeddingGenWithStructureR2(depth=opt_net['depth']) elif which_model == "flownet2": from models.flownet2.models import FlowNet2 ld = torch.load(opt_net['load_path']) diff --git a/codes/train.py b/codes/train.py index 510249aa..de9ce78d 100644 --- a/codes/train.py +++ b/codes/train.py @@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_ssgsimpler.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_chained_structuredr2.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/train2.py b/codes/train2.py index 7bcdf092..396ab01d 100644 --- a/codes/train2.py +++ b/codes/train2.py @@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_ssgdeep.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_chained_structured.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()