From 747ded2bf73106f10bc982c58ec9d243fb837c5f Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 9 Sep 2020 15:28:14 -0600 Subject: [PATCH] Fixes to the spsr3 Some lessons learned: - Biases are fairly important as a relief valve. They dont need to be everywhere, but most computationally heavy branches should have a bias. - GroupNorm in SPSR is not a great idea. Since image gradients are represented in this model, normal means and standard deviations are not applicable. (imggrad has a high representation of 0). - Don't fuck with the mainline of any generative model. As much as possible, all additions should be done through residual connections. Never pollute the mainline with reference data, do that in branches. It basically leaves the mode untrainable. --- codes/models/archs/SPSR_arch.py | 31 ++++++++++--------- .../archs/SwitchedResidualGenerator_arch.py | 2 +- codes/train.py | 5 +-- 3 files changed, 20 insertions(+), 18 deletions(-) diff --git a/codes/models/archs/SPSR_arch.py b/codes/models/archs/SPSR_arch.py index 307d8078..93daa7f5 100644 --- a/codes/models/archs/SPSR_arch.py +++ b/codes/models/archs/SPSR_arch.py @@ -545,27 +545,29 @@ class SwitchedSpsrWithRef2(nn.Module): # Grad branch. Note - groupnorm on this branch is REALLY bad. Avoid it like the plague. self.get_g_nopadding = ImageGradientNoPadding() self.grad_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False) + self.ref_join3 = RefJoiner(nf) self.grad_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2, norm=False, final_norm=False) self.sw_grad = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, pre_transform_block=pretransform_fn, transform_block=transform_fn, attention_norm=True, transform_count=self.transformation_counts // 2, init_temp=init_temperature, add_scalable_noise_to_transforms=False) - self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False) - self.grad_lr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False) + self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) + self.grad_lr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) self.upsample_grad = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=False) for _ in range(n_upscale)]) self.grad_branch_output_conv = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=True) # Join branch (grad+fea) + self.ref_join4 = RefJoiner(nf) self.conjoin_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2, norm=False) self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, pre_transform_block=pretransform_fn, transform_block=transform_fn, attention_norm=True, transform_count=self.transformation_counts, init_temp=init_temperature, add_scalable_noise_to_transforms=False) - self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False) - self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=True, activation=True, bias=False) for _ in range(n_upscale)]) - self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False) + self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) + self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=True) for _ in range(n_upscale)]) + self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=True) self.final_hr_conv2 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=False) self.switches = [self.sw1, self.sw2, self.sw_grad, self.conjoin_sw] self.attentions = None @@ -577,27 +579,26 @@ class SwitchedSpsrWithRef2(nn.Module): ref = self.reference_processor(ref, center_coord) x = self.model_fea_conv(x) - x = self.noise_ref_join(x, torch.randn_like(x)) - x = self.ref_join1(x, ref) - x1, a1 = self.sw1(x, True) + x1 = self.noise_ref_join(x, torch.randn_like(x)) + x1 = self.ref_join1(x1, ref) + x1, a1 = self.sw1(x1, True, identity=x) x2 = x1 x2 = self.ref_join2(x2, ref) - x2, a2 = self.sw2(x2, True) - x_fea = self.feature_lr_conv(x2) - x_fea = self.feature_lr_conv2(x_fea) + x2, a2 = self.sw2(x2, True, identity=x1) - x_grad = self.grad_conv(x_grad) + x_grad_identity = self.grad_conv(x_grad) + x_grad = self.ref_join3(x_grad_identity, ref) x_grad = self.grad_ref_join(x_grad, x1) - x_grad, a3 = self.sw_grad(x_grad, True) + x_grad, a3 = self.sw_grad(x_grad, True, identity=x_grad_identity) x_grad = self.grad_lr_conv(x_grad) x_grad = self.grad_lr_conv2(x_grad) x_grad_out = self.upsample_grad(x_grad) x_grad_out = self.grad_branch_output_conv(x_grad_out) - x_out = x_fea + x_out = self.ref_join4(x2, ref) x_out = self.conjoin_ref_join(x_out, x_grad) - x_out, a4 = self.conjoin_sw(x_out, True) + x_out, a4 = self.conjoin_sw(x_out, True, identity=x2) x_out = self.final_lr_conv(x_out) x_out = self.upsample(x_out) x_out = self.final_hr_conv1(x_out) diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index e6f07544..ed82d015 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -4,7 +4,7 @@ from switched_conv import BareConvSwitch, compute_attention_specificity, Attenti import torch.nn.functional as F import functools from collections import OrderedDict -from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConvGnLelu +from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConvGnLelu, MultiConvBlock from switched_conv_util import save_attention_to_image_rgb import os from torch.utils.checkpoint import checkpoint diff --git a/codes/train.py b/codes/train.py index 661e9b92..e7e8c25c 100644 --- a/codes/train.py +++ b/codes/train.py @@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/pretrain_spsr_switched2_psnr.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr3_gan.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) @@ -119,6 +119,7 @@ def main(): torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True + # torch.autograd.set_detect_anomaly(True) #### create train and val dataloader dataset_ratio = 1 # enlarge the size of each epoch @@ -173,7 +174,7 @@ def main(): _t = time() _profile = False - for _, train_data in enumerate(tq_ldr): + for train_data in tq_ldr: if _profile: print("Data fetch: %f" % (time() - _t)) _t = time()