SPSR3 work

SPSR3 is meant to fix whatever is causing the switching units
inside of the newer SPSR architectures to fail and basically
not use the multiplexers.
This commit is contained in:
James Betker 2020-09-08 15:14:23 -06:00
parent 5606e8b0ee
commit e6207d4c50
4 changed files with 48 additions and 32 deletions

View File

@ -270,7 +270,7 @@ class ExtensibleTrainer(BaseModel):
load_path = self.opt['path']['pretrain_model_%s' % (name,)]
if load_path is not None:
logger.info('Loading model for [%s]' % (load_path))
self.load_network(load_path, net)
self.load_network(load_path, net, self.opt['path']['strict_load'])
def save(self, iter_step):
for name, net in self.networks.items():

View File

@ -4,7 +4,7 @@ import torch.nn as nn
import torch.nn.functional as F
from models.archs import SPSR_util as B
from .RRDBNet_arch import RRDB
from models.archs.arch_util import ConvGnLelu, UpconvBlock, ConjoinBlock, ConvGnSilu, MultiConvBlock
from models.archs.arch_util import ConvGnLelu, UpconvBlock, ConjoinBlock, ConvGnSilu, MultiConvBlock, ReferenceJoinBlock
from models.archs.SwitchedResidualGenerator_arch import ConvBasisMultiplexer, ConfigurableSwitchComputer, ReferencingConvMultiplexer, ReferenceImageBranch, AdaInConvBlock, ProcessingBranchWithStochasticity
from switched_conv_util import save_attention_to_image_rgb
from switched_conv import compute_attention_specificity
@ -499,17 +499,16 @@ class SwitchedSpsrWithRef2(nn.Module):
transformation_filters = nf
switch_filters = nf
self.transformation_counts = xforms
self.reference_processor = ReferenceImageBranch(transformation_filters)
multiplx_fn = functools.partial(ReferencingConvMultiplexer, transformation_filters, switch_filters)
pretransform_fn = functools.partial(AdaInConvBlock, 512, transformation_filters, transformation_filters)
transform_fn = functools.partial(ProcessingBranchWithStochasticity, transformation_filters, transformation_filters, transformation_filters // 8, 3)
# For conjoining two input streams.
conjoin_multiplex_fn = functools.partial(MultiplexerWithReducer, nf, multiplx_fn)
conjoin_pretransform_fn = functools.partial(AdaInConvBlock, 512, transformation_filters * 2, transformation_filters * 2)
conjoin_transform_fn = functools.partial(ProcessingBranchWithStochasticity, transformation_filters * 2, transformation_filters, transformation_filters // 8, 4)
multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, 3,
2, use_exp2=True)
pretransform_fn = functools.partial(ConvGnLelu, transformation_filters, transformation_filters, norm=False, bias=False, weight_init_factor=.1)
transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5),
transformation_filters, kernel_size=3, depth=3,
weight_init_factor=.1)
# Feature branch
self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False)
self.noise_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.1, norm=False)
self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=pretransform_fn, transform_block=transform_fn,
attention_norm=True,
@ -523,28 +522,30 @@ class SwitchedSpsrWithRef2(nn.Module):
self.feature_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False)
self.feature_lr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False)
# Grad branch
# Grad branch. Note - groupnorm on this branch is REALLY bad. Avoid it like the plague.
self.get_g_nopadding = ImageGradientNoPadding()
self.grad_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False)
self.sw_grad = ConfigurableSwitchComputer(transformation_filters, conjoin_multiplex_fn,
pre_transform_block=conjoin_pretransform_fn, transform_block=conjoin_transform_fn,
self.grad_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2, norm=False, final_norm=False)
self.sw_grad = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=pretransform_fn, transform_block=transform_fn,
attention_norm=True,
transform_count=self.transformation_counts // 2, init_temp=init_temperature,
add_scalable_noise_to_transforms=False)
# Upsampling
self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False)
self.grad_lr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False)
self.upsample_grad = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=True, activation=True, bias=False) for _ in range(n_upscale)])
self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False)
self.grad_lr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False)
self.upsample_grad = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=False) for _ in range(n_upscale)])
self.grad_branch_output_conv = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=True)
self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, conjoin_multiplex_fn,
pre_transform_block=conjoin_pretransform_fn, transform_block=conjoin_transform_fn,
# Join branch (grad+fea
self.conjoin_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2, norm=False)
self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=pretransform_fn, transform_block=transform_fn,
attention_norm=True,
transform_count=self.transformation_counts, init_temp=init_temperature,
add_scalable_noise_to_transforms=False)
self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False)
self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=True, activation=True, bias=False) for _ in range(n_upscale)])
self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False)
self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=False) for _ in range(n_upscale)])
self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False)
self.final_hr_conv2 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=True)
self.switches = [self.sw1, self.sw2, self.sw_grad, self.conjoin_sw]
self.attentions = None
@ -552,25 +553,25 @@ class SwitchedSpsrWithRef2(nn.Module):
self.final_temperature_step = 10000
def forward(self, x, ref, center_coord):
ref = self.reference_processor(ref, center_coord)
x = self.model_fea_conv(x)
x_grad = self.get_g_nopadding(x)
x1, a1 = self.sw1((x, ref), True)
x2, a2 = self.sw2((x1, ref), True)
x = self.model_fea_conv(x)
x = self.noise_ref_join(x, torch.randn_like(x))
x1, a1 = self.sw1(x, True)
x2, a2 = self.sw2(x, True)
x_fea = self.feature_lr_conv(x2)
x_fea = self.feature_lr_conv2(x_fea)
x_grad = self.get_g_nopadding(x)
x_grad = self.grad_conv(x_grad)
x_grad, a3 = self.sw_grad((torch.cat([x_grad, x1], dim=1), ref),
identity=x_grad, output_attention_weights=True)
x_grad = self.grad_ref_join(x_grad, x1)
x_grad, a3 = self.sw_grad(x_grad, True)
x_grad = self.grad_lr_conv(x_grad)
x_grad = self.grad_lr_conv2(x_grad)
x_grad_out = self.upsample_grad(x_grad)
x_grad_out = self.grad_branch_output_conv(x_grad_out)
x_out, a4 = self.conjoin_sw((torch.cat([x_fea, x_grad], dim=1), ref),
identity=x_fea, output_attention_weights=True)
x_out = self.conjoin_ref_join(x_fea, x_grad)
x_out, a4 = self.conjoin_sw(x_out, True)
x_out = self.final_lr_conv(x_out)
x_out = self.upsample(x_out)
x_out = self.final_hr_conv1(x_out)

View File

@ -454,6 +454,21 @@ class ConjoinBlock(nn.Module):
return self.decimate(x)
# Designed explicitly to join a mainline trunk with reference data. Implemented as a residual branch.
class ReferenceJoinBlock(nn.Module):
def __init__(self, nf, residual_weight_init_factor=1, norm=False, block=ConvGnLelu, final_norm=True):
super(ReferenceJoinBlock, self).__init__()
self.branch = MultiConvBlock(nf * 2, nf + nf // 2, nf, kernel_size=3, depth=3,
scale_init=residual_weight_init_factor, norm=norm,
weight_init_factor=residual_weight_init_factor)
self.join_conv = block(nf, nf, norm=final_norm, bias=False, activation=True)
def forward(self, x, ref):
joined = torch.cat([x, ref], dim=1)
branch = self.branch(joined)
return self.join_conv(x + branch)
# Basic convolutional upsampling block that uses interpolate.
class UpconvBlock(nn.Module):
def __init__(self, filters_in, filters_out=None, block=ConvGnSilu, norm=True, activation=True, bias=False):

View File

@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
def main():
#### options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_csnln.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr3_gan.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)