From c1c9c5681f7a9897d1af0766e1f557f431d7a56c Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 17 Oct 2020 08:40:28 -0600 Subject: [PATCH] Swap recurrence --- codes/models/archs/ChainedEmbeddingGen.py | 15 +++++++------- codes/models/networks.py | 24 +++++++++++------------ 2 files changed, 19 insertions(+), 20 deletions(-) diff --git a/codes/models/archs/ChainedEmbeddingGen.py b/codes/models/archs/ChainedEmbeddingGen.py index d304458d..6781630c 100644 --- a/codes/models/archs/ChainedEmbeddingGen.py +++ b/codes/models/archs/ChainedEmbeddingGen.py @@ -3,7 +3,7 @@ from torch import nn from models.archs.SPSR_arch import ImageGradientNoPadding from models.archs.arch_util import ConvGnLelu, ExpansionBlock2, ConvGnSilu, ConjoinBlock, MultiConvBlock, \ - FinalUpsampleBlock2x + FinalUpsampleBlock2x, ReferenceJoinBlock from models.archs.spinenet_arch import SpineNet from utils.util import checkpoint @@ -69,10 +69,10 @@ class ChainedEmbeddingGenWithStructure(nn.Module): def __init__(self, depth=10, recurrent=False): super(ChainedEmbeddingGenWithStructure, self).__init__() self.recurrent = recurrent + self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False) if recurrent: - self.initial_conv_rec = ConvGnLelu(6, 64, kernel_size=7, bias=True, norm=False, activation=False) - else: - self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False) + self.recurrent_process = ConvGnLelu(3, 64, kernel_size=3, stride=2, norm=False, bias=True, activation=False) + self.recurrent_join = ReferenceJoinBlock(64, residual_weight_init_factor=.01, final_norm=False, kernel_size=1, depth=3, join=False) self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False) self.blocks = nn.ModuleList([BasicEmbeddingPyramid() for i in range(depth)]) self.structure_joins = nn.ModuleList([ConjoinBlock(64) for i in range(3)]) @@ -83,11 +83,10 @@ class ChainedEmbeddingGenWithStructure(nn.Module): def forward(self, x, recurrent=None): emb = checkpoint(self.spine, x) + fea = self.initial_conv(x) if self.recurrent: - fea = torch.cat([x,recurrent], dim=1) - fea = self.initial_conv_rec(x) - else: - fea = self.initial_conv(x) + rec = self.recurrent_process(recurrent) + fea, _ = self.recurrent_join(fea, rec) grad = fea for i, block in enumerate(self.blocks): fea = fea + checkpoint(block, fea, *emb) diff --git a/codes/models/networks.py b/codes/models/networks.py index 2c232fbc..e2d4aaf6 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -1,24 +1,24 @@ +import functools +import logging +from collections import OrderedDict + import munch import torch -import logging +import torchvision from munch import munchify -import models.archs.SRResNet_arch as SRResNet_arch -import models.archs.discriminator_vgg_arch as SRGAN_arch + import models.archs.DiscriminatorResnet_arch as DiscriminatorResnet_arch import models.archs.DiscriminatorResnet_arch_passthrough as DiscriminatorResnet_arch_passthrough import models.archs.RRDBNet_arch as RRDBNet_arch -import models.archs.feature_arch as feature_arch -import models.archs.SwitchedResidualGenerator_arch as SwitchedGen_arch import models.archs.SPSR_arch as spsr +import models.archs.SRResNet_arch as SRResNet_arch import models.archs.StructuredSwitchedGenerator as ssg -import models.archs.rcan as rcan +import models.archs.SwitchedResidualGenerator_arch as SwitchedGen_arch +import models.archs.discriminator_vgg_arch as SRGAN_arch +import models.archs.feature_arch as feature_arch import models.archs.panet.panet as panet -from collections import OrderedDict -import torchvision -import functools - -from models.archs.ChainedEmbeddingGen import ChainedEmbeddingGen, ChainedEmbeddingGenWithStructure, \ - ChainedEmbeddingGenWithStructureR2 +import models.archs.rcan as rcan +from models.archs.ChainedEmbeddingGen import ChainedEmbeddingGen, ChainedEmbeddingGenWithStructure logger = logging.getLogger('base')