SPSR3 work

SPSR3 is meant to fix whatever is causing the switching units
inside of the newer SPSR architectures to fail and basically
not use the multiplexers.
This commit is contained in:
James Betker 2020-09-08 15:14:23 -06:00
parent 5606e8b0ee
commit e6207d4c50
4 changed files with 48 additions and 32 deletions

View File

@ -270,7 +270,7 @@ class ExtensibleTrainer(BaseModel):
load_path = self.opt['path']['pretrain_model_%s' % (name,)] load_path = self.opt['path']['pretrain_model_%s' % (name,)]
if load_path is not None: if load_path is not None:
logger.info('Loading model for [%s]' % (load_path)) logger.info('Loading model for [%s]' % (load_path))
self.load_network(load_path, net) self.load_network(load_path, net, self.opt['path']['strict_load'])
def save(self, iter_step): def save(self, iter_step):
for name, net in self.networks.items(): for name, net in self.networks.items():

View File

@ -4,7 +4,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from models.archs import SPSR_util as B from models.archs import SPSR_util as B
from .RRDBNet_arch import RRDB from .RRDBNet_arch import RRDB
from models.archs.arch_util import ConvGnLelu, UpconvBlock, ConjoinBlock, ConvGnSilu, MultiConvBlock from models.archs.arch_util import ConvGnLelu, UpconvBlock, ConjoinBlock, ConvGnSilu, MultiConvBlock, ReferenceJoinBlock
from models.archs.SwitchedResidualGenerator_arch import ConvBasisMultiplexer, ConfigurableSwitchComputer, ReferencingConvMultiplexer, ReferenceImageBranch, AdaInConvBlock, ProcessingBranchWithStochasticity from models.archs.SwitchedResidualGenerator_arch import ConvBasisMultiplexer, ConfigurableSwitchComputer, ReferencingConvMultiplexer, ReferenceImageBranch, AdaInConvBlock, ProcessingBranchWithStochasticity
from switched_conv_util import save_attention_to_image_rgb from switched_conv_util import save_attention_to_image_rgb
from switched_conv import compute_attention_specificity from switched_conv import compute_attention_specificity
@ -499,17 +499,16 @@ class SwitchedSpsrWithRef2(nn.Module):
transformation_filters = nf transformation_filters = nf
switch_filters = nf switch_filters = nf
self.transformation_counts = xforms self.transformation_counts = xforms
self.reference_processor = ReferenceImageBranch(transformation_filters) multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, 3,
multiplx_fn = functools.partial(ReferencingConvMultiplexer, transformation_filters, switch_filters) 2, use_exp2=True)
pretransform_fn = functools.partial(AdaInConvBlock, 512, transformation_filters, transformation_filters) pretransform_fn = functools.partial(ConvGnLelu, transformation_filters, transformation_filters, norm=False, bias=False, weight_init_factor=.1)
transform_fn = functools.partial(ProcessingBranchWithStochasticity, transformation_filters, transformation_filters, transformation_filters // 8, 3) transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5),
# For conjoining two input streams. transformation_filters, kernel_size=3, depth=3,
conjoin_multiplex_fn = functools.partial(MultiplexerWithReducer, nf, multiplx_fn) weight_init_factor=.1)
conjoin_pretransform_fn = functools.partial(AdaInConvBlock, 512, transformation_filters * 2, transformation_filters * 2)
conjoin_transform_fn = functools.partial(ProcessingBranchWithStochasticity, transformation_filters * 2, transformation_filters, transformation_filters // 8, 4)
# Feature branch # Feature branch
self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False) self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False)
self.noise_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.1, norm=False)
self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=pretransform_fn, transform_block=transform_fn, pre_transform_block=pretransform_fn, transform_block=transform_fn,
attention_norm=True, attention_norm=True,
@ -523,28 +522,30 @@ class SwitchedSpsrWithRef2(nn.Module):
self.feature_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False) self.feature_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False)
self.feature_lr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False) self.feature_lr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False)
# Grad branch # Grad branch. Note - groupnorm on this branch is REALLY bad. Avoid it like the plague.
self.get_g_nopadding = ImageGradientNoPadding() self.get_g_nopadding = ImageGradientNoPadding()
self.grad_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False) self.grad_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False)
self.sw_grad = ConfigurableSwitchComputer(transformation_filters, conjoin_multiplex_fn, self.grad_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2, norm=False, final_norm=False)
pre_transform_block=conjoin_pretransform_fn, transform_block=conjoin_transform_fn, self.sw_grad = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=pretransform_fn, transform_block=transform_fn,
attention_norm=True, attention_norm=True,
transform_count=self.transformation_counts // 2, init_temp=init_temperature, transform_count=self.transformation_counts // 2, init_temp=init_temperature,
add_scalable_noise_to_transforms=False) add_scalable_noise_to_transforms=False)
# Upsampling self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False)
self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False) self.grad_lr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False)
self.grad_lr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False) self.upsample_grad = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=False) for _ in range(n_upscale)])
self.upsample_grad = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=True, activation=True, bias=False) for _ in range(n_upscale)])
self.grad_branch_output_conv = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=True) self.grad_branch_output_conv = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=True)
self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, conjoin_multiplex_fn, # Join branch (grad+fea
pre_transform_block=conjoin_pretransform_fn, transform_block=conjoin_transform_fn, self.conjoin_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2, norm=False)
self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=pretransform_fn, transform_block=transform_fn,
attention_norm=True, attention_norm=True,
transform_count=self.transformation_counts, init_temp=init_temperature, transform_count=self.transformation_counts, init_temp=init_temperature,
add_scalable_noise_to_transforms=False) add_scalable_noise_to_transforms=False)
self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False) self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=True, activation=True, bias=False) for _ in range(n_upscale)]) self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=False) for _ in range(n_upscale)])
self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False) self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False)
self.final_hr_conv2 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=True) self.final_hr_conv2 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=True)
self.switches = [self.sw1, self.sw2, self.sw_grad, self.conjoin_sw] self.switches = [self.sw1, self.sw2, self.sw_grad, self.conjoin_sw]
self.attentions = None self.attentions = None
@ -552,25 +553,25 @@ class SwitchedSpsrWithRef2(nn.Module):
self.final_temperature_step = 10000 self.final_temperature_step = 10000
def forward(self, x, ref, center_coord): def forward(self, x, ref, center_coord):
ref = self.reference_processor(ref, center_coord) x_grad = self.get_g_nopadding(x)
x = self.model_fea_conv(x)
x1, a1 = self.sw1((x, ref), True) x = self.model_fea_conv(x)
x2, a2 = self.sw2((x1, ref), True) x = self.noise_ref_join(x, torch.randn_like(x))
x1, a1 = self.sw1(x, True)
x2, a2 = self.sw2(x, True)
x_fea = self.feature_lr_conv(x2) x_fea = self.feature_lr_conv(x2)
x_fea = self.feature_lr_conv2(x_fea) x_fea = self.feature_lr_conv2(x_fea)
x_grad = self.get_g_nopadding(x)
x_grad = self.grad_conv(x_grad) x_grad = self.grad_conv(x_grad)
x_grad, a3 = self.sw_grad((torch.cat([x_grad, x1], dim=1), ref), x_grad = self.grad_ref_join(x_grad, x1)
identity=x_grad, output_attention_weights=True) x_grad, a3 = self.sw_grad(x_grad, True)
x_grad = self.grad_lr_conv(x_grad) x_grad = self.grad_lr_conv(x_grad)
x_grad = self.grad_lr_conv2(x_grad) x_grad = self.grad_lr_conv2(x_grad)
x_grad_out = self.upsample_grad(x_grad) x_grad_out = self.upsample_grad(x_grad)
x_grad_out = self.grad_branch_output_conv(x_grad_out) x_grad_out = self.grad_branch_output_conv(x_grad_out)
x_out, a4 = self.conjoin_sw((torch.cat([x_fea, x_grad], dim=1), ref), x_out = self.conjoin_ref_join(x_fea, x_grad)
identity=x_fea, output_attention_weights=True) x_out, a4 = self.conjoin_sw(x_out, True)
x_out = self.final_lr_conv(x_out) x_out = self.final_lr_conv(x_out)
x_out = self.upsample(x_out) x_out = self.upsample(x_out)
x_out = self.final_hr_conv1(x_out) x_out = self.final_hr_conv1(x_out)

View File

@ -454,6 +454,21 @@ class ConjoinBlock(nn.Module):
return self.decimate(x) return self.decimate(x)
# Designed explicitly to join a mainline trunk with reference data. Implemented as a residual branch.
class ReferenceJoinBlock(nn.Module):
def __init__(self, nf, residual_weight_init_factor=1, norm=False, block=ConvGnLelu, final_norm=True):
super(ReferenceJoinBlock, self).__init__()
self.branch = MultiConvBlock(nf * 2, nf + nf // 2, nf, kernel_size=3, depth=3,
scale_init=residual_weight_init_factor, norm=norm,
weight_init_factor=residual_weight_init_factor)
self.join_conv = block(nf, nf, norm=final_norm, bias=False, activation=True)
def forward(self, x, ref):
joined = torch.cat([x, ref], dim=1)
branch = self.branch(joined)
return self.join_conv(x + branch)
# Basic convolutional upsampling block that uses interpolate. # Basic convolutional upsampling block that uses interpolate.
class UpconvBlock(nn.Module): class UpconvBlock(nn.Module):
def __init__(self, filters_in, filters_out=None, block=ConvGnSilu, norm=True, activation=True, bias=False): def __init__(self, filters_in, filters_out=None, block=ConvGnSilu, norm=True, activation=True, bias=False):

View File

@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
def main(): def main():
#### options #### options
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_csnln.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr3_gan.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher') help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)