Add updated spsr net for test

This commit is contained in:
James Betker 2020-09-07 17:01:48 -06:00
parent 55475d2ac1
commit a18ece62ee
5 changed files with 153 additions and 13 deletions

View File

@ -4,8 +4,8 @@ import torch.nn as nn
import torch.nn.functional as F
from models.archs import SPSR_util as B
from .RRDBNet_arch import RRDB
from models.archs.arch_util import ConvGnLelu, UpconvBlock, ConjoinBlock
from models.archs.SwitchedResidualGenerator_arch import MultiConvBlock, ConvBasisMultiplexer, ConfigurableSwitchComputer, ReferencingConvMultiplexer, ReferenceImageBranch, AdaInConvBlock
from models.archs.arch_util import ConvGnLelu, UpconvBlock, ConjoinBlock, ConvGnSilu
from models.archs.SwitchedResidualGenerator_arch import MultiConvBlock, ConvBasisMultiplexer, ConfigurableSwitchComputer, ReferencingConvMultiplexer, ReferenceImageBranch, AdaInConvBlock, ProcessingBranchWithStochasticity
from switched_conv_util import save_attention_to_image_rgb
from switched_conv import compute_attention_specificity
import functools
@ -473,3 +473,133 @@ class SwitchedSpsrWithRef(nn.Module):
val["switch_%i_specificity" % (i,)] = means[i]
val["switch_%i_histogram" % (i,)] = hists[i]
return val
class MultiplexerWithReducer(nn.Module):
def __init__(self, base_filters, multiplx_create_fn, transform_count):
super(MultiplexerWithReducer, self).__init__()
self.proc1 = ConvGnSilu(base_filters*2, base_filters*2, bias=False)
self.proc2 = ConvGnSilu(base_filters*2, base_filters*2, bias=False)
self.reduce = ConvGnSilu(base_filters*2, base_filters, activation=False, norm=False, bias=True)
self.conjoin = ConjoinBlock(base_filters)
self.mplex = multiplx_create_fn(transform_count)
def forward(self, x, ref):
x = self.proc1(x)
x = self.proc2(x)
x = self.reduce(x)
return self.mplex(x, ref)
class SwitchedSpsrWithRef2(nn.Module):
def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, init_temperature=10):
super(SwitchedSpsrWithRef2, self).__init__()
n_upscale = int(math.log(upscale, 2))
# switch options
transformation_filters = nf
switch_filters = nf
self.transformation_counts = xforms
self.reference_processor = ReferenceImageBranch(transformation_filters)
multiplx_fn = functools.partial(ReferencingConvMultiplexer, transformation_filters, switch_filters)
pretransform_fn = functools.partial(AdaInConvBlock, 512, transformation_filters, transformation_filters)
transform_fn = functools.partial(ProcessingBranchWithStochasticity, transformation_filters, transformation_filters, transformation_filters // 8, 3)
# For conjoining two input streams.
conjoin_multiplex_fn = functools.partial(MultiplexerWithReducer, nf, multiplx_fn)
conjoin_pretransform_fn = functools.partial(AdaInConvBlock, 512, transformation_filters * 2, transformation_filters * 2)
conjoin_transform_fn = functools.partial(ProcessingBranchWithStochasticity, transformation_filters * 2, transformation_filters, transformation_filters // 8, 4)
# Feature branch
self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False)
self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=pretransform_fn, transform_block=transform_fn,
attention_norm=True,
transform_count=self.transformation_counts, init_temp=init_temperature,
add_scalable_noise_to_transforms=False)
self.sw2 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=pretransform_fn, transform_block=transform_fn,
attention_norm=True,
transform_count=self.transformation_counts, init_temp=init_temperature,
add_scalable_noise_to_transforms=False)
self.feature_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False)
self.feature_lr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False)
# Grad branch
self.get_g_nopadding = ImageGradientNoPadding()
self.grad_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False)
self.sw_grad = ConfigurableSwitchComputer(transformation_filters, conjoin_multiplex_fn,
pre_transform_block=conjoin_pretransform_fn, transform_block=conjoin_transform_fn,
attention_norm=True,
transform_count=self.transformation_counts // 2, init_temp=init_temperature,
add_scalable_noise_to_transforms=False)
# Upsampling
self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False)
self.grad_lr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False)
self.upsample_grad = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=True, activation=True, bias=False) for _ in range(n_upscale)])
self.grad_branch_output_conv = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=True)
self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, conjoin_multiplex_fn,
pre_transform_block=conjoin_pretransform_fn, transform_block=conjoin_transform_fn,
attention_norm=True,
transform_count=self.transformation_counts, init_temp=init_temperature,
add_scalable_noise_to_transforms=False)
self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False)
self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=True, activation=True, bias=False) for _ in range(n_upscale)])
self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False)
self.final_hr_conv2 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=True)
self.switches = [self.sw1, self.sw2, self.sw_grad, self.conjoin_sw]
self.attentions = None
self.init_temperature = init_temperature
self.final_temperature_step = 10000
def forward(self, x, ref, center_coord):
x_grad = self.get_g_nopadding(x)
ref = self.reference_processor(ref, center_coord)
x = self.model_fea_conv(x)
x1, a1 = self.sw1((x, ref), True)
x2, a2 = self.sw2((x1, ref), True)
x_fea = self.feature_lr_conv(x2)
x_fea = self.feature_lr_conv2(x_fea)
x_grad = self.grad_conv(x_grad)
x_grad, a3 = self.sw_grad((torch.cat([x_grad, x1], dim=1), ref),
identity=x_grad, output_attention_weights=True)
x_grad = self.grad_lr_conv(x_grad)
x_grad = self.grad_lr_conv2(x_grad)
x_grad_out = self.upsample_grad(x_grad)
x_grad_out = self.grad_branch_output_conv(x_grad_out)
x_out, a4 = self.conjoin_sw((torch.cat([x_fea, x_grad], dim=1), ref),
identity=x_fea, output_attention_weights=True)
x_out = self.final_lr_conv(x_out)
x_out = self.upsample(x_out)
x_out = self.final_hr_conv1(x_out)
x_out = self.final_hr_conv2(x_out)
self.attentions = [a1, a2, a3, a4]
return x_grad_out, x_out, x_grad
def set_temperature(self, temp):
[sw.set_temperature(temp) for sw in self.switches]
def update_for_step(self, step, experiments_path='.'):
if self.attentions:
temp = max(1, 1 + self.init_temperature *
(self.final_temperature_step - step) / self.final_temperature_step)
self.set_temperature(temp)
if step % 200 == 0:
output_path = os.path.join(experiments_path, "attention_maps", "a%i")
prefix = "attention_map_%i_%%i.png" % (step,)
[save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))]
def get_debug_values(self, step):
temp = self.switches[0].switch.temperature
mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]
means = [i[0] for i in mean_hists]
hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists]
val = {"switch_temperature": temp}
for i in range(len(means)):
val["switch_%i_specificity" % (i,)] = means[i]
val["switch_%i_histogram" % (i,)] = hists[i]
return val

View File

@ -138,6 +138,19 @@ class AdaInConvBlock(nn.Module):
return self.post_fuse_conv(x)
class ProcessingBranchWithStochasticity(nn.Module):
def __init__(self, nf_in, nf_out, noise_filters, depth):
super(ProcessingBranchWithStochasticity, self).__init__()
nf_gap = nf_out - nf_in
self.noise_filters = noise_filters
self.processor = MultiConvBlock(nf_in + noise_filters, nf_in + nf_gap // 2, nf_out, kernel_size=3, depth=depth, weight_init_factor = .1)
def forward(self, x):
b, c, h, w = x.shape
noise = torch.randn((b, self.noise_filters, h, w), device=x.device)
return self.processor(torch.cat([x, noise], dim=1))
# This is similar to ConvBasisMultiplexer, except that it takes a linear reference tensor as a second input to
# provide better results. It also has fixed parameterization in several places
class ReferencingConvMultiplexer(nn.Module):

View File

@ -440,4 +440,4 @@ class UpconvBlock(nn.Module):
def forward(self, x):
x = F.interpolate(x, scale_factor=2, mode="nearest")
return self.process(x)
return self.process(x)

View File

@ -1,20 +1,14 @@
import torch
import logging
from munch import munchify
import models.archs.SRResNet_arch as SRResNet_arch
import models.archs.discriminator_vgg_arch as SRGAN_arch
import models.archs.DiscriminatorResnet_arch as DiscriminatorResnet_arch
import models.archs.DiscriminatorResnet_arch_passthrough as DiscriminatorResnet_arch_passthrough
import models.archs.FlatProcessorNetNew_arch as FlatProcessorNetNew_arch
import models.archs.RRDBNet_arch as RRDBNet_arch
import models.archs.HighToLowResNet as HighToLowResNet
import models.archs.NestedSwitchGenerator as ng
import models.archs.feature_arch as feature_arch
import models.archs.SwitchedResidualGenerator_arch as SwitchedGen_arch
import models.archs.SRG1_arch as srg1
import models.archs.ProgressiveSrg_arch as psrg
import models.archs.SPSR_arch as spsr
import models.archs.arch_util as arch_util
import functools
from collections import OrderedDict
logger = logging.getLogger('base')
@ -61,10 +55,13 @@ def define_G(opt, net_key='network_G', scale=None):
xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8
netG = spsr.SwitchedSpsrWithRef(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10)
elif which_model == "spsr_switched_with_ref4x":
elif which_model == "spsr_switched_with_ref2":
xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8
netG = spsr.SwitchedSpsrWithRef4x(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms,
netG = spsr.SwitchedSpsrWithRef2(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10)
elif which_model == "csnln":
import model.csnln as csnln
netG = csnln.CSNLN(munchify(opt_net))
else:
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))

View File

@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
def main():
#### options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr_switched2_fullimgref_gan_no_branch.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_csnln.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)