diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index c0be7ba9..985d017f 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -270,7 +270,7 @@ class ExtensibleTrainer(BaseModel): load_path = self.opt['path']['pretrain_model_%s' % (name,)] if load_path is not None: logger.info('Loading model for [%s]' % (load_path)) - self.load_network(load_path, net) + self.load_network(load_path, net, self.opt['path']['strict_load']) def save(self, iter_step): for name, net in self.networks.items(): diff --git a/codes/models/archs/SPSR_arch.py b/codes/models/archs/SPSR_arch.py index b3d5de54..fbfb3e3f 100644 --- a/codes/models/archs/SPSR_arch.py +++ b/codes/models/archs/SPSR_arch.py @@ -4,7 +4,7 @@ import torch.nn as nn import torch.nn.functional as F from models.archs import SPSR_util as B from .RRDBNet_arch import RRDB -from models.archs.arch_util import ConvGnLelu, UpconvBlock, ConjoinBlock, ConvGnSilu, MultiConvBlock +from models.archs.arch_util import ConvGnLelu, UpconvBlock, ConjoinBlock, ConvGnSilu, MultiConvBlock, ReferenceJoinBlock from models.archs.SwitchedResidualGenerator_arch import ConvBasisMultiplexer, ConfigurableSwitchComputer, ReferencingConvMultiplexer, ReferenceImageBranch, AdaInConvBlock, ProcessingBranchWithStochasticity from switched_conv_util import save_attention_to_image_rgb from switched_conv import compute_attention_specificity @@ -499,17 +499,16 @@ class SwitchedSpsrWithRef2(nn.Module): transformation_filters = nf switch_filters = nf self.transformation_counts = xforms - self.reference_processor = ReferenceImageBranch(transformation_filters) - multiplx_fn = functools.partial(ReferencingConvMultiplexer, transformation_filters, switch_filters) - pretransform_fn = functools.partial(AdaInConvBlock, 512, transformation_filters, transformation_filters) - transform_fn = functools.partial(ProcessingBranchWithStochasticity, transformation_filters, transformation_filters, transformation_filters // 8, 3) - # For conjoining two input streams. - conjoin_multiplex_fn = functools.partial(MultiplexerWithReducer, nf, multiplx_fn) - conjoin_pretransform_fn = functools.partial(AdaInConvBlock, 512, transformation_filters * 2, transformation_filters * 2) - conjoin_transform_fn = functools.partial(ProcessingBranchWithStochasticity, transformation_filters * 2, transformation_filters, transformation_filters // 8, 4) + multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, 3, + 2, use_exp2=True) + pretransform_fn = functools.partial(ConvGnLelu, transformation_filters, transformation_filters, norm=False, bias=False, weight_init_factor=.1) + transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5), + transformation_filters, kernel_size=3, depth=3, + weight_init_factor=.1) # Feature branch self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False) + self.noise_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.1, norm=False) self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, pre_transform_block=pretransform_fn, transform_block=transform_fn, attention_norm=True, @@ -523,28 +522,30 @@ class SwitchedSpsrWithRef2(nn.Module): self.feature_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False) self.feature_lr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False) - # Grad branch + # Grad branch. Note - groupnorm on this branch is REALLY bad. Avoid it like the plague. self.get_g_nopadding = ImageGradientNoPadding() self.grad_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False) - self.sw_grad = ConfigurableSwitchComputer(transformation_filters, conjoin_multiplex_fn, - pre_transform_block=conjoin_pretransform_fn, transform_block=conjoin_transform_fn, + self.grad_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2, norm=False, final_norm=False) + self.sw_grad = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, + pre_transform_block=pretransform_fn, transform_block=transform_fn, attention_norm=True, transform_count=self.transformation_counts // 2, init_temp=init_temperature, add_scalable_noise_to_transforms=False) - # Upsampling - self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False) - self.grad_lr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False) - self.upsample_grad = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=True, activation=True, bias=False) for _ in range(n_upscale)]) + self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False) + self.grad_lr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False) + self.upsample_grad = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=False) for _ in range(n_upscale)]) self.grad_branch_output_conv = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=True) - self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, conjoin_multiplex_fn, - pre_transform_block=conjoin_pretransform_fn, transform_block=conjoin_transform_fn, + # Join branch (grad+fea + self.conjoin_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2, norm=False) + self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, + pre_transform_block=pretransform_fn, transform_block=transform_fn, attention_norm=True, transform_count=self.transformation_counts, init_temp=init_temperature, add_scalable_noise_to_transforms=False) - self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False) - self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=True, activation=True, bias=False) for _ in range(n_upscale)]) - self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False) + self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False) + self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=False) for _ in range(n_upscale)]) + self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False) self.final_hr_conv2 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=True) self.switches = [self.sw1, self.sw2, self.sw_grad, self.conjoin_sw] self.attentions = None @@ -552,25 +553,25 @@ class SwitchedSpsrWithRef2(nn.Module): self.final_temperature_step = 10000 def forward(self, x, ref, center_coord): - ref = self.reference_processor(ref, center_coord) - x = self.model_fea_conv(x) + x_grad = self.get_g_nopadding(x) - x1, a1 = self.sw1((x, ref), True) - x2, a2 = self.sw2((x1, ref), True) + x = self.model_fea_conv(x) + x = self.noise_ref_join(x, torch.randn_like(x)) + x1, a1 = self.sw1(x, True) + x2, a2 = self.sw2(x, True) x_fea = self.feature_lr_conv(x2) x_fea = self.feature_lr_conv2(x_fea) - x_grad = self.get_g_nopadding(x) x_grad = self.grad_conv(x_grad) - x_grad, a3 = self.sw_grad((torch.cat([x_grad, x1], dim=1), ref), - identity=x_grad, output_attention_weights=True) + x_grad = self.grad_ref_join(x_grad, x1) + x_grad, a3 = self.sw_grad(x_grad, True) x_grad = self.grad_lr_conv(x_grad) x_grad = self.grad_lr_conv2(x_grad) x_grad_out = self.upsample_grad(x_grad) x_grad_out = self.grad_branch_output_conv(x_grad_out) - x_out, a4 = self.conjoin_sw((torch.cat([x_fea, x_grad], dim=1), ref), - identity=x_fea, output_attention_weights=True) + x_out = self.conjoin_ref_join(x_fea, x_grad) + x_out, a4 = self.conjoin_sw(x_out, True) x_out = self.final_lr_conv(x_out) x_out = self.upsample(x_out) x_out = self.final_hr_conv1(x_out) diff --git a/codes/models/archs/arch_util.py b/codes/models/archs/arch_util.py index 6ad7f6fb..4c84fc99 100644 --- a/codes/models/archs/arch_util.py +++ b/codes/models/archs/arch_util.py @@ -454,6 +454,21 @@ class ConjoinBlock(nn.Module): return self.decimate(x) +# Designed explicitly to join a mainline trunk with reference data. Implemented as a residual branch. +class ReferenceJoinBlock(nn.Module): + def __init__(self, nf, residual_weight_init_factor=1, norm=False, block=ConvGnLelu, final_norm=True): + super(ReferenceJoinBlock, self).__init__() + self.branch = MultiConvBlock(nf * 2, nf + nf // 2, nf, kernel_size=3, depth=3, + scale_init=residual_weight_init_factor, norm=norm, + weight_init_factor=residual_weight_init_factor) + self.join_conv = block(nf, nf, norm=final_norm, bias=False, activation=True) + + def forward(self, x, ref): + joined = torch.cat([x, ref], dim=1) + branch = self.branch(joined) + return self.join_conv(x + branch) + + # Basic convolutional upsampling block that uses interpolate. class UpconvBlock(nn.Module): def __init__(self, filters_in, filters_out=None, block=ConvGnSilu, norm=True, activation=True, bias=False): diff --git a/codes/train.py b/codes/train.py index 8009981c..f993f4e9 100644 --- a/codes/train.py +++ b/codes/train.py @@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_csnln.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr3_gan.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0)