Add a ChainedEmbeddingGen which can be simueltaneously used with multiple training paradigms

This commit is contained in:
James Betker 2020-10-21 22:21:51 -06:00
parent 931aa65dd0
commit 1ef559d7ca
3 changed files with 85 additions and 10 deletions

View File

@ -53,12 +53,12 @@ class BasicEmbeddingPyramid(nn.Module):
class ChainedEmbeddingGen(nn.Module):
def __init__(self, depth=10):
def __init__(self, depth=10, in_nc=3):
super(ChainedEmbeddingGen, self).__init__()
self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False)
self.initial_conv = ConvGnLelu(in_nc, 64, kernel_size=7, bias=True, norm=False, activation=False)
self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False)
self.blocks = nn.ModuleList([BasicEmbeddingPyramid() for i in range(depth)])
self.upsample = FinalUpsampleBlock2x(64)
self.upsample = FinalUpsampleBlock2x(64, out_nc=in_nc)
def forward(self, x):
fea = self.initial_conv(x)
@ -69,10 +69,10 @@ class ChainedEmbeddingGen(nn.Module):
class ChainedEmbeddingGenWithStructure(nn.Module):
def __init__(self, depth=10, recurrent=False, recurrent_nf=3, recurrent_stride=2):
def __init__(self, in_nc=3, depth=10, recurrent=False, recurrent_nf=3, recurrent_stride=2):
super(ChainedEmbeddingGenWithStructure, self).__init__()
self.recurrent = recurrent
self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False)
self.initial_conv = ConvGnLelu(in_nc, 64, kernel_size=7, bias=True, norm=False, activation=False)
if recurrent:
self.recurrent_nf = recurrent_nf
self.recurrent_stride = recurrent_stride
@ -194,3 +194,74 @@ class StructuredChainedEmbeddingGenWithBypass(nn.Module):
blk_means['block_%i' % (i+1,)] = m
return {'ref_join_std': self.ref_join_std, 'bypass_biases': sum(biases) / len(biases),
'blocks_std': blk_stds, 'blocks_mean': blk_means}
class MultifacetedChainedEmbeddingGen(nn.Module):
def __init__(self, depth=10):
super(MultifacetedChainedEmbeddingGen, self).__init__()
self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False)
self.teco_recurrent_process = ConvGnLelu(3, 64, kernel_size=3, stride=2, norm=False, bias=True, activation=False)
self.teco_recurrent_join = ReferenceJoinBlock(64, residual_weight_init_factor=.01, final_norm=False, kernel_size=1, depth=3, join=False)
self.prog_recurrent_process = ConvGnLelu(3, 64, kernel_size=3, stride=1, norm=False, bias=True, activation=False)
self.prog_recurrent_join = ReferenceJoinBlock(64, residual_weight_init_factor=.01, final_norm=False, kernel_size=1, depth=3, join=False)
self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False)
self.blocks = nn.ModuleList([BasicEmbeddingPyramid() for i in range(depth)])
self.bypasses = nn.ModuleList([OptionalPassthroughBlock(64, initial_bias=0) for i in range(depth)])
self.structure_joins = nn.ModuleList([ConjoinBlock(64) for i in range(3)])
self.structure_blocks = nn.ModuleList([ConvGnLelu(64, 64, kernel_size=3, bias=False, norm=False, activation=False, weight_init_factor=.1) for i in range(3)])
self.structure_upsample = FinalUpsampleBlock2x(64)
self.grad_extract = ImageGradientNoPadding()
self.upsample = FinalUpsampleBlock2x(64)
self.teco_ref_std = 0
self.prog_ref_std = 0
self.block_residual_means = [0 for _ in range(depth)]
self.block_residual_stds = [0 for _ in range(depth)]
self.bypass_maps = []
def forward(self, x, teco_recurrent=None, prog_recurrent=None):
fea = self.initial_conv(x)
# Integrate recurrence inputs.
if teco_recurrent is not None:
teco_rec = torch.nn.functional.interpolate(teco_recurrent, scale_factor=2, mode='nearest')
teco_rec = self.teco_recurrent_process(teco_rec)
fea, std = self.teco_recurrent_join(fea, teco_rec)
self.teco_ref_std = std.item()
elif prog_recurrent is not None:
prog_rec = self.prog_recurrent_process(prog_recurrent)
prog_rec, std = self.prog_recurrent_join(fea, prog_rec)
self.prog_ref_std = std.item()
emb = checkpoint(self.spine, fea)
grad = fea
self.bypass_maps = []
for i, block in enumerate(self.blocks):
residual, context = checkpoint(block, fea, *emb)
residual, bypass_map = checkpoint(self.bypasses[i], residual, context)
fea = fea + residual
self.bypass_maps.append(bypass_map.detach())
self.block_residual_means[i] = residual.mean().item()
self.block_residual_stds[i] = residual.std().item()
if i < 3:
structure_br = checkpoint(self.structure_joins[i], grad, fea)
grad = grad + checkpoint(self.structure_blocks[i], structure_br)
out = checkpoint(self.upsample, fea)
return out, self.grad_extract(checkpoint(self.structure_upsample, grad)), self.grad_extract(out), fea
def visual_dbg(self, step, path):
for i, bm in enumerate(self.bypass_maps):
torchvision.utils.save_image(bm.cpu(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1)))
def get_debug_values(self, step, net_name):
biases = [b.bias.item() for b in self.bypasses]
blk_stds, blk_means = {}, {}
for i, (s, m) in enumerate(zip(self.block_residual_stds, self.block_residual_means)):
blk_stds['block_%i' % (i+1,)] = s
blk_means['block_%i' % (i+1,)] = m
return {'teco_std': self.teco_ref_std,
'prog_std': self.prog_ref_std,
'bypass_biases': sum(biases) / len(biases),
'blocks_std': blk_stds, 'blocks_mean': blk_means}

View File

@ -488,12 +488,12 @@ class UpconvBlock(nn.Module):
# Scales an image up 2x and performs intermediary processing. Designed to be the final block in an SR network.
class FinalUpsampleBlock2x(nn.Module):
def __init__(self, nf, block=ConvGnLelu):
def __init__(self, nf, block=ConvGnLelu, out_nc=3):
super(FinalUpsampleBlock2x, self).__init__()
self.chain = nn.Sequential(block(nf, nf, kernel_size=3, norm=False, activation=True, bias=True),
UpconvBlock(nf, nf // 2, block=block, norm=False, activation=True, bias=True),
block(nf // 2, nf // 2, kernel_size=3, norm=False, activation=False, bias=True),
block(nf // 2, 3, kernel_size=3, norm=False, activation=False, bias=False))
block(nf // 2, out_nc, kernel_size=3, norm=False, activation=False, bias=False))
def forward(self, x):
return self.chain(x)

View File

@ -19,7 +19,7 @@ import models.archs.feature_arch as feature_arch
import models.archs.panet.panet as panet
import models.archs.rcan as rcan
from models.archs.ChainedEmbeddingGen import ChainedEmbeddingGen, ChainedEmbeddingGenWithStructure, \
StructuredChainedEmbeddingGenWithBypass
StructuredChainedEmbeddingGenWithBypass, MultifacetedChainedEmbeddingGen
logger = logging.getLogger('base')
@ -125,18 +125,22 @@ def define_G(opt, net_key='network_G', scale=None):
netG = SwitchedGen_arch.ArtistGen(opt_net['in_nc'], nf=opt_net['nf'], xforms=opt_net['num_transforms'], upscale=opt_net['scale'],
init_temperature=opt_net['temperature'])
elif which_model == 'chained_gen':
netG = ChainedEmbeddingGen(depth=opt_net['depth'])
in_nc = opt_net['in_nc'] if 'in_nc' in opt_net.keys() else 3
netG = ChainedEmbeddingGen(depth=opt_net['depth'], in_nc=in_nc)
elif which_model == 'chained_gen_structured':
rec = opt_net['recurrent'] if 'recurrent' in opt_net.keys() else False
recnf = opt_net['recurrent_nf'] if 'recurrent_nf' in opt_net.keys() else 3
recstd = opt_net['recurrent_stride'] if 'recurrent_stride' in opt_net.keys() else 2
netG = ChainedEmbeddingGenWithStructure(depth=opt_net['depth'], recurrent=rec, recurrent_nf=recnf, recurrent_stride=recstd)
in_nc = opt_net['in_nc'] if 'in_nc' in opt_net.keys() else 3
netG = ChainedEmbeddingGenWithStructure(depth=opt_net['depth'], recurrent=rec, recurrent_nf=recnf, recurrent_stride=recstd, in_nc=in_nc)
elif which_model == 'chained_gen_structured_with_bypass':
rec = opt_net['recurrent'] if 'recurrent' in opt_net.keys() else False
recnf = opt_net['recurrent_nf'] if 'recurrent_nf' in opt_net.keys() else 3
recstd = opt_net['recurrent_stride'] if 'recurrent_stride' in opt_net.keys() else 2
bypass_bias = opt_net['bypass_bias'] if 'bypass_bias' in opt_net.keys() else 0
netG = StructuredChainedEmbeddingGenWithBypass(depth=opt_net['depth'], recurrent=rec, recurrent_nf=recnf, recurrent_stride=recstd, bypass_bias=bypass_bias)
elif which_model == 'multifaceted_chained':
netG = MultifacetedChainedEmbeddingGen(depth=opt_net['depth'])
elif which_model == "flownet2":
from models.flownet2.models import FlowNet2
ld = torch.load(opt_net['load_path'])