Fixes to the spsr3

Some lessons learned:
- Biases are fairly important as a relief valve. They dont need to be everywhere, but
  most computationally heavy branches should have a bias.
- GroupNorm in SPSR is not a great idea. Since image gradients are represented
   in this model, normal means and standard deviations are not applicable. (imggrad
   has a high representation of 0).
- Don't fuck with the mainline of any generative model. As much as possible, all
   additions should be done through residual connections. Never pollute the mainline
   with reference data, do that in branches. It basically leaves the mode untrainable.
This commit is contained in:
James Betker 2020-09-09 15:28:14 -06:00
parent 0ffac391c1
commit 747ded2bf7
3 changed files with 20 additions and 18 deletions

View File

@ -545,27 +545,29 @@ class SwitchedSpsrWithRef2(nn.Module):
# Grad branch. Note - groupnorm on this branch is REALLY bad. Avoid it like the plague.
self.get_g_nopadding = ImageGradientNoPadding()
self.grad_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False)
self.ref_join3 = RefJoiner(nf)
self.grad_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2, norm=False, final_norm=False)
self.sw_grad = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=pretransform_fn, transform_block=transform_fn,
attention_norm=True,
transform_count=self.transformation_counts // 2, init_temp=init_temperature,
add_scalable_noise_to_transforms=False)
self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False)
self.grad_lr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False)
self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
self.grad_lr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
self.upsample_grad = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=False) for _ in range(n_upscale)])
self.grad_branch_output_conv = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=True)
# Join branch (grad+fea)
self.ref_join4 = RefJoiner(nf)
self.conjoin_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2, norm=False)
self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=pretransform_fn, transform_block=transform_fn,
attention_norm=True,
transform_count=self.transformation_counts, init_temp=init_temperature,
add_scalable_noise_to_transforms=False)
self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False)
self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=True, activation=True, bias=False) for _ in range(n_upscale)])
self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False)
self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=True) for _ in range(n_upscale)])
self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=True)
self.final_hr_conv2 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=False)
self.switches = [self.sw1, self.sw2, self.sw_grad, self.conjoin_sw]
self.attentions = None
@ -577,27 +579,26 @@ class SwitchedSpsrWithRef2(nn.Module):
ref = self.reference_processor(ref, center_coord)
x = self.model_fea_conv(x)
x = self.noise_ref_join(x, torch.randn_like(x))
x = self.ref_join1(x, ref)
x1, a1 = self.sw1(x, True)
x1 = self.noise_ref_join(x, torch.randn_like(x))
x1 = self.ref_join1(x1, ref)
x1, a1 = self.sw1(x1, True, identity=x)
x2 = x1
x2 = self.ref_join2(x2, ref)
x2, a2 = self.sw2(x2, True)
x_fea = self.feature_lr_conv(x2)
x_fea = self.feature_lr_conv2(x_fea)
x2, a2 = self.sw2(x2, True, identity=x1)
x_grad = self.grad_conv(x_grad)
x_grad_identity = self.grad_conv(x_grad)
x_grad = self.ref_join3(x_grad_identity, ref)
x_grad = self.grad_ref_join(x_grad, x1)
x_grad, a3 = self.sw_grad(x_grad, True)
x_grad, a3 = self.sw_grad(x_grad, True, identity=x_grad_identity)
x_grad = self.grad_lr_conv(x_grad)
x_grad = self.grad_lr_conv2(x_grad)
x_grad_out = self.upsample_grad(x_grad)
x_grad_out = self.grad_branch_output_conv(x_grad_out)
x_out = x_fea
x_out = self.ref_join4(x2, ref)
x_out = self.conjoin_ref_join(x_out, x_grad)
x_out, a4 = self.conjoin_sw(x_out, True)
x_out, a4 = self.conjoin_sw(x_out, True, identity=x2)
x_out = self.final_lr_conv(x_out)
x_out = self.upsample(x_out)
x_out = self.final_hr_conv1(x_out)

View File

@ -4,7 +4,7 @@ from switched_conv import BareConvSwitch, compute_attention_specificity, Attenti
import torch.nn.functional as F
import functools
from collections import OrderedDict
from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConvGnLelu
from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConvGnLelu, MultiConvBlock
from switched_conv_util import save_attention_to_image_rgb
import os
from torch.utils.checkpoint import checkpoint

View File

@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
def main():
#### options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/pretrain_spsr_switched2_psnr.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr3_gan.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
@ -119,6 +119,7 @@ def main():
torch.backends.cudnn.benchmark = True
# torch.backends.cudnn.deterministic = True
# torch.autograd.set_detect_anomaly(True)
#### create train and val dataloader
dataset_ratio = 1 # enlarge the size of each epoch
@ -173,7 +174,7 @@ def main():
_t = time()
_profile = False
for _, train_data in enumerate(tq_ldr):
for train_data in tq_ldr:
if _profile:
print("Data fetch: %f" % (time() - _t))
_t = time()