forked from mrq/DL-Art-School
Fixes to the spsr3
Some lessons learned: - Biases are fairly important as a relief valve. They dont need to be everywhere, but most computationally heavy branches should have a bias. - GroupNorm in SPSR is not a great idea. Since image gradients are represented in this model, normal means and standard deviations are not applicable. (imggrad has a high representation of 0). - Don't fuck with the mainline of any generative model. As much as possible, all additions should be done through residual connections. Never pollute the mainline with reference data, do that in branches. It basically leaves the mode untrainable.
This commit is contained in:
parent
0ffac391c1
commit
747ded2bf7
|
@ -545,27 +545,29 @@ class SwitchedSpsrWithRef2(nn.Module):
|
|||
# Grad branch. Note - groupnorm on this branch is REALLY bad. Avoid it like the plague.
|
||||
self.get_g_nopadding = ImageGradientNoPadding()
|
||||
self.grad_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False)
|
||||
self.ref_join3 = RefJoiner(nf)
|
||||
self.grad_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2, norm=False, final_norm=False)
|
||||
self.sw_grad = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||
attention_norm=True,
|
||||
transform_count=self.transformation_counts // 2, init_temp=init_temperature,
|
||||
add_scalable_noise_to_transforms=False)
|
||||
self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False)
|
||||
self.grad_lr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False)
|
||||
self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
|
||||
self.grad_lr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
|
||||
self.upsample_grad = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=False) for _ in range(n_upscale)])
|
||||
self.grad_branch_output_conv = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=True)
|
||||
|
||||
# Join branch (grad+fea)
|
||||
self.ref_join4 = RefJoiner(nf)
|
||||
self.conjoin_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2, norm=False)
|
||||
self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||
attention_norm=True,
|
||||
transform_count=self.transformation_counts, init_temp=init_temperature,
|
||||
add_scalable_noise_to_transforms=False)
|
||||
self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False)
|
||||
self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=True, activation=True, bias=False) for _ in range(n_upscale)])
|
||||
self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False)
|
||||
self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
|
||||
self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=True) for _ in range(n_upscale)])
|
||||
self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=True)
|
||||
self.final_hr_conv2 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=False)
|
||||
self.switches = [self.sw1, self.sw2, self.sw_grad, self.conjoin_sw]
|
||||
self.attentions = None
|
||||
|
@ -577,27 +579,26 @@ class SwitchedSpsrWithRef2(nn.Module):
|
|||
ref = self.reference_processor(ref, center_coord)
|
||||
|
||||
x = self.model_fea_conv(x)
|
||||
x = self.noise_ref_join(x, torch.randn_like(x))
|
||||
x = self.ref_join1(x, ref)
|
||||
x1, a1 = self.sw1(x, True)
|
||||
x1 = self.noise_ref_join(x, torch.randn_like(x))
|
||||
x1 = self.ref_join1(x1, ref)
|
||||
x1, a1 = self.sw1(x1, True, identity=x)
|
||||
|
||||
x2 = x1
|
||||
x2 = self.ref_join2(x2, ref)
|
||||
x2, a2 = self.sw2(x2, True)
|
||||
x_fea = self.feature_lr_conv(x2)
|
||||
x_fea = self.feature_lr_conv2(x_fea)
|
||||
x2, a2 = self.sw2(x2, True, identity=x1)
|
||||
|
||||
x_grad = self.grad_conv(x_grad)
|
||||
x_grad_identity = self.grad_conv(x_grad)
|
||||
x_grad = self.ref_join3(x_grad_identity, ref)
|
||||
x_grad = self.grad_ref_join(x_grad, x1)
|
||||
x_grad, a3 = self.sw_grad(x_grad, True)
|
||||
x_grad, a3 = self.sw_grad(x_grad, True, identity=x_grad_identity)
|
||||
x_grad = self.grad_lr_conv(x_grad)
|
||||
x_grad = self.grad_lr_conv2(x_grad)
|
||||
x_grad_out = self.upsample_grad(x_grad)
|
||||
x_grad_out = self.grad_branch_output_conv(x_grad_out)
|
||||
|
||||
x_out = x_fea
|
||||
x_out = self.ref_join4(x2, ref)
|
||||
x_out = self.conjoin_ref_join(x_out, x_grad)
|
||||
x_out, a4 = self.conjoin_sw(x_out, True)
|
||||
x_out, a4 = self.conjoin_sw(x_out, True, identity=x2)
|
||||
x_out = self.final_lr_conv(x_out)
|
||||
x_out = self.upsample(x_out)
|
||||
x_out = self.final_hr_conv1(x_out)
|
||||
|
|
|
@ -4,7 +4,7 @@ from switched_conv import BareConvSwitch, compute_attention_specificity, Attenti
|
|||
import torch.nn.functional as F
|
||||
import functools
|
||||
from collections import OrderedDict
|
||||
from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConvGnLelu
|
||||
from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConvGnLelu, MultiConvBlock
|
||||
from switched_conv_util import save_attention_to_image_rgb
|
||||
import os
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
|
|
@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
|
|||
def main():
|
||||
#### options
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/pretrain_spsr_switched2_psnr.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr3_gan.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
|
@ -119,6 +119,7 @@ def main():
|
|||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
# torch.autograd.set_detect_anomaly(True)
|
||||
|
||||
#### create train and val dataloader
|
||||
dataset_ratio = 1 # enlarge the size of each epoch
|
||||
|
@ -173,7 +174,7 @@ def main():
|
|||
|
||||
_t = time()
|
||||
_profile = False
|
||||
for _, train_data in enumerate(tq_ldr):
|
||||
for train_data in tq_ldr:
|
||||
if _profile:
|
||||
print("Data fetch: %f" % (time() - _t))
|
||||
_t = time()
|
||||
|
|
Loading…
Reference in New Issue
Block a user