Add ChainedEmbeddingGen

This commit is contained in:
James Betker 2020-10-15 23:18:08 -06:00
parent c4543ce124
commit 617d97e19d
3 changed files with 80 additions and 0 deletions

View File

@ -0,0 +1,64 @@
import torch
from torch import nn
from models.archs.arch_util import ConvGnLelu, ExpansionBlock2, ConvGnSilu, ConjoinBlock, MultiConvBlock, \
FinalUpsampleBlock2x
from models.archs.spinenet_arch import SpineNet
from utils.util import checkpoint
class BasicEmbeddingPyramid(nn.Module):
def __init__(self, use_norms=True):
super(BasicEmbeddingPyramid, self).__init__()
self.initial_process = ConvGnLelu(64, 64, kernel_size=1, bias=True, activation=True, norm=False)
self.reducers = nn.ModuleList([ConvGnLelu(64, 128, stride=2, kernel_size=1, bias=False, activation=True, norm=False),
ConvGnLelu(128, 128, kernel_size=3, bias=False, activation=True, norm=use_norms),
ConvGnLelu(128, 256, stride=2, kernel_size=1, bias=False, activation=True, norm=False),
ConvGnLelu(256, 256, kernel_size=3, bias=False, activation=True, norm=use_norms)])
self.expanders = nn.ModuleList([ExpansionBlock2(256, 128, block=ConvGnLelu),
ExpansionBlock2(128, 64, block=ConvGnLelu)])
self.embedding_processor1 = ConvGnSilu(256, 128, kernel_size=1, bias=True, activation=True, norm=False)
self.embedding_joiner1 = ConjoinBlock(128, block=ConvGnLelu, norm=use_norms)
self.embedding_processor2 = ConvGnSilu(256, 256, kernel_size=1, bias=True, activation=True, norm=False)
self.embedding_joiner2 = ConjoinBlock(256, block=ConvGnLelu, norm=use_norms)
self.final_process = nn.Sequential(ConvGnLelu(128, 96, kernel_size=1, bias=False, activation=False, norm=False,
weight_init_factor=.1),
ConvGnLelu(96, 64, kernel_size=1, bias=False, activation=False, norm=False,
weight_init_factor=.1),
ConvGnLelu(64, 64, kernel_size=1, bias=False, activation=False, norm=False,
weight_init_factor=.1),
ConvGnLelu(64, 64, kernel_size=1, bias=False, activation=False, norm=False,
weight_init_factor=.1))
def forward(self, x, *embeddings):
p = self.initial_process(x)
identities = []
for i in range(2):
identities.append(p)
p = self.reducers[i*2](p)
p = self.reducers[i*2+1](p)
if i == 0:
p = self.embedding_joiner1(p, self.embedding_processor1(embeddings[0]))
elif i == 1:
p = self.embedding_joiner2(p, self.embedding_processor2(embeddings[1]))
for i in range(2):
p = self.expanders[i](p, identities[-(i+1)])
x = self.final_process(torch.cat([x, p], dim=1))
return x
class ChainedEmbeddingGen(nn.Module):
def __init__(self):
super(ChainedEmbeddingGen, self).__init__()
self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False)
self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False)
self.blocks = nn.ModuleList([BasicEmbeddingPyramid() for i in range(5)])
self.upsample = FinalUpsampleBlock2x(64)
def forward(self, x):
emb = checkpoint(self.spine, x)
fea = self.initial_conv(x)
for block in self.blocks:
fea = fea + checkpoint(block, fea, *emb)
return checkpoint(self.upsample, fea),

View File

@ -484,3 +484,16 @@ class UpconvBlock(nn.Module):
def forward(self, x): def forward(self, x):
x = F.interpolate(x, scale_factor=2, mode="nearest") x = F.interpolate(x, scale_factor=2, mode="nearest")
return self.process(x) return self.process(x)
# Scales an image up 2x and performs intermediary processing. Designed to be the final block in an SR network.
class FinalUpsampleBlock2x(nn.Module):
def __init__(self, nf, block=ConvGnLelu):
super(FinalUpsampleBlock2x, self).__init__()
self.chain = nn.Sequential(block(nf, nf, kernel_size=3, norm=False, activation=True, bias=True),
UpconvBlock(nf, nf // 2, block=block, norm=False, activation=True, bias=True),
block(nf // 2, nf // 2, kernel_size=3, norm=False, activation=False, bias=True),
block(nf // 2, 3, kernel_size=3, norm=False, activation=False, bias=False))
def forward(self, x):
return self.chain(x)

View File

@ -17,6 +17,7 @@ from collections import OrderedDict
import torchvision import torchvision
import functools import functools
from models.archs.ChainedEmbeddingGen import ChainedEmbeddingGen
logger = logging.getLogger('base') logger = logging.getLogger('base')
@ -121,6 +122,8 @@ def define_G(opt, net_key='network_G', scale=None):
elif which_model == 'artist': elif which_model == 'artist':
netG = SwitchedGen_arch.ArtistGen(opt_net['in_nc'], nf=opt_net['nf'], xforms=opt_net['num_transforms'], upscale=opt_net['scale'], netG = SwitchedGen_arch.ArtistGen(opt_net['in_nc'], nf=opt_net['nf'], xforms=opt_net['num_transforms'], upscale=opt_net['scale'],
init_temperature=opt_net['temperature']) init_temperature=opt_net['temperature'])
elif which_model == 'chained_gen':
netG = ChainedEmbeddingGen()
elif which_model == "flownet2": elif which_model == "flownet2":
from models.flownet2.models import FlowNet2 from models.flownet2.models import FlowNet2
ld = torch.load(opt_net['load_path']) ld = torch.load(opt_net['load_path'])