SPSR3 work
SPSR3 is meant to fix whatever is causing the switching units inside of the newer SPSR architectures to fail and basically not use the multiplexers.
This commit is contained in:
parent
5606e8b0ee
commit
e6207d4c50
|
@ -270,7 +270,7 @@ class ExtensibleTrainer(BaseModel):
|
||||||
load_path = self.opt['path']['pretrain_model_%s' % (name,)]
|
load_path = self.opt['path']['pretrain_model_%s' % (name,)]
|
||||||
if load_path is not None:
|
if load_path is not None:
|
||||||
logger.info('Loading model for [%s]' % (load_path))
|
logger.info('Loading model for [%s]' % (load_path))
|
||||||
self.load_network(load_path, net)
|
self.load_network(load_path, net, self.opt['path']['strict_load'])
|
||||||
|
|
||||||
def save(self, iter_step):
|
def save(self, iter_step):
|
||||||
for name, net in self.networks.items():
|
for name, net in self.networks.items():
|
||||||
|
|
|
@ -4,7 +4,7 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from models.archs import SPSR_util as B
|
from models.archs import SPSR_util as B
|
||||||
from .RRDBNet_arch import RRDB
|
from .RRDBNet_arch import RRDB
|
||||||
from models.archs.arch_util import ConvGnLelu, UpconvBlock, ConjoinBlock, ConvGnSilu, MultiConvBlock
|
from models.archs.arch_util import ConvGnLelu, UpconvBlock, ConjoinBlock, ConvGnSilu, MultiConvBlock, ReferenceJoinBlock
|
||||||
from models.archs.SwitchedResidualGenerator_arch import ConvBasisMultiplexer, ConfigurableSwitchComputer, ReferencingConvMultiplexer, ReferenceImageBranch, AdaInConvBlock, ProcessingBranchWithStochasticity
|
from models.archs.SwitchedResidualGenerator_arch import ConvBasisMultiplexer, ConfigurableSwitchComputer, ReferencingConvMultiplexer, ReferenceImageBranch, AdaInConvBlock, ProcessingBranchWithStochasticity
|
||||||
from switched_conv_util import save_attention_to_image_rgb
|
from switched_conv_util import save_attention_to_image_rgb
|
||||||
from switched_conv import compute_attention_specificity
|
from switched_conv import compute_attention_specificity
|
||||||
|
@ -499,17 +499,16 @@ class SwitchedSpsrWithRef2(nn.Module):
|
||||||
transformation_filters = nf
|
transformation_filters = nf
|
||||||
switch_filters = nf
|
switch_filters = nf
|
||||||
self.transformation_counts = xforms
|
self.transformation_counts = xforms
|
||||||
self.reference_processor = ReferenceImageBranch(transformation_filters)
|
multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, 3,
|
||||||
multiplx_fn = functools.partial(ReferencingConvMultiplexer, transformation_filters, switch_filters)
|
2, use_exp2=True)
|
||||||
pretransform_fn = functools.partial(AdaInConvBlock, 512, transformation_filters, transformation_filters)
|
pretransform_fn = functools.partial(ConvGnLelu, transformation_filters, transformation_filters, norm=False, bias=False, weight_init_factor=.1)
|
||||||
transform_fn = functools.partial(ProcessingBranchWithStochasticity, transformation_filters, transformation_filters, transformation_filters // 8, 3)
|
transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5),
|
||||||
# For conjoining two input streams.
|
transformation_filters, kernel_size=3, depth=3,
|
||||||
conjoin_multiplex_fn = functools.partial(MultiplexerWithReducer, nf, multiplx_fn)
|
weight_init_factor=.1)
|
||||||
conjoin_pretransform_fn = functools.partial(AdaInConvBlock, 512, transformation_filters * 2, transformation_filters * 2)
|
|
||||||
conjoin_transform_fn = functools.partial(ProcessingBranchWithStochasticity, transformation_filters * 2, transformation_filters, transformation_filters // 8, 4)
|
|
||||||
|
|
||||||
# Feature branch
|
# Feature branch
|
||||||
self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False)
|
self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False)
|
||||||
|
self.noise_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.1, norm=False)
|
||||||
self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||||
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||||
attention_norm=True,
|
attention_norm=True,
|
||||||
|
@ -523,28 +522,30 @@ class SwitchedSpsrWithRef2(nn.Module):
|
||||||
self.feature_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False)
|
self.feature_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False)
|
||||||
self.feature_lr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False)
|
self.feature_lr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False)
|
||||||
|
|
||||||
# Grad branch
|
# Grad branch. Note - groupnorm on this branch is REALLY bad. Avoid it like the plague.
|
||||||
self.get_g_nopadding = ImageGradientNoPadding()
|
self.get_g_nopadding = ImageGradientNoPadding()
|
||||||
self.grad_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False)
|
self.grad_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False)
|
||||||
self.sw_grad = ConfigurableSwitchComputer(transformation_filters, conjoin_multiplex_fn,
|
self.grad_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2, norm=False, final_norm=False)
|
||||||
pre_transform_block=conjoin_pretransform_fn, transform_block=conjoin_transform_fn,
|
self.sw_grad = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||||
|
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||||
attention_norm=True,
|
attention_norm=True,
|
||||||
transform_count=self.transformation_counts // 2, init_temp=init_temperature,
|
transform_count=self.transformation_counts // 2, init_temp=init_temperature,
|
||||||
add_scalable_noise_to_transforms=False)
|
add_scalable_noise_to_transforms=False)
|
||||||
# Upsampling
|
self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False)
|
||||||
self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False)
|
self.grad_lr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False)
|
||||||
self.grad_lr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False)
|
self.upsample_grad = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=False) for _ in range(n_upscale)])
|
||||||
self.upsample_grad = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=True, activation=True, bias=False) for _ in range(n_upscale)])
|
|
||||||
self.grad_branch_output_conv = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=True)
|
self.grad_branch_output_conv = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=True)
|
||||||
|
|
||||||
self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, conjoin_multiplex_fn,
|
# Join branch (grad+fea
|
||||||
pre_transform_block=conjoin_pretransform_fn, transform_block=conjoin_transform_fn,
|
self.conjoin_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2, norm=False)
|
||||||
|
self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||||
|
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||||
attention_norm=True,
|
attention_norm=True,
|
||||||
transform_count=self.transformation_counts, init_temp=init_temperature,
|
transform_count=self.transformation_counts, init_temp=init_temperature,
|
||||||
add_scalable_noise_to_transforms=False)
|
add_scalable_noise_to_transforms=False)
|
||||||
self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False)
|
self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
|
||||||
self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=True, activation=True, bias=False) for _ in range(n_upscale)])
|
self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=False) for _ in range(n_upscale)])
|
||||||
self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False)
|
self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False)
|
||||||
self.final_hr_conv2 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=True)
|
self.final_hr_conv2 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=True)
|
||||||
self.switches = [self.sw1, self.sw2, self.sw_grad, self.conjoin_sw]
|
self.switches = [self.sw1, self.sw2, self.sw_grad, self.conjoin_sw]
|
||||||
self.attentions = None
|
self.attentions = None
|
||||||
|
@ -552,25 +553,25 @@ class SwitchedSpsrWithRef2(nn.Module):
|
||||||
self.final_temperature_step = 10000
|
self.final_temperature_step = 10000
|
||||||
|
|
||||||
def forward(self, x, ref, center_coord):
|
def forward(self, x, ref, center_coord):
|
||||||
ref = self.reference_processor(ref, center_coord)
|
x_grad = self.get_g_nopadding(x)
|
||||||
x = self.model_fea_conv(x)
|
|
||||||
|
|
||||||
x1, a1 = self.sw1((x, ref), True)
|
x = self.model_fea_conv(x)
|
||||||
x2, a2 = self.sw2((x1, ref), True)
|
x = self.noise_ref_join(x, torch.randn_like(x))
|
||||||
|
x1, a1 = self.sw1(x, True)
|
||||||
|
x2, a2 = self.sw2(x, True)
|
||||||
x_fea = self.feature_lr_conv(x2)
|
x_fea = self.feature_lr_conv(x2)
|
||||||
x_fea = self.feature_lr_conv2(x_fea)
|
x_fea = self.feature_lr_conv2(x_fea)
|
||||||
|
|
||||||
x_grad = self.get_g_nopadding(x)
|
|
||||||
x_grad = self.grad_conv(x_grad)
|
x_grad = self.grad_conv(x_grad)
|
||||||
x_grad, a3 = self.sw_grad((torch.cat([x_grad, x1], dim=1), ref),
|
x_grad = self.grad_ref_join(x_grad, x1)
|
||||||
identity=x_grad, output_attention_weights=True)
|
x_grad, a3 = self.sw_grad(x_grad, True)
|
||||||
x_grad = self.grad_lr_conv(x_grad)
|
x_grad = self.grad_lr_conv(x_grad)
|
||||||
x_grad = self.grad_lr_conv2(x_grad)
|
x_grad = self.grad_lr_conv2(x_grad)
|
||||||
x_grad_out = self.upsample_grad(x_grad)
|
x_grad_out = self.upsample_grad(x_grad)
|
||||||
x_grad_out = self.grad_branch_output_conv(x_grad_out)
|
x_grad_out = self.grad_branch_output_conv(x_grad_out)
|
||||||
|
|
||||||
x_out, a4 = self.conjoin_sw((torch.cat([x_fea, x_grad], dim=1), ref),
|
x_out = self.conjoin_ref_join(x_fea, x_grad)
|
||||||
identity=x_fea, output_attention_weights=True)
|
x_out, a4 = self.conjoin_sw(x_out, True)
|
||||||
x_out = self.final_lr_conv(x_out)
|
x_out = self.final_lr_conv(x_out)
|
||||||
x_out = self.upsample(x_out)
|
x_out = self.upsample(x_out)
|
||||||
x_out = self.final_hr_conv1(x_out)
|
x_out = self.final_hr_conv1(x_out)
|
||||||
|
|
|
@ -454,6 +454,21 @@ class ConjoinBlock(nn.Module):
|
||||||
return self.decimate(x)
|
return self.decimate(x)
|
||||||
|
|
||||||
|
|
||||||
|
# Designed explicitly to join a mainline trunk with reference data. Implemented as a residual branch.
|
||||||
|
class ReferenceJoinBlock(nn.Module):
|
||||||
|
def __init__(self, nf, residual_weight_init_factor=1, norm=False, block=ConvGnLelu, final_norm=True):
|
||||||
|
super(ReferenceJoinBlock, self).__init__()
|
||||||
|
self.branch = MultiConvBlock(nf * 2, nf + nf // 2, nf, kernel_size=3, depth=3,
|
||||||
|
scale_init=residual_weight_init_factor, norm=norm,
|
||||||
|
weight_init_factor=residual_weight_init_factor)
|
||||||
|
self.join_conv = block(nf, nf, norm=final_norm, bias=False, activation=True)
|
||||||
|
|
||||||
|
def forward(self, x, ref):
|
||||||
|
joined = torch.cat([x, ref], dim=1)
|
||||||
|
branch = self.branch(joined)
|
||||||
|
return self.join_conv(x + branch)
|
||||||
|
|
||||||
|
|
||||||
# Basic convolutional upsampling block that uses interpolate.
|
# Basic convolutional upsampling block that uses interpolate.
|
||||||
class UpconvBlock(nn.Module):
|
class UpconvBlock(nn.Module):
|
||||||
def __init__(self, filters_in, filters_out=None, block=ConvGnSilu, norm=True, activation=True, bias=False):
|
def __init__(self, filters_in, filters_out=None, block=ConvGnSilu, norm=True, activation=True, bias=False):
|
||||||
|
|
|
@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
|
||||||
def main():
|
def main():
|
||||||
#### options
|
#### options
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_csnln.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr3_gan.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||||
help='job launcher')
|
help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user