Swap recurrence
This commit is contained in:
parent
6141aa1110
commit
c1c9c5681f
|
@ -3,7 +3,7 @@ from torch import nn
|
||||||
|
|
||||||
from models.archs.SPSR_arch import ImageGradientNoPadding
|
from models.archs.SPSR_arch import ImageGradientNoPadding
|
||||||
from models.archs.arch_util import ConvGnLelu, ExpansionBlock2, ConvGnSilu, ConjoinBlock, MultiConvBlock, \
|
from models.archs.arch_util import ConvGnLelu, ExpansionBlock2, ConvGnSilu, ConjoinBlock, MultiConvBlock, \
|
||||||
FinalUpsampleBlock2x
|
FinalUpsampleBlock2x, ReferenceJoinBlock
|
||||||
from models.archs.spinenet_arch import SpineNet
|
from models.archs.spinenet_arch import SpineNet
|
||||||
from utils.util import checkpoint
|
from utils.util import checkpoint
|
||||||
|
|
||||||
|
@ -69,10 +69,10 @@ class ChainedEmbeddingGenWithStructure(nn.Module):
|
||||||
def __init__(self, depth=10, recurrent=False):
|
def __init__(self, depth=10, recurrent=False):
|
||||||
super(ChainedEmbeddingGenWithStructure, self).__init__()
|
super(ChainedEmbeddingGenWithStructure, self).__init__()
|
||||||
self.recurrent = recurrent
|
self.recurrent = recurrent
|
||||||
if recurrent:
|
|
||||||
self.initial_conv_rec = ConvGnLelu(6, 64, kernel_size=7, bias=True, norm=False, activation=False)
|
|
||||||
else:
|
|
||||||
self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False)
|
self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False)
|
||||||
|
if recurrent:
|
||||||
|
self.recurrent_process = ConvGnLelu(3, 64, kernel_size=3, stride=2, norm=False, bias=True, activation=False)
|
||||||
|
self.recurrent_join = ReferenceJoinBlock(64, residual_weight_init_factor=.01, final_norm=False, kernel_size=1, depth=3, join=False)
|
||||||
self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False)
|
self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False)
|
||||||
self.blocks = nn.ModuleList([BasicEmbeddingPyramid() for i in range(depth)])
|
self.blocks = nn.ModuleList([BasicEmbeddingPyramid() for i in range(depth)])
|
||||||
self.structure_joins = nn.ModuleList([ConjoinBlock(64) for i in range(3)])
|
self.structure_joins = nn.ModuleList([ConjoinBlock(64) for i in range(3)])
|
||||||
|
@ -83,11 +83,10 @@ class ChainedEmbeddingGenWithStructure(nn.Module):
|
||||||
|
|
||||||
def forward(self, x, recurrent=None):
|
def forward(self, x, recurrent=None):
|
||||||
emb = checkpoint(self.spine, x)
|
emb = checkpoint(self.spine, x)
|
||||||
if self.recurrent:
|
|
||||||
fea = torch.cat([x,recurrent], dim=1)
|
|
||||||
fea = self.initial_conv_rec(x)
|
|
||||||
else:
|
|
||||||
fea = self.initial_conv(x)
|
fea = self.initial_conv(x)
|
||||||
|
if self.recurrent:
|
||||||
|
rec = self.recurrent_process(recurrent)
|
||||||
|
fea, _ = self.recurrent_join(fea, rec)
|
||||||
grad = fea
|
grad = fea
|
||||||
for i, block in enumerate(self.blocks):
|
for i, block in enumerate(self.blocks):
|
||||||
fea = fea + checkpoint(block, fea, *emb)
|
fea = fea + checkpoint(block, fea, *emb)
|
||||||
|
|
|
@ -1,24 +1,24 @@
|
||||||
|
import functools
|
||||||
|
import logging
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
import munch
|
import munch
|
||||||
import torch
|
import torch
|
||||||
import logging
|
import torchvision
|
||||||
from munch import munchify
|
from munch import munchify
|
||||||
import models.archs.SRResNet_arch as SRResNet_arch
|
|
||||||
import models.archs.discriminator_vgg_arch as SRGAN_arch
|
|
||||||
import models.archs.DiscriminatorResnet_arch as DiscriminatorResnet_arch
|
import models.archs.DiscriminatorResnet_arch as DiscriminatorResnet_arch
|
||||||
import models.archs.DiscriminatorResnet_arch_passthrough as DiscriminatorResnet_arch_passthrough
|
import models.archs.DiscriminatorResnet_arch_passthrough as DiscriminatorResnet_arch_passthrough
|
||||||
import models.archs.RRDBNet_arch as RRDBNet_arch
|
import models.archs.RRDBNet_arch as RRDBNet_arch
|
||||||
import models.archs.feature_arch as feature_arch
|
|
||||||
import models.archs.SwitchedResidualGenerator_arch as SwitchedGen_arch
|
|
||||||
import models.archs.SPSR_arch as spsr
|
import models.archs.SPSR_arch as spsr
|
||||||
|
import models.archs.SRResNet_arch as SRResNet_arch
|
||||||
import models.archs.StructuredSwitchedGenerator as ssg
|
import models.archs.StructuredSwitchedGenerator as ssg
|
||||||
import models.archs.rcan as rcan
|
import models.archs.SwitchedResidualGenerator_arch as SwitchedGen_arch
|
||||||
|
import models.archs.discriminator_vgg_arch as SRGAN_arch
|
||||||
|
import models.archs.feature_arch as feature_arch
|
||||||
import models.archs.panet.panet as panet
|
import models.archs.panet.panet as panet
|
||||||
from collections import OrderedDict
|
import models.archs.rcan as rcan
|
||||||
import torchvision
|
from models.archs.ChainedEmbeddingGen import ChainedEmbeddingGen, ChainedEmbeddingGenWithStructure
|
||||||
import functools
|
|
||||||
|
|
||||||
from models.archs.ChainedEmbeddingGen import ChainedEmbeddingGen, ChainedEmbeddingGenWithStructure, \
|
|
||||||
ChainedEmbeddingGenWithStructureR2
|
|
||||||
|
|
||||||
logger = logging.getLogger('base')
|
logger = logging.getLogger('base')
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user