diff --git a/codes/models/archs/SPSR_arch.py b/codes/models/archs/SPSR_arch.py index 307d8078..93daa7f5 100644 --- a/codes/models/archs/SPSR_arch.py +++ b/codes/models/archs/SPSR_arch.py @@ -545,27 +545,29 @@ class SwitchedSpsrWithRef2(nn.Module): # Grad branch. Note - groupnorm on this branch is REALLY bad. Avoid it like the plague. self.get_g_nopadding = ImageGradientNoPadding() self.grad_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False) + self.ref_join3 = RefJoiner(nf) self.grad_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2, norm=False, final_norm=False) self.sw_grad = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, pre_transform_block=pretransform_fn, transform_block=transform_fn, attention_norm=True, transform_count=self.transformation_counts // 2, init_temp=init_temperature, add_scalable_noise_to_transforms=False) - self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False) - self.grad_lr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False) + self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) + self.grad_lr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) self.upsample_grad = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=False) for _ in range(n_upscale)]) self.grad_branch_output_conv = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=True) # Join branch (grad+fea) + self.ref_join4 = RefJoiner(nf) self.conjoin_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2, norm=False) self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, pre_transform_block=pretransform_fn, transform_block=transform_fn, attention_norm=True, transform_count=self.transformation_counts, init_temp=init_temperature, add_scalable_noise_to_transforms=False) - self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False) - self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=True, activation=True, bias=False) for _ in range(n_upscale)]) - self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False) + self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) + self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=True) for _ in range(n_upscale)]) + self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=True) self.final_hr_conv2 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=False) self.switches = [self.sw1, self.sw2, self.sw_grad, self.conjoin_sw] self.attentions = None @@ -577,27 +579,26 @@ class SwitchedSpsrWithRef2(nn.Module): ref = self.reference_processor(ref, center_coord) x = self.model_fea_conv(x) - x = self.noise_ref_join(x, torch.randn_like(x)) - x = self.ref_join1(x, ref) - x1, a1 = self.sw1(x, True) + x1 = self.noise_ref_join(x, torch.randn_like(x)) + x1 = self.ref_join1(x1, ref) + x1, a1 = self.sw1(x1, True, identity=x) x2 = x1 x2 = self.ref_join2(x2, ref) - x2, a2 = self.sw2(x2, True) - x_fea = self.feature_lr_conv(x2) - x_fea = self.feature_lr_conv2(x_fea) + x2, a2 = self.sw2(x2, True, identity=x1) - x_grad = self.grad_conv(x_grad) + x_grad_identity = self.grad_conv(x_grad) + x_grad = self.ref_join3(x_grad_identity, ref) x_grad = self.grad_ref_join(x_grad, x1) - x_grad, a3 = self.sw_grad(x_grad, True) + x_grad, a3 = self.sw_grad(x_grad, True, identity=x_grad_identity) x_grad = self.grad_lr_conv(x_grad) x_grad = self.grad_lr_conv2(x_grad) x_grad_out = self.upsample_grad(x_grad) x_grad_out = self.grad_branch_output_conv(x_grad_out) - x_out = x_fea + x_out = self.ref_join4(x2, ref) x_out = self.conjoin_ref_join(x_out, x_grad) - x_out, a4 = self.conjoin_sw(x_out, True) + x_out, a4 = self.conjoin_sw(x_out, True, identity=x2) x_out = self.final_lr_conv(x_out) x_out = self.upsample(x_out) x_out = self.final_hr_conv1(x_out) diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index e6f07544..ed82d015 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -4,7 +4,7 @@ from switched_conv import BareConvSwitch, compute_attention_specificity, Attenti import torch.nn.functional as F import functools from collections import OrderedDict -from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConvGnLelu +from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConvGnLelu, MultiConvBlock from switched_conv_util import save_attention_to_image_rgb import os from torch.utils.checkpoint import checkpoint diff --git a/codes/train.py b/codes/train.py index 661e9b92..e7e8c25c 100644 --- a/codes/train.py +++ b/codes/train.py @@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/pretrain_spsr_switched2_psnr.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr3_gan.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) @@ -119,6 +119,7 @@ def main(): torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True + # torch.autograd.set_detect_anomaly(True) #### create train and val dataloader dataset_ratio = 1 # enlarge the size of each epoch @@ -173,7 +174,7 @@ def main(): _t = time() _profile = False - for _, train_data in enumerate(tq_ldr): + for train_data in tq_ldr: if _profile: print("Data fetch: %f" % (time() - _t)) _t = time()