Swap recurrence

This commit is contained in:
James Betker 2020-10-17 08:40:28 -06:00
parent 6141aa1110
commit c1c9c5681f
2 changed files with 19 additions and 20 deletions

View File

@ -3,7 +3,7 @@ from torch import nn
from models.archs.SPSR_arch import ImageGradientNoPadding from models.archs.SPSR_arch import ImageGradientNoPadding
from models.archs.arch_util import ConvGnLelu, ExpansionBlock2, ConvGnSilu, ConjoinBlock, MultiConvBlock, \ from models.archs.arch_util import ConvGnLelu, ExpansionBlock2, ConvGnSilu, ConjoinBlock, MultiConvBlock, \
FinalUpsampleBlock2x FinalUpsampleBlock2x, ReferenceJoinBlock
from models.archs.spinenet_arch import SpineNet from models.archs.spinenet_arch import SpineNet
from utils.util import checkpoint from utils.util import checkpoint
@ -69,10 +69,10 @@ class ChainedEmbeddingGenWithStructure(nn.Module):
def __init__(self, depth=10, recurrent=False): def __init__(self, depth=10, recurrent=False):
super(ChainedEmbeddingGenWithStructure, self).__init__() super(ChainedEmbeddingGenWithStructure, self).__init__()
self.recurrent = recurrent self.recurrent = recurrent
if recurrent:
self.initial_conv_rec = ConvGnLelu(6, 64, kernel_size=7, bias=True, norm=False, activation=False)
else:
self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False) self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False)
if recurrent:
self.recurrent_process = ConvGnLelu(3, 64, kernel_size=3, stride=2, norm=False, bias=True, activation=False)
self.recurrent_join = ReferenceJoinBlock(64, residual_weight_init_factor=.01, final_norm=False, kernel_size=1, depth=3, join=False)
self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False) self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False)
self.blocks = nn.ModuleList([BasicEmbeddingPyramid() for i in range(depth)]) self.blocks = nn.ModuleList([BasicEmbeddingPyramid() for i in range(depth)])
self.structure_joins = nn.ModuleList([ConjoinBlock(64) for i in range(3)]) self.structure_joins = nn.ModuleList([ConjoinBlock(64) for i in range(3)])
@ -83,11 +83,10 @@ class ChainedEmbeddingGenWithStructure(nn.Module):
def forward(self, x, recurrent=None): def forward(self, x, recurrent=None):
emb = checkpoint(self.spine, x) emb = checkpoint(self.spine, x)
if self.recurrent:
fea = torch.cat([x,recurrent], dim=1)
fea = self.initial_conv_rec(x)
else:
fea = self.initial_conv(x) fea = self.initial_conv(x)
if self.recurrent:
rec = self.recurrent_process(recurrent)
fea, _ = self.recurrent_join(fea, rec)
grad = fea grad = fea
for i, block in enumerate(self.blocks): for i, block in enumerate(self.blocks):
fea = fea + checkpoint(block, fea, *emb) fea = fea + checkpoint(block, fea, *emb)

View File

@ -1,24 +1,24 @@
import functools
import logging
from collections import OrderedDict
import munch import munch
import torch import torch
import logging import torchvision
from munch import munchify from munch import munchify
import models.archs.SRResNet_arch as SRResNet_arch
import models.archs.discriminator_vgg_arch as SRGAN_arch
import models.archs.DiscriminatorResnet_arch as DiscriminatorResnet_arch import models.archs.DiscriminatorResnet_arch as DiscriminatorResnet_arch
import models.archs.DiscriminatorResnet_arch_passthrough as DiscriminatorResnet_arch_passthrough import models.archs.DiscriminatorResnet_arch_passthrough as DiscriminatorResnet_arch_passthrough
import models.archs.RRDBNet_arch as RRDBNet_arch import models.archs.RRDBNet_arch as RRDBNet_arch
import models.archs.feature_arch as feature_arch
import models.archs.SwitchedResidualGenerator_arch as SwitchedGen_arch
import models.archs.SPSR_arch as spsr import models.archs.SPSR_arch as spsr
import models.archs.SRResNet_arch as SRResNet_arch
import models.archs.StructuredSwitchedGenerator as ssg import models.archs.StructuredSwitchedGenerator as ssg
import models.archs.rcan as rcan import models.archs.SwitchedResidualGenerator_arch as SwitchedGen_arch
import models.archs.discriminator_vgg_arch as SRGAN_arch
import models.archs.feature_arch as feature_arch
import models.archs.panet.panet as panet import models.archs.panet.panet as panet
from collections import OrderedDict import models.archs.rcan as rcan
import torchvision from models.archs.ChainedEmbeddingGen import ChainedEmbeddingGen, ChainedEmbeddingGenWithStructure
import functools
from models.archs.ChainedEmbeddingGen import ChainedEmbeddingGen, ChainedEmbeddingGenWithStructure, \
ChainedEmbeddingGenWithStructureR2
logger = logging.getLogger('base') logger = logging.getLogger('base')