Add ChainedGenWithStructure
This commit is contained in:
parent
96f1be30ed
commit
d856378b2e
|
@ -1,6 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from models.archs.SPSR_arch import ImageGradientNoPadding
|
||||||
from models.archs.arch_util import ConvGnLelu, ExpansionBlock2, ConvGnSilu, ConjoinBlock, MultiConvBlock, \
|
from models.archs.arch_util import ConvGnLelu, ExpansionBlock2, ConvGnSilu, ConjoinBlock, MultiConvBlock, \
|
||||||
FinalUpsampleBlock2x
|
FinalUpsampleBlock2x
|
||||||
from models.archs.spinenet_arch import SpineNet
|
from models.archs.spinenet_arch import SpineNet
|
||||||
|
@ -49,11 +50,11 @@ class BasicEmbeddingPyramid(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class ChainedEmbeddingGen(nn.Module):
|
class ChainedEmbeddingGen(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self, depth=10):
|
||||||
super(ChainedEmbeddingGen, self).__init__()
|
super(ChainedEmbeddingGen, self).__init__()
|
||||||
self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False)
|
self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False)
|
||||||
self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False)
|
self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False)
|
||||||
self.blocks = nn.ModuleList([BasicEmbeddingPyramid() for i in range(5)])
|
self.blocks = nn.ModuleList([BasicEmbeddingPyramid() for i in range(depth)])
|
||||||
self.upsample = FinalUpsampleBlock2x(64)
|
self.upsample = FinalUpsampleBlock2x(64)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -62,3 +63,56 @@ class ChainedEmbeddingGen(nn.Module):
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
fea = fea + checkpoint(block, fea, *emb)
|
fea = fea + checkpoint(block, fea, *emb)
|
||||||
return checkpoint(self.upsample, fea),
|
return checkpoint(self.upsample, fea),
|
||||||
|
|
||||||
|
|
||||||
|
class ChainedEmbeddingGenWithStructure(nn.Module):
|
||||||
|
def __init__(self, depth=10):
|
||||||
|
super(ChainedEmbeddingGenWithStructure, self).__init__()
|
||||||
|
self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False)
|
||||||
|
self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False)
|
||||||
|
self.blocks = nn.ModuleList([BasicEmbeddingPyramid() for i in range(depth)])
|
||||||
|
self.structure_joins = nn.ModuleList([ConjoinBlock(64) for i in range(3)])
|
||||||
|
self.structure_blocks = nn.ModuleList([ConvGnLelu(64, 64, kernel_size=3, bias=False, norm=False, activation=False, weight_init_factor=.1) for i in range(3)])
|
||||||
|
self.structure_upsample = FinalUpsampleBlock2x(64)
|
||||||
|
self.grad_extract = ImageGradientNoPadding()
|
||||||
|
self.upsample = FinalUpsampleBlock2x(64)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
emb = checkpoint(self.spine, x)
|
||||||
|
fea = self.initial_conv(x)
|
||||||
|
grad = fea
|
||||||
|
for i, block in enumerate(self.blocks):
|
||||||
|
fea = fea + checkpoint(block, fea, *emb)
|
||||||
|
if i < 3:
|
||||||
|
structure_br = checkpoint(self.structure_joins[i], grad, fea)
|
||||||
|
grad = grad + checkpoint(self.structure_blocks[i], structure_br)
|
||||||
|
out = checkpoint(self.upsample, fea)
|
||||||
|
return out, self.grad_extract(checkpoint(self.structure_upsample, grad)), self.grad_extract(out)
|
||||||
|
|
||||||
|
|
||||||
|
class ChainedEmbeddingGenWithStructureR2(nn.Module):
|
||||||
|
def __init__(self, depth=10):
|
||||||
|
super(ChainedEmbeddingGenWithStructureR2, self).__init__()
|
||||||
|
self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False)
|
||||||
|
self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False)
|
||||||
|
self.blocks = nn.ModuleList([BasicEmbeddingPyramid() for i in range(depth)])
|
||||||
|
self.structure_joins = nn.ModuleList([ConjoinBlock(64) for i in range(3)])
|
||||||
|
self.structure_blocks = nn.ModuleList([ConvGnLelu(64, 64, kernel_size=3, bias=False, norm=False, activation=False, weight_init_factor=.1) for i in range(3)])
|
||||||
|
self.structure_upsample = FinalUpsampleBlock2x(64)
|
||||||
|
self.structure_rejoin = ConjoinBlock(64)
|
||||||
|
self.grad_extract = ImageGradientNoPadding()
|
||||||
|
self.upsample = FinalUpsampleBlock2x(64)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
emb = checkpoint(self.spine, x)
|
||||||
|
fea = self.initial_conv(x)
|
||||||
|
grad = fea
|
||||||
|
for i, block in enumerate(self.blocks):
|
||||||
|
fea = fea + checkpoint(block, fea, *emb)
|
||||||
|
if i < 3:
|
||||||
|
structure_br = checkpoint(self.structure_joins[i], grad, fea)
|
||||||
|
grad = grad + checkpoint(self.structure_blocks[i], structure_br)
|
||||||
|
if i == 3:
|
||||||
|
fea = fea + self.structure_rejoin(fea, grad)
|
||||||
|
out = checkpoint(self.upsample, fea)
|
||||||
|
return out, self.grad_extract(checkpoint(self.structure_upsample, grad)), self.grad_extract(out)
|
||||||
|
|
|
@ -17,7 +17,8 @@ from collections import OrderedDict
|
||||||
import torchvision
|
import torchvision
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
from models.archs.ChainedEmbeddingGen import ChainedEmbeddingGen
|
from models.archs.ChainedEmbeddingGen import ChainedEmbeddingGen, ChainedEmbeddingGenWithStructure, \
|
||||||
|
ChainedEmbeddingGenWithStructureR2
|
||||||
|
|
||||||
logger = logging.getLogger('base')
|
logger = logging.getLogger('base')
|
||||||
|
|
||||||
|
@ -123,7 +124,11 @@ def define_G(opt, net_key='network_G', scale=None):
|
||||||
netG = SwitchedGen_arch.ArtistGen(opt_net['in_nc'], nf=opt_net['nf'], xforms=opt_net['num_transforms'], upscale=opt_net['scale'],
|
netG = SwitchedGen_arch.ArtistGen(opt_net['in_nc'], nf=opt_net['nf'], xforms=opt_net['num_transforms'], upscale=opt_net['scale'],
|
||||||
init_temperature=opt_net['temperature'])
|
init_temperature=opt_net['temperature'])
|
||||||
elif which_model == 'chained_gen':
|
elif which_model == 'chained_gen':
|
||||||
netG = ChainedEmbeddingGen()
|
netG = ChainedEmbeddingGen(depth=opt_net['depth'])
|
||||||
|
elif which_model == 'chained_gen_structured':
|
||||||
|
netG = ChainedEmbeddingGenWithStructure(depth=opt_net['depth'])
|
||||||
|
elif which_model == 'chained_gen_structuredr2':
|
||||||
|
netG = ChainedEmbeddingGenWithStructureR2(depth=opt_net['depth'])
|
||||||
elif which_model == "flownet2":
|
elif which_model == "flownet2":
|
||||||
from models.flownet2.models import FlowNet2
|
from models.flownet2.models import FlowNet2
|
||||||
ld = torch.load(opt_net['load_path'])
|
ld = torch.load(opt_net['load_path'])
|
||||||
|
|
|
@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs):
|
||||||
def main():
|
def main():
|
||||||
#### options
|
#### options
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_ssgsimpler.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_chained_structuredr2.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs):
|
||||||
def main():
|
def main():
|
||||||
#### options
|
#### options
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_ssgdeep.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_chained_structured.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user