Add ChainedEmbeddingGen
This commit is contained in:
parent
c4543ce124
commit
617d97e19d
64
codes/models/archs/ChainedEmbeddingGen.py
Normal file
64
codes/models/archs/ChainedEmbeddingGen.py
Normal file
|
@ -0,0 +1,64 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from models.archs.arch_util import ConvGnLelu, ExpansionBlock2, ConvGnSilu, ConjoinBlock, MultiConvBlock, \
|
||||
FinalUpsampleBlock2x
|
||||
from models.archs.spinenet_arch import SpineNet
|
||||
from utils.util import checkpoint
|
||||
|
||||
|
||||
class BasicEmbeddingPyramid(nn.Module):
|
||||
def __init__(self, use_norms=True):
|
||||
super(BasicEmbeddingPyramid, self).__init__()
|
||||
self.initial_process = ConvGnLelu(64, 64, kernel_size=1, bias=True, activation=True, norm=False)
|
||||
self.reducers = nn.ModuleList([ConvGnLelu(64, 128, stride=2, kernel_size=1, bias=False, activation=True, norm=False),
|
||||
ConvGnLelu(128, 128, kernel_size=3, bias=False, activation=True, norm=use_norms),
|
||||
ConvGnLelu(128, 256, stride=2, kernel_size=1, bias=False, activation=True, norm=False),
|
||||
ConvGnLelu(256, 256, kernel_size=3, bias=False, activation=True, norm=use_norms)])
|
||||
self.expanders = nn.ModuleList([ExpansionBlock2(256, 128, block=ConvGnLelu),
|
||||
ExpansionBlock2(128, 64, block=ConvGnLelu)])
|
||||
self.embedding_processor1 = ConvGnSilu(256, 128, kernel_size=1, bias=True, activation=True, norm=False)
|
||||
self.embedding_joiner1 = ConjoinBlock(128, block=ConvGnLelu, norm=use_norms)
|
||||
self.embedding_processor2 = ConvGnSilu(256, 256, kernel_size=1, bias=True, activation=True, norm=False)
|
||||
self.embedding_joiner2 = ConjoinBlock(256, block=ConvGnLelu, norm=use_norms)
|
||||
|
||||
self.final_process = nn.Sequential(ConvGnLelu(128, 96, kernel_size=1, bias=False, activation=False, norm=False,
|
||||
weight_init_factor=.1),
|
||||
ConvGnLelu(96, 64, kernel_size=1, bias=False, activation=False, norm=False,
|
||||
weight_init_factor=.1),
|
||||
ConvGnLelu(64, 64, kernel_size=1, bias=False, activation=False, norm=False,
|
||||
weight_init_factor=.1),
|
||||
ConvGnLelu(64, 64, kernel_size=1, bias=False, activation=False, norm=False,
|
||||
weight_init_factor=.1))
|
||||
|
||||
def forward(self, x, *embeddings):
|
||||
p = self.initial_process(x)
|
||||
identities = []
|
||||
for i in range(2):
|
||||
identities.append(p)
|
||||
p = self.reducers[i*2](p)
|
||||
p = self.reducers[i*2+1](p)
|
||||
if i == 0:
|
||||
p = self.embedding_joiner1(p, self.embedding_processor1(embeddings[0]))
|
||||
elif i == 1:
|
||||
p = self.embedding_joiner2(p, self.embedding_processor2(embeddings[1]))
|
||||
for i in range(2):
|
||||
p = self.expanders[i](p, identities[-(i+1)])
|
||||
x = self.final_process(torch.cat([x, p], dim=1))
|
||||
return x
|
||||
|
||||
|
||||
class ChainedEmbeddingGen(nn.Module):
|
||||
def __init__(self):
|
||||
super(ChainedEmbeddingGen, self).__init__()
|
||||
self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False)
|
||||
self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False)
|
||||
self.blocks = nn.ModuleList([BasicEmbeddingPyramid() for i in range(5)])
|
||||
self.upsample = FinalUpsampleBlock2x(64)
|
||||
|
||||
def forward(self, x):
|
||||
emb = checkpoint(self.spine, x)
|
||||
fea = self.initial_conv(x)
|
||||
for block in self.blocks:
|
||||
fea = fea + checkpoint(block, fea, *emb)
|
||||
return checkpoint(self.upsample, fea),
|
|
@ -484,3 +484,16 @@ class UpconvBlock(nn.Module):
|
|||
def forward(self, x):
|
||||
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||
return self.process(x)
|
||||
|
||||
|
||||
# Scales an image up 2x and performs intermediary processing. Designed to be the final block in an SR network.
|
||||
class FinalUpsampleBlock2x(nn.Module):
|
||||
def __init__(self, nf, block=ConvGnLelu):
|
||||
super(FinalUpsampleBlock2x, self).__init__()
|
||||
self.chain = nn.Sequential(block(nf, nf, kernel_size=3, norm=False, activation=True, bias=True),
|
||||
UpconvBlock(nf, nf // 2, block=block, norm=False, activation=True, bias=True),
|
||||
block(nf // 2, nf // 2, kernel_size=3, norm=False, activation=False, bias=True),
|
||||
block(nf // 2, 3, kernel_size=3, norm=False, activation=False, bias=False))
|
||||
|
||||
def forward(self, x):
|
||||
return self.chain(x)
|
||||
|
|
|
@ -17,6 +17,7 @@ from collections import OrderedDict
|
|||
import torchvision
|
||||
import functools
|
||||
|
||||
from models.archs.ChainedEmbeddingGen import ChainedEmbeddingGen
|
||||
|
||||
logger = logging.getLogger('base')
|
||||
|
||||
|
@ -121,6 +122,8 @@ def define_G(opt, net_key='network_G', scale=None):
|
|||
elif which_model == 'artist':
|
||||
netG = SwitchedGen_arch.ArtistGen(opt_net['in_nc'], nf=opt_net['nf'], xforms=opt_net['num_transforms'], upscale=opt_net['scale'],
|
||||
init_temperature=opt_net['temperature'])
|
||||
elif which_model == 'chained_gen':
|
||||
netG = ChainedEmbeddingGen()
|
||||
elif which_model == "flownet2":
|
||||
from models.flownet2.models import FlowNet2
|
||||
ld = torch.load(opt_net['load_path'])
|
||||
|
|
Loading…
Reference in New Issue
Block a user