Add ChainedEmbeddingGen
This commit is contained in:
parent
c4543ce124
commit
617d97e19d
64
codes/models/archs/ChainedEmbeddingGen.py
Normal file
64
codes/models/archs/ChainedEmbeddingGen.py
Normal file
|
@ -0,0 +1,64 @@
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from models.archs.arch_util import ConvGnLelu, ExpansionBlock2, ConvGnSilu, ConjoinBlock, MultiConvBlock, \
|
||||||
|
FinalUpsampleBlock2x
|
||||||
|
from models.archs.spinenet_arch import SpineNet
|
||||||
|
from utils.util import checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
class BasicEmbeddingPyramid(nn.Module):
|
||||||
|
def __init__(self, use_norms=True):
|
||||||
|
super(BasicEmbeddingPyramid, self).__init__()
|
||||||
|
self.initial_process = ConvGnLelu(64, 64, kernel_size=1, bias=True, activation=True, norm=False)
|
||||||
|
self.reducers = nn.ModuleList([ConvGnLelu(64, 128, stride=2, kernel_size=1, bias=False, activation=True, norm=False),
|
||||||
|
ConvGnLelu(128, 128, kernel_size=3, bias=False, activation=True, norm=use_norms),
|
||||||
|
ConvGnLelu(128, 256, stride=2, kernel_size=1, bias=False, activation=True, norm=False),
|
||||||
|
ConvGnLelu(256, 256, kernel_size=3, bias=False, activation=True, norm=use_norms)])
|
||||||
|
self.expanders = nn.ModuleList([ExpansionBlock2(256, 128, block=ConvGnLelu),
|
||||||
|
ExpansionBlock2(128, 64, block=ConvGnLelu)])
|
||||||
|
self.embedding_processor1 = ConvGnSilu(256, 128, kernel_size=1, bias=True, activation=True, norm=False)
|
||||||
|
self.embedding_joiner1 = ConjoinBlock(128, block=ConvGnLelu, norm=use_norms)
|
||||||
|
self.embedding_processor2 = ConvGnSilu(256, 256, kernel_size=1, bias=True, activation=True, norm=False)
|
||||||
|
self.embedding_joiner2 = ConjoinBlock(256, block=ConvGnLelu, norm=use_norms)
|
||||||
|
|
||||||
|
self.final_process = nn.Sequential(ConvGnLelu(128, 96, kernel_size=1, bias=False, activation=False, norm=False,
|
||||||
|
weight_init_factor=.1),
|
||||||
|
ConvGnLelu(96, 64, kernel_size=1, bias=False, activation=False, norm=False,
|
||||||
|
weight_init_factor=.1),
|
||||||
|
ConvGnLelu(64, 64, kernel_size=1, bias=False, activation=False, norm=False,
|
||||||
|
weight_init_factor=.1),
|
||||||
|
ConvGnLelu(64, 64, kernel_size=1, bias=False, activation=False, norm=False,
|
||||||
|
weight_init_factor=.1))
|
||||||
|
|
||||||
|
def forward(self, x, *embeddings):
|
||||||
|
p = self.initial_process(x)
|
||||||
|
identities = []
|
||||||
|
for i in range(2):
|
||||||
|
identities.append(p)
|
||||||
|
p = self.reducers[i*2](p)
|
||||||
|
p = self.reducers[i*2+1](p)
|
||||||
|
if i == 0:
|
||||||
|
p = self.embedding_joiner1(p, self.embedding_processor1(embeddings[0]))
|
||||||
|
elif i == 1:
|
||||||
|
p = self.embedding_joiner2(p, self.embedding_processor2(embeddings[1]))
|
||||||
|
for i in range(2):
|
||||||
|
p = self.expanders[i](p, identities[-(i+1)])
|
||||||
|
x = self.final_process(torch.cat([x, p], dim=1))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ChainedEmbeddingGen(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(ChainedEmbeddingGen, self).__init__()
|
||||||
|
self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False)
|
||||||
|
self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False)
|
||||||
|
self.blocks = nn.ModuleList([BasicEmbeddingPyramid() for i in range(5)])
|
||||||
|
self.upsample = FinalUpsampleBlock2x(64)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
emb = checkpoint(self.spine, x)
|
||||||
|
fea = self.initial_conv(x)
|
||||||
|
for block in self.blocks:
|
||||||
|
fea = fea + checkpoint(block, fea, *emb)
|
||||||
|
return checkpoint(self.upsample, fea),
|
|
@ -484,3 +484,16 @@ class UpconvBlock(nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||||
return self.process(x)
|
return self.process(x)
|
||||||
|
|
||||||
|
|
||||||
|
# Scales an image up 2x and performs intermediary processing. Designed to be the final block in an SR network.
|
||||||
|
class FinalUpsampleBlock2x(nn.Module):
|
||||||
|
def __init__(self, nf, block=ConvGnLelu):
|
||||||
|
super(FinalUpsampleBlock2x, self).__init__()
|
||||||
|
self.chain = nn.Sequential(block(nf, nf, kernel_size=3, norm=False, activation=True, bias=True),
|
||||||
|
UpconvBlock(nf, nf // 2, block=block, norm=False, activation=True, bias=True),
|
||||||
|
block(nf // 2, nf // 2, kernel_size=3, norm=False, activation=False, bias=True),
|
||||||
|
block(nf // 2, 3, kernel_size=3, norm=False, activation=False, bias=False))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.chain(x)
|
||||||
|
|
|
@ -17,6 +17,7 @@ from collections import OrderedDict
|
||||||
import torchvision
|
import torchvision
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
|
from models.archs.ChainedEmbeddingGen import ChainedEmbeddingGen
|
||||||
|
|
||||||
logger = logging.getLogger('base')
|
logger = logging.getLogger('base')
|
||||||
|
|
||||||
|
@ -121,6 +122,8 @@ def define_G(opt, net_key='network_G', scale=None):
|
||||||
elif which_model == 'artist':
|
elif which_model == 'artist':
|
||||||
netG = SwitchedGen_arch.ArtistGen(opt_net['in_nc'], nf=opt_net['nf'], xforms=opt_net['num_transforms'], upscale=opt_net['scale'],
|
netG = SwitchedGen_arch.ArtistGen(opt_net['in_nc'], nf=opt_net['nf'], xforms=opt_net['num_transforms'], upscale=opt_net['scale'],
|
||||||
init_temperature=opt_net['temperature'])
|
init_temperature=opt_net['temperature'])
|
||||||
|
elif which_model == 'chained_gen':
|
||||||
|
netG = ChainedEmbeddingGen()
|
||||||
elif which_model == "flownet2":
|
elif which_model == "flownet2":
|
||||||
from models.flownet2.models import FlowNet2
|
from models.flownet2.models import FlowNet2
|
||||||
ld = torch.load(opt_net['load_path'])
|
ld = torch.load(opt_net['load_path'])
|
||||||
|
|
Loading…
Reference in New Issue
Block a user