SPSR3 work
SPSR3 is meant to fix whatever is causing the switching units inside of the newer SPSR architectures to fail and basically not use the multiplexers.
This commit is contained in:
parent
5606e8b0ee
commit
e6207d4c50
|
@ -270,7 +270,7 @@ class ExtensibleTrainer(BaseModel):
|
|||
load_path = self.opt['path']['pretrain_model_%s' % (name,)]
|
||||
if load_path is not None:
|
||||
logger.info('Loading model for [%s]' % (load_path))
|
||||
self.load_network(load_path, net)
|
||||
self.load_network(load_path, net, self.opt['path']['strict_load'])
|
||||
|
||||
def save(self, iter_step):
|
||||
for name, net in self.networks.items():
|
||||
|
|
|
@ -4,7 +4,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from models.archs import SPSR_util as B
|
||||
from .RRDBNet_arch import RRDB
|
||||
from models.archs.arch_util import ConvGnLelu, UpconvBlock, ConjoinBlock, ConvGnSilu, MultiConvBlock
|
||||
from models.archs.arch_util import ConvGnLelu, UpconvBlock, ConjoinBlock, ConvGnSilu, MultiConvBlock, ReferenceJoinBlock
|
||||
from models.archs.SwitchedResidualGenerator_arch import ConvBasisMultiplexer, ConfigurableSwitchComputer, ReferencingConvMultiplexer, ReferenceImageBranch, AdaInConvBlock, ProcessingBranchWithStochasticity
|
||||
from switched_conv_util import save_attention_to_image_rgb
|
||||
from switched_conv import compute_attention_specificity
|
||||
|
@ -499,17 +499,16 @@ class SwitchedSpsrWithRef2(nn.Module):
|
|||
transformation_filters = nf
|
||||
switch_filters = nf
|
||||
self.transformation_counts = xforms
|
||||
self.reference_processor = ReferenceImageBranch(transformation_filters)
|
||||
multiplx_fn = functools.partial(ReferencingConvMultiplexer, transformation_filters, switch_filters)
|
||||
pretransform_fn = functools.partial(AdaInConvBlock, 512, transformation_filters, transformation_filters)
|
||||
transform_fn = functools.partial(ProcessingBranchWithStochasticity, transformation_filters, transformation_filters, transformation_filters // 8, 3)
|
||||
# For conjoining two input streams.
|
||||
conjoin_multiplex_fn = functools.partial(MultiplexerWithReducer, nf, multiplx_fn)
|
||||
conjoin_pretransform_fn = functools.partial(AdaInConvBlock, 512, transformation_filters * 2, transformation_filters * 2)
|
||||
conjoin_transform_fn = functools.partial(ProcessingBranchWithStochasticity, transformation_filters * 2, transformation_filters, transformation_filters // 8, 4)
|
||||
multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, 3,
|
||||
2, use_exp2=True)
|
||||
pretransform_fn = functools.partial(ConvGnLelu, transformation_filters, transformation_filters, norm=False, bias=False, weight_init_factor=.1)
|
||||
transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5),
|
||||
transformation_filters, kernel_size=3, depth=3,
|
||||
weight_init_factor=.1)
|
||||
|
||||
# Feature branch
|
||||
self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False)
|
||||
self.noise_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.1, norm=False)
|
||||
self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||
attention_norm=True,
|
||||
|
@ -523,28 +522,30 @@ class SwitchedSpsrWithRef2(nn.Module):
|
|||
self.feature_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False)
|
||||
self.feature_lr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False)
|
||||
|
||||
# Grad branch
|
||||
# Grad branch. Note - groupnorm on this branch is REALLY bad. Avoid it like the plague.
|
||||
self.get_g_nopadding = ImageGradientNoPadding()
|
||||
self.grad_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False)
|
||||
self.sw_grad = ConfigurableSwitchComputer(transformation_filters, conjoin_multiplex_fn,
|
||||
pre_transform_block=conjoin_pretransform_fn, transform_block=conjoin_transform_fn,
|
||||
self.grad_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2, norm=False, final_norm=False)
|
||||
self.sw_grad = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||
attention_norm=True,
|
||||
transform_count=self.transformation_counts // 2, init_temp=init_temperature,
|
||||
add_scalable_noise_to_transforms=False)
|
||||
# Upsampling
|
||||
self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False)
|
||||
self.grad_lr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False)
|
||||
self.upsample_grad = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=True, activation=True, bias=False) for _ in range(n_upscale)])
|
||||
self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False)
|
||||
self.grad_lr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False)
|
||||
self.upsample_grad = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=False) for _ in range(n_upscale)])
|
||||
self.grad_branch_output_conv = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=True)
|
||||
|
||||
self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, conjoin_multiplex_fn,
|
||||
pre_transform_block=conjoin_pretransform_fn, transform_block=conjoin_transform_fn,
|
||||
# Join branch (grad+fea
|
||||
self.conjoin_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2, norm=False)
|
||||
self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||
attention_norm=True,
|
||||
transform_count=self.transformation_counts, init_temp=init_temperature,
|
||||
add_scalable_noise_to_transforms=False)
|
||||
self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False)
|
||||
self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=True, activation=True, bias=False) for _ in range(n_upscale)])
|
||||
self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False)
|
||||
self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
|
||||
self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=False) for _ in range(n_upscale)])
|
||||
self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False)
|
||||
self.final_hr_conv2 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=True)
|
||||
self.switches = [self.sw1, self.sw2, self.sw_grad, self.conjoin_sw]
|
||||
self.attentions = None
|
||||
|
@ -552,25 +553,25 @@ class SwitchedSpsrWithRef2(nn.Module):
|
|||
self.final_temperature_step = 10000
|
||||
|
||||
def forward(self, x, ref, center_coord):
|
||||
ref = self.reference_processor(ref, center_coord)
|
||||
x = self.model_fea_conv(x)
|
||||
x_grad = self.get_g_nopadding(x)
|
||||
|
||||
x1, a1 = self.sw1((x, ref), True)
|
||||
x2, a2 = self.sw2((x1, ref), True)
|
||||
x = self.model_fea_conv(x)
|
||||
x = self.noise_ref_join(x, torch.randn_like(x))
|
||||
x1, a1 = self.sw1(x, True)
|
||||
x2, a2 = self.sw2(x, True)
|
||||
x_fea = self.feature_lr_conv(x2)
|
||||
x_fea = self.feature_lr_conv2(x_fea)
|
||||
|
||||
x_grad = self.get_g_nopadding(x)
|
||||
x_grad = self.grad_conv(x_grad)
|
||||
x_grad, a3 = self.sw_grad((torch.cat([x_grad, x1], dim=1), ref),
|
||||
identity=x_grad, output_attention_weights=True)
|
||||
x_grad = self.grad_ref_join(x_grad, x1)
|
||||
x_grad, a3 = self.sw_grad(x_grad, True)
|
||||
x_grad = self.grad_lr_conv(x_grad)
|
||||
x_grad = self.grad_lr_conv2(x_grad)
|
||||
x_grad_out = self.upsample_grad(x_grad)
|
||||
x_grad_out = self.grad_branch_output_conv(x_grad_out)
|
||||
|
||||
x_out, a4 = self.conjoin_sw((torch.cat([x_fea, x_grad], dim=1), ref),
|
||||
identity=x_fea, output_attention_weights=True)
|
||||
x_out = self.conjoin_ref_join(x_fea, x_grad)
|
||||
x_out, a4 = self.conjoin_sw(x_out, True)
|
||||
x_out = self.final_lr_conv(x_out)
|
||||
x_out = self.upsample(x_out)
|
||||
x_out = self.final_hr_conv1(x_out)
|
||||
|
|
|
@ -454,6 +454,21 @@ class ConjoinBlock(nn.Module):
|
|||
return self.decimate(x)
|
||||
|
||||
|
||||
# Designed explicitly to join a mainline trunk with reference data. Implemented as a residual branch.
|
||||
class ReferenceJoinBlock(nn.Module):
|
||||
def __init__(self, nf, residual_weight_init_factor=1, norm=False, block=ConvGnLelu, final_norm=True):
|
||||
super(ReferenceJoinBlock, self).__init__()
|
||||
self.branch = MultiConvBlock(nf * 2, nf + nf // 2, nf, kernel_size=3, depth=3,
|
||||
scale_init=residual_weight_init_factor, norm=norm,
|
||||
weight_init_factor=residual_weight_init_factor)
|
||||
self.join_conv = block(nf, nf, norm=final_norm, bias=False, activation=True)
|
||||
|
||||
def forward(self, x, ref):
|
||||
joined = torch.cat([x, ref], dim=1)
|
||||
branch = self.branch(joined)
|
||||
return self.join_conv(x + branch)
|
||||
|
||||
|
||||
# Basic convolutional upsampling block that uses interpolate.
|
||||
class UpconvBlock(nn.Module):
|
||||
def __init__(self, filters_in, filters_out=None, block=ConvGnSilu, norm=True, activation=True, bias=False):
|
||||
|
|
|
@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
|
|||
def main():
|
||||
#### options
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_csnln.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr3_gan.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
|
|
Loading…
Reference in New Issue
Block a user