diff --git a/codes/models/archs/ChainedEmbeddingGen.py b/codes/models/archs/ChainedEmbeddingGen.py new file mode 100644 index 00000000..e38c4cfc --- /dev/null +++ b/codes/models/archs/ChainedEmbeddingGen.py @@ -0,0 +1,64 @@ +import torch +from torch import nn + +from models.archs.arch_util import ConvGnLelu, ExpansionBlock2, ConvGnSilu, ConjoinBlock, MultiConvBlock, \ + FinalUpsampleBlock2x +from models.archs.spinenet_arch import SpineNet +from utils.util import checkpoint + + +class BasicEmbeddingPyramid(nn.Module): + def __init__(self, use_norms=True): + super(BasicEmbeddingPyramid, self).__init__() + self.initial_process = ConvGnLelu(64, 64, kernel_size=1, bias=True, activation=True, norm=False) + self.reducers = nn.ModuleList([ConvGnLelu(64, 128, stride=2, kernel_size=1, bias=False, activation=True, norm=False), + ConvGnLelu(128, 128, kernel_size=3, bias=False, activation=True, norm=use_norms), + ConvGnLelu(128, 256, stride=2, kernel_size=1, bias=False, activation=True, norm=False), + ConvGnLelu(256, 256, kernel_size=3, bias=False, activation=True, norm=use_norms)]) + self.expanders = nn.ModuleList([ExpansionBlock2(256, 128, block=ConvGnLelu), + ExpansionBlock2(128, 64, block=ConvGnLelu)]) + self.embedding_processor1 = ConvGnSilu(256, 128, kernel_size=1, bias=True, activation=True, norm=False) + self.embedding_joiner1 = ConjoinBlock(128, block=ConvGnLelu, norm=use_norms) + self.embedding_processor2 = ConvGnSilu(256, 256, kernel_size=1, bias=True, activation=True, norm=False) + self.embedding_joiner2 = ConjoinBlock(256, block=ConvGnLelu, norm=use_norms) + + self.final_process = nn.Sequential(ConvGnLelu(128, 96, kernel_size=1, bias=False, activation=False, norm=False, + weight_init_factor=.1), + ConvGnLelu(96, 64, kernel_size=1, bias=False, activation=False, norm=False, + weight_init_factor=.1), + ConvGnLelu(64, 64, kernel_size=1, bias=False, activation=False, norm=False, + weight_init_factor=.1), + ConvGnLelu(64, 64, kernel_size=1, bias=False, activation=False, norm=False, + weight_init_factor=.1)) + + def forward(self, x, *embeddings): + p = self.initial_process(x) + identities = [] + for i in range(2): + identities.append(p) + p = self.reducers[i*2](p) + p = self.reducers[i*2+1](p) + if i == 0: + p = self.embedding_joiner1(p, self.embedding_processor1(embeddings[0])) + elif i == 1: + p = self.embedding_joiner2(p, self.embedding_processor2(embeddings[1])) + for i in range(2): + p = self.expanders[i](p, identities[-(i+1)]) + x = self.final_process(torch.cat([x, p], dim=1)) + return x + + +class ChainedEmbeddingGen(nn.Module): + def __init__(self): + super(ChainedEmbeddingGen, self).__init__() + self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False) + self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False) + self.blocks = nn.ModuleList([BasicEmbeddingPyramid() for i in range(5)]) + self.upsample = FinalUpsampleBlock2x(64) + + def forward(self, x): + emb = checkpoint(self.spine, x) + fea = self.initial_conv(x) + for block in self.blocks: + fea = fea + checkpoint(block, fea, *emb) + return checkpoint(self.upsample, fea), diff --git a/codes/models/archs/arch_util.py b/codes/models/archs/arch_util.py index 418d26bb..394e2326 100644 --- a/codes/models/archs/arch_util.py +++ b/codes/models/archs/arch_util.py @@ -484,3 +484,16 @@ class UpconvBlock(nn.Module): def forward(self, x): x = F.interpolate(x, scale_factor=2, mode="nearest") return self.process(x) + + +# Scales an image up 2x and performs intermediary processing. Designed to be the final block in an SR network. +class FinalUpsampleBlock2x(nn.Module): + def __init__(self, nf, block=ConvGnLelu): + super(FinalUpsampleBlock2x, self).__init__() + self.chain = nn.Sequential(block(nf, nf, kernel_size=3, norm=False, activation=True, bias=True), + UpconvBlock(nf, nf // 2, block=block, norm=False, activation=True, bias=True), + block(nf // 2, nf // 2, kernel_size=3, norm=False, activation=False, bias=True), + block(nf // 2, 3, kernel_size=3, norm=False, activation=False, bias=False)) + + def forward(self, x): + return self.chain(x) diff --git a/codes/models/networks.py b/codes/models/networks.py index 212a40de..c02ae705 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -17,6 +17,7 @@ from collections import OrderedDict import torchvision import functools +from models.archs.ChainedEmbeddingGen import ChainedEmbeddingGen logger = logging.getLogger('base') @@ -121,6 +122,8 @@ def define_G(opt, net_key='network_G', scale=None): elif which_model == 'artist': netG = SwitchedGen_arch.ArtistGen(opt_net['in_nc'], nf=opt_net['nf'], xforms=opt_net['num_transforms'], upscale=opt_net['scale'], init_temperature=opt_net['temperature']) + elif which_model == 'chained_gen': + netG = ChainedEmbeddingGen() elif which_model == "flownet2": from models.flownet2.models import FlowNet2 ld = torch.load(opt_net['load_path'])