From b905b108da9c7c72ec91007bba19211414169471 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 18 Dec 2020 09:10:44 -0700 Subject: [PATCH] Large cleanup Removed a lot of old code that I won't be touching again. Refactored some code elements into more logical places. --- .gitmodules | 3 + codes/models/ExtensibleTrainer.py | 4 +- codes/models/archs/SPSR_arch.py | 668 ------------------ codes/models/archs/SPSR_util.py | 163 ----- codes/models/archs/SRResNet_arch.py | 55 -- .../archs/biggan/biggan_discriminator.py | 139 ---- codes/models/archs/biggan/biggan_layers.py | 457 ------------ .../models/{steps => archs/byol}/__init__.py | 0 .../{ => archs}/byol/byol_model_wrapper.py | 0 .../{ => archs}/byol/byol_structural.py | 5 +- codes/models/archs/flownet2 | 1 + codes/models/archs/lambda_rrdb.py | 47 -- codes/models/archs/multi_res_rrdb.py | 206 ------ codes/models/archs/pyramid_arch.py | 98 --- codes/models/archs/pytorch_ssim.py | 80 --- codes/models/archs/rcan.py | 221 ------ codes/models/archs/tecogan/__init__.py | 0 .../models/archs/{ => tecogan}/teco_resgen.py | 0 codes/models/archs/transformers/igpt/gpt2.py | 3 +- .../custom_training_components/__init__.py | 0 .../progressive_zoom.py | 5 +- .../stereoscopic.py | 6 +- .../tecogan_losses.py | 8 +- codes/models/{steps => }/injectors.py | 74 +- codes/models/{steps => }/losses.py | 3 +- codes/models/networks.py | 52 +- codes/models/{steps => }/steps.py | 6 +- codes/scripts/use_generator_as_filter.py | 3 - codes/utils/onnx_inference.py | 22 - 29 files changed, 31 insertions(+), 2298 deletions(-) delete mode 100644 codes/models/archs/SPSR_arch.py delete mode 100644 codes/models/archs/SPSR_util.py delete mode 100644 codes/models/archs/SRResNet_arch.py delete mode 100644 codes/models/archs/biggan/biggan_discriminator.py delete mode 100644 codes/models/archs/biggan/biggan_layers.py rename codes/models/{steps => archs/byol}/__init__.py (100%) rename codes/models/{ => archs}/byol/byol_model_wrapper.py (100%) rename codes/models/{ => archs}/byol/byol_structural.py (97%) create mode 160000 codes/models/archs/flownet2 delete mode 100644 codes/models/archs/lambda_rrdb.py delete mode 100644 codes/models/archs/multi_res_rrdb.py delete mode 100644 codes/models/archs/pyramid_arch.py delete mode 100644 codes/models/archs/pytorch_ssim.py delete mode 100644 codes/models/archs/rcan.py create mode 100644 codes/models/archs/tecogan/__init__.py rename codes/models/archs/{ => tecogan}/teco_resgen.py (100%) create mode 100644 codes/models/custom_training_components/__init__.py rename codes/models/{steps => custom_training_components}/progressive_zoom.py (97%) rename codes/models/{steps => custom_training_components}/stereoscopic.py (90%) rename codes/models/{steps => custom_training_components}/tecogan_losses.py (98%) rename codes/models/{steps => }/injectors.py (83%) rename codes/models/{steps => }/losses.py (99%) rename codes/models/{steps => }/steps.py (98%) delete mode 100644 codes/utils/onnx_inference.py diff --git a/.gitmodules b/.gitmodules index 3f14c362..92c025e8 100644 --- a/.gitmodules +++ b/.gitmodules @@ -8,3 +8,6 @@ path = codes/models/flownet2 url = https://github.com/neonbjb/flownet2-pytorch.git branch = master +[submodule "codes/models/archs/flownet2"] + path = codes/models/archs/flownet2 + url = https://github.com/neonbjb/flownet2-pytorch.git diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index 9e38e02e..9b62576d 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -8,8 +8,8 @@ import torch.nn as nn import models.lr_scheduler as lr_scheduler import models.networks as networks from models.base_model import BaseModel -from models.steps.injectors import create_injector -from models.steps.steps import ConfigurableStep +from models.injectors import create_injector +from models.steps import ConfigurableStep from models.experiments.experiments import get_experiment_for_name import torchvision.utils as utils diff --git a/codes/models/archs/SPSR_arch.py b/codes/models/archs/SPSR_arch.py deleted file mode 100644 index 9d2ec8ea..00000000 --- a/codes/models/archs/SPSR_arch.py +++ /dev/null @@ -1,668 +0,0 @@ -import functools -import os - -import math -import torch -import torch.nn as nn -import torch.nn.functional as F -import torchvision -from utils.util import checkpoint - -from models.archs import SPSR_util as B -from models.archs.SwitchedResidualGenerator_arch import ConfigurableSwitchComputer, ReferenceImageBranch, \ - QueryKeyMultiplexer, QueryKeyPyramidMultiplexer, ConvBasisMultiplexer -from models.archs.arch_util import ConvGnLelu, UpconvBlock, MultiConvBlock, ReferenceJoinBlock -from switched_conv.switched_conv import compute_attention_specificity -from switched_conv.switched_conv_util import save_attention_to_image_rgb -from .RRDBNet_arch import RRDB - - -class ImageGradient(nn.Module): - def __init__(self): - super(ImageGradient, self).__init__() - kernel_v = [[0, -1, 0], - [0, 0, 0], - [0, 1, 0]] - kernel_h = [[0, 0, 0], - [-1, 0, 1], - [0, 0, 0]] - kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0) - kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0) - self.weight_h = nn.Parameter(data = kernel_h, requires_grad = False).cuda() - self.weight_v = nn.Parameter(data = kernel_v, requires_grad = False).cuda() - - def forward(self, x): - x0 = x[:, 0] - x1 = x[:, 1] - x2 = x[:, 2] - x0_v = F.conv2d(x0.unsqueeze(1), self.weight_v, padding=2) - x0_h = F.conv2d(x0.unsqueeze(1), self.weight_h, padding=2) - - x1_v = F.conv2d(x1.unsqueeze(1), self.weight_v, padding=2) - x1_h = F.conv2d(x1.unsqueeze(1), self.weight_h, padding=2) - - x2_v = F.conv2d(x2.unsqueeze(1), self.weight_v, padding=2) - x2_h = F.conv2d(x2.unsqueeze(1), self.weight_h, padding=2) - - x0 = torch.sqrt(torch.pow(x0_v, 2) + torch.pow(x0_h, 2) + 1e-6) - x1 = torch.sqrt(torch.pow(x1_v, 2) + torch.pow(x1_h, 2) + 1e-6) - x2 = torch.sqrt(torch.pow(x2_v, 2) + torch.pow(x2_h, 2) + 1e-6) - - x = torch.cat([x0, x1, x2], dim=1) - return x - - -class ImageGradientNoPadding(nn.Module): - def __init__(self): - super(ImageGradientNoPadding, self).__init__() - kernel_v = [[0, -1, 0], - [0, 0, 0], - [0, 1, 0]] - kernel_h = [[0, 0, 0], - [-1, 0, 1], - [0, 0, 0]] - kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0) - kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0) - self.weight_h = nn.Parameter(data = kernel_h, requires_grad = False) - - self.weight_v = nn.Parameter(data = kernel_v, requires_grad = False) - - - def forward(self, x): - x = x.float() - x_list = [] - for i in range(x.shape[1]): - x_i = x[:, i] - x_i_v = F.conv2d(x_i.unsqueeze(1), self.weight_v, padding=1) - x_i_h = F.conv2d(x_i.unsqueeze(1), self.weight_h, padding=1) - x_i = torch.sqrt(torch.pow(x_i_v, 2) + torch.pow(x_i_h, 2) + 1e-6) - x_list.append(x_i) - - x = torch.cat(x_list, dim = 1) - - return x - - -#################### -# Generator -#################### - -class SPSRNet(nn.Module): - def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, \ - act_type='leakyrelu', mode='CNA', upsample_mode='upconv'): - super(SPSRNet, self).__init__() - - n_upscale = int(math.log(upscale, 2)) - - self.scale=upscale - if upscale == 3: - n_upscale = 1 - - fea_conv = ConvGnLelu(in_nc, nf//2, kernel_size=7, norm=False, activation=False) - self.ref_conv = ConvGnLelu(in_nc, nf//2, stride=upscale, kernel_size=7, norm=False, activation=False) - self.join_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False) - rb_blocks = [RRDB(nf) for _ in range(nb)] - - LR_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False) - - upsample_block = UpconvBlock - if upscale == 3: - upsampler = upsample_block(nf, nf, activation=True) - else: - upsampler = [upsample_block(nf, nf, activation=True) for _ in range(n_upscale)] - - self.HR_conv0_new = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True) - self.HR_conv1_new = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False) - - self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)), \ - *upsampler, self.HR_conv0_new) - - self.b_fea_conv = ConvGnLelu(in_nc, nf//2, kernel_size=3, norm=False, activation=False) - self.b_ref_conv = ConvGnLelu(in_nc, nf//2, stride=upscale, kernel_size=3, norm=False, activation=False) - self.b_join_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False) - - self.b_concat_1 = ConvGnLelu(2 * nf, nf, kernel_size=3, norm=False, activation=False) - self.b_block_1 = RRDB(nf * 2) - - self.b_concat_2 = ConvGnLelu(2 * nf, nf, kernel_size=3, norm=False, activation=False) - self.b_block_2 = RRDB(nf * 2) - - self.b_concat_3 = ConvGnLelu(2 * nf, nf, kernel_size=3, norm=False, activation=False) - self.b_block_3 = RRDB(nf * 2) - - self.b_concat_4 = ConvGnLelu(2 * nf, nf, kernel_size=3, norm=False, activation=False) - self.b_block_4 = RRDB(nf * 2) - - self.b_LR_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False) - - if upscale == 3: - b_upsampler = UpconvBlock(nf, nf, activation=True) - else: - b_upsampler = [UpconvBlock(nf, nf, activation=True) for _ in range(n_upscale)] - - b_HR_conv0 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True) - b_HR_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False) - - self.b_module = B.sequential(*b_upsampler, b_HR_conv0, b_HR_conv1) - - self.conv_w = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False) - - self.f_concat = ConvGnLelu(nf * 2, nf, kernel_size=3, norm=False, activation=False) - - self.f_block = RRDB(nf * 2) - - self.f_HR_conv0 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True) - self.f_HR_conv1 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False) - - self.get_g_nopadding = ImageGradientNoPadding() - - def bl1(self, x): - block_list = self.model[1].sub - for i in range(5): - x = block_list[i](x) - return x - - def bl2(self, x): - block_list = self.model[1].sub - for i in range(5): - x = block_list[i+5](x) - return x - - def bl3(self, x): - block_list = self.model[1].sub - for i in range(5): - x = block_list[i+10](x) - return x - - def bl4(self, x): - block_list = self.model[1].sub - for i in range(5): - x = block_list[i+15](x) - return x - - def bl5(self, x): - block_list = self.model[1].sub - x = block_list[20:](x) - return x - - def bl6(self, x_ori, x): - x = x_ori + x - x = self.model[2:](x) - x = self.HR_conv1_new(x) - return x - - def branch_bl1(self, x_grad, ref_grad): - x_b_fea = self.b_fea_conv(x_grad) - x_b_ref = self.b_ref_conv(ref_grad) - x_b_fea = self.b_join_conv(torch.cat([x_b_fea, x_b_ref], dim=1)) - return x_b_fea - - def branch_bl2(self, x_b_fea, x_fea1): - x_cat_1 = torch.cat([x_b_fea, x_fea1], dim=1) - x_cat_1 = self.b_block_1(x_cat_1) - x_cat_1 = self.b_concat_1(x_cat_1) - return x_cat_1 - - def branch_bl3(self, x_cat_1, x_fea2): - x_cat_2 = torch.cat([x_cat_1, x_fea2], dim=1) - x_cat_2 = self.b_block_2(x_cat_2) - x_cat_2 = self.b_concat_2(x_cat_2) - return x_cat_2 - - def branch_bl4(self, x_cat_2, x_fea3): - x_cat_3 = torch.cat([x_cat_2, x_fea3], dim=1) - x_cat_3 = self.b_block_3(x_cat_3) - x_cat_3 = self.b_concat_3(x_cat_3) - return x_cat_3 - - def branch_bl5(self, x_cat_3, x_fea4, x_b_fea): - x_cat_4 = torch.cat([x_cat_3, x_fea4], dim=1) - x_cat_4 = self.b_block_4(x_cat_4) - x_cat_4 = self.b_concat_4(x_cat_4) - x_cat_4 = self.b_LR_conv(x_cat_4) - x_cat_4 = x_cat_4 + x_b_fea - x_branch = self.b_module(x_cat_4) - return x_branch - - def forward(self, x, ref=None): - b,f,h,w = x.shape - if ref is None: - ref = torch.zeros((b,f,h*self.scale,w*self.scale), device=x.device, dtype=x.dtype) - - x_grad = self.get_g_nopadding(x) - ref_grad = self.get_g_nopadding(ref) - x = self.model[0](x) - x_ref = self.ref_conv(ref) - x = self.join_conv(torch.cat([x, x_ref], dim=1)) - - x, block_list = self.model[1](x) - x_ori = x - x = checkpoint(self.bl1, x) - x_fea1 = x - x = checkpoint(self.bl2, x) - x_fea2 = x - x = checkpoint(self.bl3, x) - x_fea3 = x - x = checkpoint(self.bl4, x) - x_fea4 = x - x = checkpoint(self.bl5, x) - x = checkpoint(self.bl6, x_ori, x) - - x_b_fea = checkpoint(self.branch_bl1, x_grad, ref_grad) - x_cat_1 = checkpoint(self.branch_bl2, x_b_fea, x_fea1) - x_cat_2 = checkpoint(self.branch_bl3, x_cat_1, x_fea2) - x_cat_3 = checkpoint(self.branch_bl4, x_cat_2, x_fea3) - x_branch = checkpoint(self.branch_bl5, x_cat_3, x_fea4, x_b_fea) - - x_out_branch = checkpoint(self.conv_w, x_branch) - ######## - x_branch_d = x_branch - x_f_cat = torch.cat([x_branch_d, x], dim=1) - x_f_cat = checkpoint(self.f_block, x_f_cat) - x_out = self.f_concat(x_f_cat) - x_out = checkpoint(self.f_HR_conv0, x_out) - x_out = checkpoint(self.f_HR_conv1, x_out) - - ######### - return x_out_branch, x_out, x_grad - - -class SPSRNetSimplified(nn.Module): - def __init__(self, in_nc, out_nc, nf, nb, upscale=4): - super(SPSRNetSimplified, self).__init__() - n_upscale = int(math.log(upscale, 2)) - - # Feature branch - self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False) - self.model_shortcut_blk = nn.Sequential(*[RRDB(nf, gc=32) for _ in range(nb)]) - self.feature_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False) - self.model_upsampler = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=False) for _ in range(n_upscale)]) - self.feature_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False) - self.feature_hr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False) - - # Grad branch - self.get_g_nopadding = ImageGradientNoPadding() - self.b_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False) - self.b_concat_decimate_1 = ConvGnLelu(2 * nf, nf, kernel_size=1, norm=False, activation=False, bias=False) - self.b_proc_block_1 = RRDB(nf, gc=32) - self.b_concat_decimate_2 = ConvGnLelu(2 * nf, nf, kernel_size=1, norm=False, activation=False, bias=False) - self.b_proc_block_2 = RRDB(nf, gc=32) - self.b_concat_decimate_3 = ConvGnLelu(2 * nf, nf, kernel_size=1, norm=False, activation=False, bias=False) - self.b_proc_block_3 = RRDB(nf, gc=32) - self.b_concat_decimate_4 = ConvGnLelu(2 * nf, nf, kernel_size=1, norm=False, activation=False, bias=False) - self.b_proc_block_4 = RRDB(nf, gc=32) - - # Upsampling - self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False) - b_upsampler = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=False) for _ in range(n_upscale)]) - grad_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False) - grad_hr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False) - self.branch_upsample = B.sequential(*b_upsampler, grad_hr_conv1, grad_hr_conv2) - # Conv used to output grad branch shortcut. - self.grad_branch_output_conv = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=False) - - # Conjoin branch. - # Note: "_branch_pretrain" is a special tag used to denote parameters that get pretrained before the rest. - self._branch_pretrain_concat = ConvGnLelu(nf * 2, nf, kernel_size=1, norm=False, activation=False, bias=False) - self._branch_pretrain_block = RRDB(nf * 2, gc=32) - self._branch_pretrain_HR_conv0 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False) - self._branch_pretrain_HR_conv1 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=False) - - def forward(self, x): - - x_grad = self.get_g_nopadding(x) - x = self.model_fea_conv(x) - - x_ori = x - for i in range(5): - x = self.model_shortcut_blk[i](x) - x_fea1 = x - - for i in range(5): - x = self.model_shortcut_blk[i + 5](x) - x_fea2 = x - - for i in range(5): - x = self.model_shortcut_blk[i + 10](x) - x_fea3 = x - - for i in range(5): - x = self.model_shortcut_blk[i + 15](x) - x_fea4 = x - - x = self.model_shortcut_blk[20:](x) - x = self.feature_lr_conv(x) - - # short cut - x = x_ori + x - x = self.model_upsampler(x) - x = self.feature_hr_conv1(x) - x = self.feature_hr_conv2(x) - - x_b_fea = self.b_fea_conv(x_grad) - x_cat_1 = torch.cat([x_b_fea, x_fea1], dim=1) - - x_cat_1 = self.b_concat_decimate_1(x_cat_1) - x_cat_1 = self.b_proc_block_1(x_cat_1) - - x_cat_2 = torch.cat([x_cat_1, x_fea2], dim=1) - - x_cat_2 = self.b_concat_decimate_2(x_cat_2) - x_cat_2 = self.b_proc_block_2(x_cat_2) - - x_cat_3 = torch.cat([x_cat_2, x_fea3], dim=1) - - x_cat_3 = self.b_concat_decimate_3(x_cat_3) - x_cat_3 = self.b_proc_block_3(x_cat_3) - - x_cat_4 = torch.cat([x_cat_3, x_fea4], dim=1) - - x_cat_4 = self.b_concat_decimate_4(x_cat_4) - x_cat_4 = self.b_proc_block_4(x_cat_4) - - x_cat_4 = self.grad_lr_conv(x_cat_4) - - # short cut - x_cat_4 = x_cat_4 + x_b_fea - x_branch = self.branch_upsample(x_cat_4) - x_out_branch = self.grad_branch_output_conv(x_branch) - - ######## - x_branch_d = x_branch - x__branch_pretrain_cat = torch.cat([x_branch_d, x], dim=1) - x__branch_pretrain_cat = self._branch_pretrain_block(x__branch_pretrain_cat) - x_out = self._branch_pretrain_concat(x__branch_pretrain_cat) - x_out = self._branch_pretrain_HR_conv0(x_out) - x_out = self._branch_pretrain_HR_conv1(x_out) - - ######### - return x_out_branch, x_out, x_grad - - -# Variant of Spsr6 which uses multiplexer blocks that feed off of a reference embedding. Also computes that embedding. -class Spsr7(nn.Module): - def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, multiplexer_reductions=3, recurrent=False, init_temperature=10): - super(Spsr7, self).__init__() - n_upscale = int(math.log(upscale, 2)) - - # processing the input embedding - self.reference_embedding = ReferenceImageBranch(nf) - - self.recurrent = recurrent - if recurrent: - self.model_recurrent_conv = ConvGnLelu(3, nf, kernel_size=3, stride=2, norm=False, activation=False, - bias=True) - self.model_fea_recurrent_combine = ConvGnLelu(nf * 2, nf, 1, activation=False, norm=False, bias=False, weight_init_factor=.01) - - # switch options - self.nf = nf - transformation_filters = nf - self.transformation_counts = xforms - multiplx_fn = functools.partial(QueryKeyMultiplexer, transformation_filters, embedding_channels=512, reductions=multiplexer_reductions) - transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5), - transformation_filters, kernel_size=3, depth=3, - weight_init_factor=.1) - - # Feature branch - self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=7, norm=False, activation=False) - self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, - pre_transform_block=None, transform_block=transform_fn, - attention_norm=True, - transform_count=self.transformation_counts, init_temp=init_temperature, - add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True) - self.sw2 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, - pre_transform_block=None, transform_block=transform_fn, - attention_norm=True, - transform_count=self.transformation_counts, init_temp=init_temperature, - add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True) - - # Grad branch. Note - groupnorm on this branch is REALLY bad. Avoid it like the plague. - self.get_g_nopadding = ImageGradientNoPadding() - self.grad_conv = ConvGnLelu(in_nc, nf, kernel_size=7, norm=False, activation=False, bias=False) - self.grad_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.3, final_norm=False) - - self.sw_grad = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, - pre_transform_block=None, transform_block=transform_fn, - attention_norm=True, - transform_count=self.transformation_counts // 2, init_temp=init_temperature, - add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True) - self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) - self.grad_lr_conv2 = ConvGnLelu(nf, nf, kernel_size=1, norm=False, activation=True, bias=True) - self.upsample_grad = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=False) for _ in range(n_upscale)]) - self.grad_branch_output_conv = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=True) - - # Join branch (grad+fea) - self.noise_ref_join_conjoin = ReferenceJoinBlock(nf, residual_weight_init_factor=.1) - self.conjoin_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.3) - self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, - pre_transform_block=None, transform_block=transform_fn, - attention_norm=True, - transform_count=self.transformation_counts, init_temp=init_temperature, - add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True) - self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) - self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=True) for _ in range(n_upscale)]) - self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=True) - self.final_hr_conv2 = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=False) - self.switches = [self.sw1, self.sw2, self.sw_grad, self.conjoin_sw] - self.attentions = None - self.init_temperature = init_temperature - self.final_temperature_step = 10000 - self.lr = None - - def forward(self, x, ref, ref_center, update_attention_norm=True, recurrent=None): - # The attention_maps debugger outputs . Save that here. - self.lr = x.detach().cpu() - - x_grad = self.get_g_nopadding(x) - ref_code = self.reference_embedding(ref, ref_center) - ref_embedding = ref_code.view(-1, self.nf * 8, 1, 1).repeat(1, 1, x.shape[2] // 8, x.shape[3] // 8) - - x = self.model_fea_conv(x) - if self.recurrent: - rec = self.model_recurrent_conv(recurrent) - br = self.model_fea_recurrent_combine(torch.cat([x, rec], dim=1)) - x = x + br - - x1 = x - x1, a1 = self.sw1(x1, identity=x, att_in=(x1, ref_embedding), do_checkpointing=True) - - x2 = x1 - x2, a2 = self.sw2(x2, identity=x1, att_in=(x2, ref_embedding), do_checkpointing=True) - - x_grad = self.grad_conv(x_grad) - x_grad_identity = x_grad - x_grad, grad_fea_std = checkpoint(self.grad_ref_join, x_grad, x1) - x_grad, a3 = self.sw_grad(x_grad, identity=x_grad_identity, att_in=(x_grad, ref_embedding), do_checkpointing=True) - x_grad = checkpoint(self.grad_lr_conv, x_grad) - x_grad = checkpoint(self.grad_lr_conv2, x_grad) - x_grad_out = checkpoint(self.upsample_grad, x_grad) - x_grad_out = checkpoint(self.grad_branch_output_conv, x_grad_out) - - x_out = x2 - x_out, fea_grad_std = self.conjoin_ref_join(x_out, x_grad) - x_out, a4 = self.conjoin_sw(x_out, identity=x2, att_in=(x_out, ref_embedding), do_checkpointing=True) - x_out = checkpoint(self.final_lr_conv, x_out) - x_out = checkpoint(self.upsample, x_out) - x_out = checkpoint(self.final_hr_conv1, x_out) - x_out = checkpoint(self.final_hr_conv2, x_out) - - self.attentions = [a1, a2, a3, a4] - self.grad_fea_std = grad_fea_std.detach().cpu() - self.fea_grad_std = fea_grad_std.detach().cpu() - return x_grad_out, x_out - - def set_temperature(self, temp): - [sw.set_temperature(temp) for sw in self.switches] - - def update_for_step(self, step, experiments_path='.'): - if self.attentions: - temp = max(1, 1 + self.init_temperature * - (self.final_temperature_step - step) / self.final_temperature_step) - self.set_temperature(temp) - if step % 500 == 0: - output_path = os.path.join(experiments_path, "attention_maps") - prefix = "amap_%i_a%i_%%i.png" - [save_attention_to_image_rgb(output_path, self.attentions[i], self.transformation_counts, prefix % (step, i), step, output_mag=False) for i in range(len(self.attentions))] - torchvision.utils.save_image(self.lr, os.path.join(experiments_path, "attention_maps", "amap_%i_base_image.png" % (step,))) - - def get_debug_values(self, step, net_name): - temp = self.switches[0].switch.temperature - mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions] - means = [i[0] for i in mean_hists] - hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists] - val = {"switch_temperature": temp, - "grad_branch_feat_intg_std_dev": self.grad_fea_std, - "conjoin_branch_grad_intg_std_dev": self.fea_grad_std} - for i in range(len(means)): - val["switch_%i_specificity" % (i,)] = means[i] - val["switch_%i_histogram" % (i,)] = hists[i] - return val - - -class AttentionBlock(nn.Module): - def __init__(self, nf, num_transforms, multiplexer_reductions, init_temperature=10, has_ref=True): - super(AttentionBlock, self).__init__() - self.nf = nf - self.transformation_counts = num_transforms - multiplx_fn = functools.partial(QueryKeyMultiplexer, nf, embedding_channels=512, reductions=multiplexer_reductions) - transform_fn = functools.partial(MultiConvBlock, nf, int(nf * 1.5), - nf, kernel_size=3, depth=4, - weight_init_factor=.1) - if has_ref: - self.ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.3, final_norm=False) - else: - self.ref_join = None - self.switch = ConfigurableSwitchComputer(nf, multiplx_fn, - pre_transform_block=None, transform_block=transform_fn, - attention_norm=True, - transform_count=self.transformation_counts, init_temp=init_temperature, - add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True) - - def forward(self, x, mplex_ref=None, ref=None): - if self.ref_join is not None: - branch, ref_std = self.ref_join(x, ref) - return self.switch(branch, identity=x, att_in=(branch, mplex_ref)) + (ref_std,) - else: - return self.switch(x, identity=x, att_in=(x, mplex_ref)) - - -class SwitchedSpsr(nn.Module): - def __init__(self, in_nc, nf, xforms=8, upscale=4, init_temperature=10): - super(SwitchedSpsr, self).__init__() - n_upscale = int(math.log(upscale, 2)) - - # switch options - transformation_filters = nf - switch_filters = nf - switch_reductions = 3 - switch_processing_layers = 2 - self.transformation_counts = xforms - multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, switch_reductions, - switch_processing_layers, self.transformation_counts, use_exp2=True) - pretransform_fn = functools.partial(ConvGnLelu, transformation_filters, transformation_filters, norm=False, bias=False, weight_init_factor=.1) - transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5), - transformation_filters, kernel_size=3, depth=3, - weight_init_factor=.1) - - # Feature branch - self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False) - self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, - pre_transform_block=pretransform_fn, transform_block=transform_fn, - attention_norm=True, - transform_count=self.transformation_counts, init_temp=init_temperature, - add_scalable_noise_to_transforms=True) - self.sw2 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, - pre_transform_block=pretransform_fn, transform_block=transform_fn, - attention_norm=True, - transform_count=self.transformation_counts, init_temp=init_temperature, - add_scalable_noise_to_transforms=True) - self.feature_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False) - self.feature_hr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False) - - # Grad branch - self.get_g_nopadding = ImageGradientNoPadding() - self.b_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False) - mplex_grad = functools.partial(ConvBasisMultiplexer, nf * 2, nf * 2, switch_reductions, - switch_processing_layers, self.transformation_counts // 2, use_exp2=True) - self.sw_grad = ConfigurableSwitchComputer(transformation_filters, mplex_grad, - pre_transform_block=pretransform_fn, transform_block=transform_fn, - attention_norm=True, - transform_count=self.transformation_counts // 2, init_temp=init_temperature, - add_scalable_noise_to_transforms=True) - # Upsampling - self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False) - self.grad_hr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False) - # Conv used to output grad branch shortcut. - self.grad_branch_output_conv = ConvGnLelu(nf, 3, kernel_size=1, norm=False, activation=False, bias=False) - - # Conjoin branch. - # Note: "_branch_pretrain" is a special tag used to denote parameters that get pretrained before the rest. - transform_fn_cat = functools.partial(MultiConvBlock, transformation_filters * 2, int(transformation_filters * 1.5), - transformation_filters, kernel_size=3, depth=4, - weight_init_factor=.1) - pretransform_fn_cat = functools.partial(ConvGnLelu, transformation_filters * 2, transformation_filters * 2, norm=False, bias=False, weight_init_factor=.1) - self._branch_pretrain_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, - pre_transform_block=pretransform_fn_cat, transform_block=transform_fn_cat, - attention_norm=True, - transform_count=self.transformation_counts, init_temp=init_temperature, - add_scalable_noise_to_transforms=True) - self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=True) for _ in range(n_upscale)]) - self.upsample_grad = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=True) for _ in range(n_upscale)]) - self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False) - self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=True) - self.final_hr_conv2 = ConvGnLelu(nf, 3, kernel_size=3, norm=False, activation=False, bias=False) - self.switches = [self.sw1, self.sw2, self.sw_grad, self._branch_pretrain_sw] - self.attentions = None - self.init_temperature = init_temperature - self.final_temperature_step = 10000 - - def forward(self, x): - x_grad = self.get_g_nopadding(x) - x = self.model_fea_conv(x) - - x1, a1 = self.sw1(x, do_checkpointing=True) - x2, a2 = self.sw2(x1, do_checkpointing=True) - x_fea = self.feature_lr_conv(x2) - x_fea = self.feature_hr_conv2(x_fea) - - x_b_fea = self.b_fea_conv(x_grad) - x_grad, a3 = self.sw_grad(x_b_fea, att_in=torch.cat([x1, x_b_fea], dim=1), output_attention_weights=True, do_checkpointing=True) - x_grad = checkpoint(self.grad_lr_conv, x_grad) - x_grad = checkpoint(self.grad_hr_conv, x_grad) - x_out_branch = checkpoint(self.upsample_grad, x_grad) - x_out_branch = self.grad_branch_output_conv(x_out_branch) - - x__branch_pretrain_cat = torch.cat([x_grad, x_fea], dim=1) - x__branch_pretrain_cat, a4 = self._branch_pretrain_sw(x__branch_pretrain_cat, att_in=x_fea, identity=x_fea, output_attention_weights=True) - x_out = checkpoint(self.final_lr_conv, x__branch_pretrain_cat) - x_out = checkpoint(self.upsample, x_out) - x_out = checkpoint(self.final_hr_conv1, x_out) - x_out = self.final_hr_conv2(x_out) - - self.attentions = [a1, a2, a3, a4] - - return x_out_branch, x_out, x_grad - - def set_temperature(self, temp): - [sw.set_temperature(temp) for sw in self.switches] - - def update_for_step(self, step, experiments_path='.'): - if self.attentions: - temp = max(1, 1 + self.init_temperature * - (self.final_temperature_step - step) / self.final_temperature_step) - self.set_temperature(temp) - if step % 200 == 0: - output_path = os.path.join(experiments_path, "attention_maps", "a%i") - prefix = "attention_map_%i_%%i.png" % (step,) - [save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))] - - def get_debug_values(self, step, net): - temp = self.switches[0].switch.temperature - mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions] - means = [i[0] for i in mean_hists] - hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists] - val = {"switch_temperature": temp} - for i in range(len(means)): - val["switch_%i_specificity" % (i,)] = means[i] - val["switch_%i_histogram" % (i,)] = hists[i] - return val diff --git a/codes/models/archs/SPSR_util.py b/codes/models/archs/SPSR_util.py deleted file mode 100644 index df40325f..00000000 --- a/codes/models/archs/SPSR_util.py +++ /dev/null @@ -1,163 +0,0 @@ -from collections import OrderedDict -import torch -import torch.nn as nn - -#################### -# Basic blocks -#################### -def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1): - # helper selecting activation - # neg_slope: for leakyrelu and init of prelu - # n_prelu: for p_relu num_parameters - act_type = act_type.lower() - if act_type == 'relu': - layer = nn.ReLU(inplace) - elif act_type == 'leakyrelu': - layer = nn.LeakyReLU(neg_slope, inplace) - elif act_type == 'prelu': - layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) - else: - raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type)) - return layer - -def norm(norm_type, nc): - # helper selecting normalization layer - norm_type = norm_type.lower() - if norm_type == 'batch': - layer = nn.BatchNorm2d(nc, affine=True) - elif norm_type == 'instance': - layer = nn.InstanceNorm2d(nc, affine=False) - elif norm_type == 'group': - layer = nn.GroupNorm(8, nc) - else: - raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type)) - return layer - -def pad(pad_type, padding): - # helper selecting padding layer - # if padding is 'zero', do by conv layers - pad_type = pad_type.lower() - if padding == 0: - return None - if pad_type == 'reflect': - layer = nn.ReflectionPad2d(padding) - elif pad_type == 'replicate': - layer = nn.ReplicationPad2d(padding) - else: - raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type)) - return layer - - -def get_valid_padding(kernel_size, dilation): - kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1) - padding = (kernel_size - 1) // 2 - return padding - - -class ConcatBlock(nn.Module): - # Concat the output of a submodule to its input - def __init__(self, submodule): - super(ConcatBlock, self).__init__() - self.sub = submodule - - def forward(self, x): - output = torch.cat((x, self.sub(x)), dim=1) - return output - - def __repr__(self): - tmpstr = 'Identity .. \n|' - modstr = self.sub.__repr__().replace('\n', '\n|') - tmpstr = tmpstr + modstr - return tmpstr - - -class ShortcutBlock(nn.Module): - #Elementwise sum the output of a submodule to its input - def __init__(self, submodule): - super(ShortcutBlock, self).__init__() - self.sub = submodule - - def forward(self, x): - return x, self.sub - - def __repr__(self): - tmpstr = 'Identity + \n|' - modstr = self.sub.__repr__().replace('\n', '\n|') - tmpstr = tmpstr + modstr - return tmpstr - - -def sequential(*args): - # Flatten Sequential. It unwraps nn.Sequential. - if len(args) == 1: - if isinstance(args[0], OrderedDict): - raise NotImplementedError('sequential does not support OrderedDict input.') - return args[0] # No sequential is needed. - modules = [] - for module in args: - if isinstance(module, nn.Sequential): - for submodule in module.children(): - modules.append(submodule) - elif isinstance(module, nn.Module): - modules.append(module) - return nn.Sequential(*modules) - - -def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True, \ - pad_type='zero', norm_type=None, act_type='relu', mode='CNA'): - ''' - Conv layer with padding, normalization, activation - mode: CNA --> Conv -> Norm -> Act - NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16) - ''' - assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode) - padding = get_valid_padding(kernel_size, dilation) - p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None - padding = padding if pad_type == 'zero' else 0 - - c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, \ - dilation=dilation, bias=bias, groups=groups) - a = act(act_type) if act_type else None - if 'CNA' in mode: - n = norm(norm_type, out_nc) if norm_type else None - return sequential(p, c, n, a) - elif mode == 'NAC': - if norm_type is None and act_type is not None: - a = act(act_type, inplace=False) - # Important! - # input----ReLU(inplace)----Conv--+----output - # |________________________| - # inplace ReLU will modify the input, therefore wrong output - n = norm(norm_type, in_nc) if norm_type else None - return sequential(n, a, p, c) - - -#################### -# Upsampler -#################### - - -def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, \ - pad_type='zero', norm_type=None, act_type='relu'): - ''' - Pixel shuffle layer - (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional - Neural Network, CVPR17) - ''' - conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias, \ - pad_type=pad_type, norm_type=None, act_type=None) - pixel_shuffle = nn.PixelShuffle(upscale_factor) - - n = norm(norm_type, out_nc) if norm_type else None - a = act(act_type) if act_type else None - return sequential(conv, pixel_shuffle, n, a) - - -def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, \ - pad_type='zero', norm_type=None, act_type='relu', mode='nearest'): - # Up conv - # described in https://distill.pub/2016/deconv-checkerboard/ - upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode) - conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias, \ - pad_type=pad_type, norm_type=norm_type, act_type=act_type) - return sequential(upsample, conv) diff --git a/codes/models/archs/SRResNet_arch.py b/codes/models/archs/SRResNet_arch.py deleted file mode 100644 index 6e622ac3..00000000 --- a/codes/models/archs/SRResNet_arch.py +++ /dev/null @@ -1,55 +0,0 @@ -import functools -import torch.nn as nn -import torch.nn.functional as F -import models.archs.arch_util as arch_util - - -class MSRResNet(nn.Module): - ''' modified SRResNet''' - - def __init__(self, in_nc=3, out_nc=3, nf=64, nb=16, upscale=4): - super(MSRResNet, self).__init__() - self.upscale = upscale - - self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) - basic_block = functools.partial(arch_util.ResidualBlock_noBN, nf=nf) - self.recon_trunk = arch_util.make_layer(basic_block, nb) - - # upsampling - if self.upscale == 2: - self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) - self.pixel_shuffle = nn.PixelShuffle(2) - elif self.upscale == 3: - self.upconv1 = nn.Conv2d(nf, nf * 9, 3, 1, 1, bias=True) - self.pixel_shuffle = nn.PixelShuffle(3) - elif self.upscale == 4: - self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) - self.upconv2 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) - self.pixel_shuffle = nn.PixelShuffle(2) - - self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) - - # activation function - self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) - - # initialization - arch_util.initialize_weights([self.conv_first, self.upconv1, self.HRconv, self.conv_last], - 0.1) - if self.upscale == 4: - arch_util.initialize_weights(self.upconv2, 0.1) - - def forward(self, x): - fea = self.lrelu(self.conv_first(x)) - out = self.recon_trunk(fea) - - if self.upscale == 4: - out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) - out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) - elif self.upscale == 3 or self.upscale == 2: - out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) - - out = self.conv_last(self.lrelu(self.HRconv(out))) - base = F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False) - out += base - return out diff --git a/codes/models/archs/biggan/biggan_discriminator.py b/codes/models/archs/biggan/biggan_discriminator.py deleted file mode 100644 index a85f4443..00000000 --- a/codes/models/archs/biggan/biggan_discriminator.py +++ /dev/null @@ -1,139 +0,0 @@ -import functools - -import torch -from torch.nn import init - -import models.archs.biggan.biggan_layers as layers -import torch.nn as nn - - -# Discriminator architecture, same paradigm as G's above -def D_arch(ch=64, attention='64',ksize='333333', dilation='111111'): - arch = {} - arch[256] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8, 8, 16]], - 'out_channels' : [item * ch for item in [1, 2, 4, 8, 8, 16, 16]], - 'downsample' : [True] * 6 + [False], - 'resolution' : [128, 64, 32, 16, 8, 4, 4 ], - 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')] - for i in range(2,8)}} - arch[128] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8, 16]], - 'out_channels' : [item * ch for item in [1, 2, 4, 8, 16, 16]], - 'downsample' : [True] * 5 + [False], - 'resolution' : [64, 32, 16, 8, 4, 4], - 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')] - for i in range(2,8)}} - arch[64] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8]], - 'out_channels' : [item * ch for item in [1, 2, 4, 8, 16]], - 'downsample' : [True] * 4 + [False], - 'resolution' : [32, 16, 8, 4, 4], - 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')] - for i in range(2,7)}} - arch[32] = {'in_channels' : [3] + [item * ch for item in [4, 4, 4]], - 'out_channels' : [item * ch for item in [4, 4, 4, 4]], - 'downsample' : [True, True, False, False], - 'resolution' : [16, 16, 16, 16], - 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')] - for i in range(2,6)}} - return arch - - -class BigGanDiscriminator(nn.Module): - - def __init__(self, D_ch=64, D_wide=True, resolution=128, - D_kernel_size=3, D_attn='64', num_D_SVs=1, num_D_SV_itrs=1, D_activation=nn.ReLU(inplace=False), - SN_eps=1e-12, output_dim=1, D_fp16=False, - D_init='ortho', skip_init=False, D_param='SN'): - super(BigGanDiscriminator, self).__init__() - # Width multiplier - self.ch = D_ch - # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN? - self.D_wide = D_wide - # Resolution - self.resolution = resolution - # Kernel size - self.kernel_size = D_kernel_size - # Attention? - self.attention = D_attn - # Activation - self.activation = D_activation - # Initialization style - self.init = D_init - # Parameterization style - self.D_param = D_param - # Epsilon for Spectral Norm? - self.SN_eps = SN_eps - # Fp16? - self.fp16 = D_fp16 - # Architecture - self.arch = D_arch(self.ch, self.attention)[resolution] - - # Which convs, batchnorms, and linear layers to use - # No option to turn off SN in D right now - if self.D_param == 'SN': - self.which_conv = functools.partial(layers.SNConv2d, - kernel_size=3, padding=1, - num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, - eps=self.SN_eps) - self.which_linear = functools.partial(layers.SNLinear, - num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, - eps=self.SN_eps) - self.which_embedding = functools.partial(layers.SNEmbedding, - num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, - eps=self.SN_eps) - # Prepare model - # self.blocks is a doubly-nested list of modules, the outer loop intended - # to be over blocks at a given resolution (resblocks and/or self-attention) - self.blocks = [] - for index in range(len(self.arch['out_channels'])): - self.blocks += [[layers.DBlock(in_channels=self.arch['in_channels'][index], - out_channels=self.arch['out_channels'][index], - which_conv=self.which_conv, - wide=self.D_wide, - activation=self.activation, - preactivation=(index > 0), - downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] else None))]] - # If attention on this block, attach it to the end - if self.arch['attention'][self.arch['resolution'][index]]: - print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index]) - self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], - self.which_conv)] - # Turn self.blocks into a ModuleList so that it's all properly registered. - self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) - # Linear output layer. The output dimension is typically 1, but may be - # larger if we're e.g. turning this into a VAE with an inference output - self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim) - - # Initialize weights - if not skip_init: - self.init_weights() - - # Initialize - def init_weights(self): - self.param_count = 0 - for module in self.modules(): - if (isinstance(module, nn.Conv2d) - or isinstance(module, nn.Linear) - or isinstance(module, nn.Embedding)): - if self.init == 'ortho': - init.orthogonal_(module.weight) - elif self.init == 'N02': - init.normal_(module.weight, 0, 0.02) - elif self.init in ['glorot', 'xavier']: - init.xavier_uniform_(module.weight) - else: - print('Init style not recognized...') - self.param_count += sum([p.data.nelement() for p in module.parameters()]) - print('Param count for D''s initialized parameters: %d' % self.param_count) - - def forward(self, x, y=None): - # Stick x into h for cleaner for loops without flow control - h = x - # Loop over blocks - for index, blocklist in enumerate(self.blocks): - for block in blocklist: - h = block(h) - # Apply global sum pooling as in SN-GAN - h = torch.sum(self.activation(h), [2, 3]) - # Get initial class-unconditional output - out = self.linear(h) - return out diff --git a/codes/models/archs/biggan/biggan_layers.py b/codes/models/archs/biggan/biggan_layers.py deleted file mode 100644 index 292d167f..00000000 --- a/codes/models/archs/biggan/biggan_layers.py +++ /dev/null @@ -1,457 +0,0 @@ -''' Layers - This file contains various layers for the BigGAN models. -''' -import numpy as np -import torch -import torch.nn as nn -from torch.nn import init -import torch.optim as optim -import torch.nn.functional as F -from torch.nn import Parameter as P - - -# Projection of x onto y -def proj(x, y): - return torch.mm(y, x.t()) * y / torch.mm(y, y.t()) - - -# Orthogonalize x wrt list of vectors ys -def gram_schmidt(x, ys): - for y in ys: - x = x - proj(x, y) - return x - - -# Apply num_itrs steps of the power method to estimate top N singular values. -def power_iteration(W, u_, update=True, eps=1e-12): - # Lists holding singular vectors and values - us, vs, svs = [], [], [] - for i, u in enumerate(u_): - # Run one step of the power iteration - with torch.no_grad(): - v = torch.matmul(u, W) - # Run Gram-Schmidt to subtract components of all other singular vectors - v = F.normalize(gram_schmidt(v, vs), eps=eps) - # Add to the list - vs += [v] - # Update the other singular vector - u = torch.matmul(v, W.t()) - # Run Gram-Schmidt to subtract components of all other singular vectors - u = F.normalize(gram_schmidt(u, us), eps=eps) - # Add to the list - us += [u] - if update: - u_[i][:] = u - # Compute this singular value and add it to the list - svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))] - # svs += [torch.sum(F.linear(u, W.transpose(0, 1)) * v)] - return svs, us, vs - - -# Convenience passthrough function -class identity(nn.Module): - def forward(self, input): - return input - - -# Spectral normalization base class -class SN(object): - def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12): - # Number of power iterations per step - self.num_itrs = num_itrs - # Number of singular values - self.num_svs = num_svs - # Transposed? - self.transpose = transpose - # Epsilon value for avoiding divide-by-0 - self.eps = eps - # Register a singular vector for each sv - for i in range(self.num_svs): - self.register_buffer('u%d' % i, torch.randn(1, num_outputs)) - self.register_buffer('sv%d' % i, torch.ones(1)) - - # Singular vectors (u side) - @property - def u(self): - return [getattr(self, 'u%d' % i) for i in range(self.num_svs)] - - # Singular values; - # note that these buffers are just for logging and are not used in training. - @property - def sv(self): - return [getattr(self, 'sv%d' % i) for i in range(self.num_svs)] - - # Compute the spectrally-normalized weight - def W_(self): - W_mat = self.weight.view(self.weight.size(0), -1) - if self.transpose: - W_mat = W_mat.t() - # Apply num_itrs power iterations - for _ in range(self.num_itrs): - svs, us, vs = power_iteration(W_mat, self.u, update=self.training, eps=self.eps) - # Update the svs - if self.training: - with torch.no_grad(): # Make sure to do this in a no_grad() context or you'll get memory leaks! - for i, sv in enumerate(svs): - self.sv[i][:] = sv - return self.weight / svs[0] - - -# 2D Conv layer with spectral norm -class SNConv2d(nn.Conv2d, SN): - def __init__(self, in_channels, out_channels, kernel_size, stride=1, - padding=0, dilation=1, groups=1, bias=True, - num_svs=1, num_itrs=1, eps=1e-12): - nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride, - padding, dilation, groups, bias) - SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps) - - def forward(self, x): - return F.conv2d(x, self.W_(), self.bias, self.stride, - self.padding, self.dilation, self.groups) - - -# Linear layer with spectral norm -class SNLinear(nn.Linear, SN): - def __init__(self, in_features, out_features, bias=True, - num_svs=1, num_itrs=1, eps=1e-12): - nn.Linear.__init__(self, in_features, out_features, bias) - SN.__init__(self, num_svs, num_itrs, out_features, eps=eps) - - def forward(self, x): - return F.linear(x, self.W_(), self.bias) - - -# Embedding layer with spectral norm -# We use num_embeddings as the dim instead of embedding_dim here -# for convenience sake -class SNEmbedding(nn.Embedding, SN): - def __init__(self, num_embeddings, embedding_dim, padding_idx=None, - max_norm=None, norm_type=2, scale_grad_by_freq=False, - sparse=False, _weight=None, - num_svs=1, num_itrs=1, eps=1e-12): - nn.Embedding.__init__(self, num_embeddings, embedding_dim, padding_idx, - max_norm, norm_type, scale_grad_by_freq, - sparse, _weight) - SN.__init__(self, num_svs, num_itrs, num_embeddings, eps=eps) - - def forward(self, x): - return F.embedding(x, self.W_()) - - -# A non-local block as used in SA-GAN -# Note that the implementation as described in the paper is largely incorrect; -# refer to the released code for the actual implementation. -class Attention(nn.Module): - def __init__(self, ch, which_conv=SNConv2d, name='attention'): - super(Attention, self).__init__() - # Channel multiplier - self.ch = ch - self.which_conv = which_conv - self.theta = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False) - self.phi = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False) - self.g = self.which_conv(self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False) - self.o = self.which_conv(self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False) - # Learnable gain parameter - self.gamma = P(torch.tensor(0.), requires_grad=True) - - def forward(self, x, y=None): - # Apply convs - theta = self.theta(x) - phi = F.max_pool2d(self.phi(x), [2, 2]) - g = F.max_pool2d(self.g(x), [2, 2]) - # Perform reshapes - theta = theta.view(-1, self.ch // 8, x.shape[2] * x.shape[3]) - phi = phi.view(-1, self.ch // 8, x.shape[2] * x.shape[3] // 4) - g = g.view(-1, self.ch // 2, x.shape[2] * x.shape[3] // 4) - # Matmul and softmax to get attention maps - beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1) - # Attention map times g path - o = self.o(torch.bmm(g, beta.transpose(1, 2)).view(-1, self.ch // 2, x.shape[2], x.shape[3])) - return self.gamma * o + x - - -# Fused batchnorm op -def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5): - # Apply scale and shift--if gain and bias are provided, fuse them here - # Prepare scale - scale = torch.rsqrt(var + eps) - # If a gain is provided, use it - if gain is not None: - scale = scale * gain - # Prepare shift - shift = mean * scale - # If bias is provided, use it - if bias is not None: - shift = shift - bias - return x * scale - shift - # return ((x - mean) / ((var + eps) ** 0.5)) * gain + bias # The unfused way. - - -# Manual BN -# Calculate means and variances using mean-of-squares minus mean-squared -def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5): - # Cast x to float32 if necessary - float_x = x.float() - # Calculate expected value of x (m) and expected value of x**2 (m2) - # Mean of x - m = torch.mean(float_x, [0, 2, 3], keepdim=True) - # Mean of x squared - m2 = torch.mean(float_x ** 2, [0, 2, 3], keepdim=True) - # Calculate variance as mean of squared minus mean squared. - var = (m2 - m ** 2) - # Cast back to float 16 if necessary - var = var.type(x.type()) - m = m.type(x.type()) - # Return mean and variance for updating stored mean/var if requested - if return_mean_var: - return fused_bn(x, m, var, gain, bias, eps), m.squeeze(), var.squeeze() - else: - return fused_bn(x, m, var, gain, bias, eps) - - -# My batchnorm, supports standing stats -class myBN(nn.Module): - def __init__(self, num_channels, eps=1e-5, momentum=0.1): - super(myBN, self).__init__() - # momentum for updating running stats - self.momentum = momentum - # epsilon to avoid dividing by 0 - self.eps = eps - # Momentum - self.momentum = momentum - # Register buffers - self.register_buffer('stored_mean', torch.zeros(num_channels)) - self.register_buffer('stored_var', torch.ones(num_channels)) - self.register_buffer('accumulation_counter', torch.zeros(1)) - # Accumulate running means and vars - self.accumulate_standing = False - - # reset standing stats - def reset_stats(self): - self.stored_mean[:] = 0 - self.stored_var[:] = 0 - self.accumulation_counter[:] = 0 - - def forward(self, x, gain, bias): - if self.training: - out, mean, var = manual_bn(x, gain, bias, return_mean_var=True, eps=self.eps) - # If accumulating standing stats, increment them - if self.accumulate_standing: - self.stored_mean[:] = self.stored_mean + mean.data - self.stored_var[:] = self.stored_var + var.data - self.accumulation_counter += 1.0 - # If not accumulating standing stats, take running averages - else: - self.stored_mean[:] = self.stored_mean * (1 - self.momentum) + mean * self.momentum - self.stored_var[:] = self.stored_var * (1 - self.momentum) + var * self.momentum - return out - # If not in training mode, use the stored statistics - else: - mean = self.stored_mean.view(1, -1, 1, 1) - var = self.stored_var.view(1, -1, 1, 1) - # If using standing stats, divide them by the accumulation counter - if self.accumulate_standing: - mean = mean / self.accumulation_counter - var = var / self.accumulation_counter - return fused_bn(x, mean, var, gain, bias, self.eps) - - -# Simple function to handle groupnorm norm stylization -def groupnorm(x, norm_style): - # If number of channels specified in norm_style: - if 'ch' in norm_style: - ch = int(norm_style.split('_')[-1]) - groups = max(int(x.shape[1]) // ch, 1) - # If number of groups specified in norm style - elif 'grp' in norm_style: - groups = int(norm_style.split('_')[-1]) - # If neither, default to groups = 16 - else: - groups = 16 - return F.group_norm(x, groups) - - -# Class-conditional bn -# output size is the number of channels, input size is for the linear layers -# Andy's Note: this class feels messy but I'm not really sure how to clean it up -# Suggestions welcome! (By which I mean, refactor this and make a pull request -# if you want to make this more readable/usable). -class ccbn(nn.Module): - def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1, - cross_replica=False, mybn=False, norm_style='bn', ): - super(ccbn, self).__init__() - self.output_size, self.input_size = output_size, input_size - # Prepare gain and bias layers - self.gain = which_linear(input_size, output_size) - self.bias = which_linear(input_size, output_size) - # epsilon to avoid dividing by 0 - self.eps = eps - # Momentum - self.momentum = momentum - # Use cross-replica batchnorm? - self.cross_replica = cross_replica - # Use my batchnorm? - self.mybn = mybn - # Norm style? - self.norm_style = norm_style - - if self.cross_replica or self.mybn: - self.bn = myBN(output_size, self.eps, self.momentum) - elif self.norm_style in ['bn', 'in']: - self.register_buffer('stored_mean', torch.zeros(output_size)) - self.register_buffer('stored_var', torch.ones(output_size)) - - def forward(self, x, y): - # Calculate class-conditional gains and biases - gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1) - bias = self.bias(y).view(y.size(0), -1, 1, 1) - # If using my batchnorm - if self.mybn or self.cross_replica: - return self.bn(x, gain=gain, bias=bias) - # else: - else: - if self.norm_style == 'bn': - out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None, - self.training, 0.1, self.eps) - elif self.norm_style == 'in': - out = F.instance_norm(x, self.stored_mean, self.stored_var, None, None, - self.training, 0.1, self.eps) - elif self.norm_style == 'gn': - out = groupnorm(x, self.normstyle) - elif self.norm_style == 'nonorm': - out = x - return out * gain + bias - - def extra_repr(self): - s = 'out: {output_size}, in: {input_size},' - s += ' cross_replica={cross_replica}' - return s.format(**self.__dict__) - - -# Normal, non-class-conditional BN -class bn(nn.Module): - def __init__(self, output_size, eps=1e-5, momentum=0.1, - cross_replica=False, mybn=False): - super(bn, self).__init__() - self.output_size = output_size - # Prepare gain and bias layers - self.gain = P(torch.ones(output_size), requires_grad=True) - self.bias = P(torch.zeros(output_size), requires_grad=True) - # epsilon to avoid dividing by 0 - self.eps = eps - # Momentum - self.momentum = momentum - # Use cross-replica batchnorm? - self.cross_replica = cross_replica - # Use my batchnorm? - self.mybn = mybn - - if self.cross_replica or mybn: - self.bn = myBN(output_size, self.eps, self.momentum) - # Register buffers if neither of the above - else: - self.register_buffer('stored_mean', torch.zeros(output_size)) - self.register_buffer('stored_var', torch.ones(output_size)) - - def forward(self, x, y=None): - if self.cross_replica or self.mybn: - gain = self.gain.view(1, -1, 1, 1) - bias = self.bias.view(1, -1, 1, 1) - return self.bn(x, gain=gain, bias=bias) - else: - return F.batch_norm(x, self.stored_mean, self.stored_var, self.gain, - self.bias, self.training, self.momentum, self.eps) - - -# Generator blocks -# Note that this class assumes the kernel size and padding (and any other -# settings) have been selected in the main generator module and passed in -# through the which_conv arg. Similar rules apply with which_bn (the input -# size [which is actually the number of channels of the conditional info] must -# be preselected) -class GBlock(nn.Module): - def __init__(self, in_channels, out_channels, - which_conv=nn.Conv2d, which_bn=bn, activation=None, - upsample=None): - super(GBlock, self).__init__() - - self.in_channels, self.out_channels = in_channels, out_channels - self.which_conv, self.which_bn = which_conv, which_bn - self.activation = activation - self.upsample = upsample - # Conv layers - self.conv1 = self.which_conv(self.in_channels, self.out_channels) - self.conv2 = self.which_conv(self.out_channels, self.out_channels) - self.learnable_sc = in_channels != out_channels or upsample - if self.learnable_sc: - self.conv_sc = self.which_conv(in_channels, out_channels, - kernel_size=1, padding=0) - # Batchnorm layers - self.bn1 = self.which_bn(in_channels) - self.bn2 = self.which_bn(out_channels) - # upsample layers - self.upsample = upsample - - def forward(self, x, y): - h = self.activation(self.bn1(x, y)) - if self.upsample: - h = self.upsample(h) - x = self.upsample(x) - h = self.conv1(h) - h = self.activation(self.bn2(h, y)) - h = self.conv2(h) - if self.learnable_sc: - x = self.conv_sc(x) - return h + x - - -# Residual block for the discriminator -class DBlock(nn.Module): - def __init__(self, in_channels, out_channels, which_conv=SNConv2d, wide=True, - preactivation=False, activation=None, downsample=None, ): - super(DBlock, self).__init__() - self.in_channels, self.out_channels = in_channels, out_channels - # If using wide D (as in SA-GAN and BigGAN), change the channel pattern - self.hidden_channels = self.out_channels if wide else self.in_channels - self.which_conv = which_conv - self.preactivation = preactivation - self.activation = activation - self.downsample = downsample - - # Conv layers - self.conv1 = self.which_conv(self.in_channels, self.hidden_channels) - self.conv2 = self.which_conv(self.hidden_channels, self.out_channels) - self.learnable_sc = True if (in_channels != out_channels) or downsample else False - if self.learnable_sc: - self.conv_sc = self.which_conv(in_channels, out_channels, - kernel_size=1, padding=0) - - def shortcut(self, x): - if self.preactivation: - if self.learnable_sc: - x = self.conv_sc(x) - if self.downsample: - x = self.downsample(x) - else: - if self.downsample: - x = self.downsample(x) - if self.learnable_sc: - x = self.conv_sc(x) - return x - - def forward(self, x): - if self.preactivation: - # h = self.activation(x) # NOT TODAY SATAN - # Andy's note: This line *must* be an out-of-place ReLU or it - # will negatively affect the shortcut connection. - h = F.relu(x) - else: - h = x - h = self.conv1(h) - h = self.conv2(self.activation(h)) - if self.downsample: - h = self.downsample(h) - - return h + self.shortcut(x) - diff --git a/codes/models/steps/__init__.py b/codes/models/archs/byol/__init__.py similarity index 100% rename from codes/models/steps/__init__.py rename to codes/models/archs/byol/__init__.py diff --git a/codes/models/byol/byol_model_wrapper.py b/codes/models/archs/byol/byol_model_wrapper.py similarity index 100% rename from codes/models/byol/byol_model_wrapper.py rename to codes/models/archs/byol/byol_model_wrapper.py diff --git a/codes/models/byol/byol_structural.py b/codes/models/archs/byol/byol_structural.py similarity index 97% rename from codes/models/byol/byol_structural.py rename to codes/models/archs/byol/byol_structural.py index 17d5cfbb..80a9af75 100644 --- a/codes/models/byol/byol_structural.py +++ b/codes/models/archs/byol/byol_structural.py @@ -1,14 +1,11 @@ import copy -import random -from functools import wraps -from time import time import torch import torch.nn.functional as F from torch import nn from data.byol_attachment import reconstructed_shared_regions -from models.byol.byol_model_wrapper import singleton, EMA, MLP, get_module_device, set_requires_grad, \ +from models.archs.byol.byol_model_wrapper import singleton, EMA, get_module_device, set_requires_grad, \ update_moving_average from utils.util import checkpoint diff --git a/codes/models/archs/flownet2 b/codes/models/archs/flownet2 new file mode 160000 index 00000000..db2b7899 --- /dev/null +++ b/codes/models/archs/flownet2 @@ -0,0 +1 @@ +Subproject commit db2b7899ea8506e90418dbd389300c49bdbb55c3 diff --git a/codes/models/archs/lambda_rrdb.py b/codes/models/archs/lambda_rrdb.py deleted file mode 100644 index 19ed468e..00000000 --- a/codes/models/archs/lambda_rrdb.py +++ /dev/null @@ -1,47 +0,0 @@ -import torch -from torch import nn -from lambda_networks import LambdaLayer -from torch.nn import GroupNorm - -from models.archs.RRDBNet_arch import ResidualDenseBlock -from models.archs.arch_util import ConvGnLelu - - -class LambdaRRDB(nn.Module): - """Residual in Residual Dense Block. - - Used in RRDB-Net in ESRGAN. - - Args: - mid_channels (int): Channel number of intermediate features. - growth_channels (int): Channels for each growth. - """ - - def __init__(self, mid_channels, growth_channels=32, reduce_to=None): - super(LambdaRRDB, self).__init__() - if reduce_to is None: - reduce_to = mid_channels - self.lam1 = LambdaLayer(dim=mid_channels, dim_out=mid_channels, r=23, dim_k=16, heads=4, dim_u=4) - self.gn1 = GroupNorm(num_groups=8, num_channels=mid_channels) - self.lam2 = LambdaLayer(dim=mid_channels, dim_out=mid_channels, r=23, dim_k=16, heads=4, dim_u=4) - self.gn2 = GroupNorm(num_groups=8, num_channels=mid_channels) - self.lam3 = LambdaLayer(dim=mid_channels, dim_out=reduce_to, r=23, dim_k=16, heads=4, dim_u=4) - self.gn3 = GroupNorm(num_groups=8, num_channels=mid_channels) - self.conv = ConvGnLelu(reduce_to, reduce_to, kernel_size=1, bias=True, norm=False, activation=False, weight_init_factor=.1) - - def forward(self, x): - """Forward function. - - Args: - x (Tensor): Input tensor with shape (n, c, h, w). - - Returns: - Tensor: Forward results. - """ - out = self.lam1(x) - out = self.gn1(out) - out = self.lam2(out) - out = self.gn1(out) - out = self.lam3(out) - out = self.gn3(out) - return self.conv(out) * .2 + x \ No newline at end of file diff --git a/codes/models/archs/multi_res_rrdb.py b/codes/models/archs/multi_res_rrdb.py deleted file mode 100644 index 59f35e42..00000000 --- a/codes/models/archs/multi_res_rrdb.py +++ /dev/null @@ -1,206 +0,0 @@ -import torch.nn as nn -import torch.nn.functional as F - -from models.archs.RRDBNet_arch import RRDB -from models.archs.arch_util import make_layer, default_init_weights, ConvGnSilu, ConvGnLelu, PixelUnshuffle -from utils.util import checkpoint, sequential_checkpoint - - -class MultiLevelRRDB(nn.Module): - def __init__(self, nf, gc, levels): - super().__init__() - self.levels = levels - self.level_rrdbs = nn.ModuleList([RRDB(nf, growth_channels=gc) for i in range(levels)]) - - # Trunks should be fed in in order HR->LR - def forward(self, trunk): - for i in reversed(range(self.levels)): - lvl_scale = (2**i) - lvl_res = self.level_rrdbs[i](F.interpolate(trunk, scale_factor=1/lvl_scale, mode="area"), return_residual=True) - trunk = trunk + F.interpolate(lvl_res, scale_factor=lvl_scale, mode="nearest") - return trunk - - -class MultiResRRDBNet(nn.Module): - def __init__(self, - in_channels, - out_channels, - mid_channels=64, - l1_blocks=3, - l2_blocks=4, - l3_blocks=6, - growth_channels=32, - scale=4, - ): - super().__init__() - self.scale = scale - self.in_channels = in_channels - - self.conv_first = nn.Conv2d(in_channels, mid_channels, 7, stride=1, padding=3) - - self.l3_blocks = nn.ModuleList([MultiLevelRRDB(mid_channels, growth_channels, 3) for _ in range(l1_blocks)]) - self.l2_blocks = nn.ModuleList([MultiLevelRRDB(mid_channels, growth_channels, 2) for _ in range(l2_blocks)]) - self.l1_blocks = nn.ModuleList([MultiLevelRRDB(mid_channels, growth_channels, 1) for _ in range(l3_blocks)]) - self.block_levels = [self.l3_blocks, self.l2_blocks, self.l1_blocks] - - self.conv_body = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1) - # upsample - self.conv_up1 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1) - self.conv_up2 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1) - self.conv_hr = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1) - self.conv_last = nn.Conv2d(mid_channels, out_channels, 3, 1, 1) - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - - for m in [ - self.conv_first, self.conv_first, self.conv_body, self.conv_up1, - self.conv_up2, self.conv_hr, self.conv_last - ]: - if m is not None: - default_init_weights(m, 0.1) - - def forward(self, x): - trunk = self.conv_first(x) - for block_set in self.block_levels: - for block in block_set: - trunk = checkpoint(block, trunk) - - body_feat = self.conv_body(trunk) - feat = trunk + body_feat - - # upsample - out = self.lrelu( - self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) - if self.scale == 4: - out = self.lrelu( - self.conv_up2(F.interpolate(out, scale_factor=2, mode='nearest'))) - else: - out = self.lrelu(self.conv_up2(out)) - out = self.conv_last(self.lrelu(self.conv_hr(out))) - - return out - - def visual_dbg(self, step, path): - pass - - -class SteppedResRRDBNet(nn.Module): - def __init__(self, - in_channels, - out_channels, - mid_channels=64, - l1_blocks=3, - l2_blocks=3, - growth_channels=32, - scale=4, - ): - super().__init__() - self.scale = scale - self.in_channels = in_channels - - self.conv_first = nn.Conv2d(in_channels, mid_channels, 7, stride=2, padding=3) - self.conv_second = nn.Conv2d(mid_channels, mid_channels*2, 3, stride=2, padding=1) - - self.l1_blocks = nn.Sequential(*[RRDB(mid_channels*2, growth_channels*2) for _ in range(l1_blocks)]) - self.l1_upsample_conv = nn.Conv2d(mid_channels*2, mid_channels, 3, stride=1, padding=1) - self.l2_blocks = nn.Sequential(*[RRDB(mid_channels, growth_channels, 2) for _ in range(l2_blocks)]) - - self.conv_body = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1) - # upsample - self.conv_up1 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1) - self.conv_up2 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1) - self.conv_hr = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1) - self.conv_last = nn.Conv2d(mid_channels, out_channels, 3, 1, 1) - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - - for m in [ - self.conv_first, self.conv_second, self.l1_upsample_conv, self.conv_body, self.conv_up1, - self.conv_up2, self.conv_hr, self.conv_last - ]: - if m is not None: - default_init_weights(m, 0.1) - - def forward(self, x): - trunk = self.conv_first(x) - trunk = self.conv_second(trunk) - trunk = sequential_checkpoint(self.l1_blocks, len(self.l2_blocks), trunk) - trunk = F.interpolate(trunk, scale_factor=2, mode="nearest") - trunk = self.l1_upsample_conv(trunk) - trunk = sequential_checkpoint(self.l2_blocks, len(self.l2_blocks), trunk) - body_feat = self.conv_body(trunk) - feat = trunk + body_feat - - # upsample - out = self.lrelu( - self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) - if self.scale == 4: - out = self.lrelu( - self.conv_up2(F.interpolate(out, scale_factor=2, mode='nearest'))) - else: - out = self.lrelu(self.conv_up2(out)) - out = self.conv_last(self.lrelu(self.conv_hr(out))) - - return out - - def visual_dbg(self, step, path): - pass - - -class PixelShufflingSteppedResRRDBNet(nn.Module): - def __init__(self, - in_channels, - out_channels, - mid_channels=64, - l1_blocks=3, - l2_blocks=3, - growth_channels=32, - scale=2, - ): - super().__init__() - self.scale = scale * 2 # This RRDB operates at half-scale resolution. - self.in_channels = in_channels - - self.pix_unshuffle = PixelUnshuffle(4) - self.conv_first = nn.Conv2d(4*4*in_channels, mid_channels*2, 3, stride=1, padding=1) - - self.l1_blocks = nn.Sequential(*[RRDB(mid_channels*2, growth_channels*2) for _ in range(l1_blocks)]) - self.l1_upsample_conv = nn.Conv2d(mid_channels*2, mid_channels, 3, stride=1, padding=1) - self.l2_blocks = nn.Sequential(*[RRDB(mid_channels, growth_channels, 2) for _ in range(l2_blocks)]) - - self.conv_body = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1) - # upsample - self.conv_up1 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1) - self.conv_up2 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1) - self.conv_hr = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1) - self.conv_last = nn.Conv2d(mid_channels, out_channels, 3, 1, 1) - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - - for m in [ - self.conv_first, self.l1_upsample_conv, self.conv_body, self.conv_up1, - self.conv_up2, self.conv_hr, self.conv_last - ]: - if m is not None: - default_init_weights(m, 0.1) - - def forward(self, x): - trunk = self.conv_first(self.pix_unshuffle(x)) - trunk = sequential_checkpoint(self.l1_blocks, len(self.l1_blocks), trunk) - trunk = F.interpolate(trunk, scale_factor=2, mode="nearest") - trunk = self.l1_upsample_conv(trunk) - trunk = sequential_checkpoint(self.l2_blocks, len(self.l2_blocks), trunk) - body_feat = self.conv_body(trunk) - feat = trunk + body_feat - - # upsample - out = self.lrelu( - self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) - if self.scale == 4: - out = self.lrelu( - self.conv_up2(F.interpolate(out, scale_factor=2, mode='nearest'))) - else: - out = self.lrelu(self.conv_up2(out)) - out = self.conv_last(self.lrelu(self.conv_hr(out))) - - return out - - def visual_dbg(self, step, path): - pass diff --git a/codes/models/archs/pyramid_arch.py b/codes/models/archs/pyramid_arch.py deleted file mode 100644 index 855ffa58..00000000 --- a/codes/models/archs/pyramid_arch.py +++ /dev/null @@ -1,98 +0,0 @@ -import torch -from torch import nn - -from models.archs.arch_util import ConvGnLelu, ExpansionBlock -from models.flownet2.networks.resample2d_package.resample2d import Resample2d -from utils.util import checkpoint -import torch.nn.functional as F - - -class Pyramid(nn.Module): - def __init__(self, nf, depth, processing_convs_per_layer, processing_at_point, scale_per_level=2, block=ConvGnLelu, - norm=True, return_outlevels=False): - super(Pyramid, self).__init__() - levels = [] - current_filters = nf - self.return_outlevels = return_outlevels - for d in range(depth): - level = [block(current_filters, int(current_filters*scale_per_level), kernel_size=3, stride=2, activation=True, norm=False, bias=False)] - current_filters = int(current_filters*scale_per_level) - for pc in range(processing_convs_per_layer): - level.append(block(current_filters, current_filters, kernel_size=3, activation=True, norm=norm, bias=False)) - levels.append(nn.Sequential(*level)) - self.downsamples = nn.ModuleList(levels) - if processing_at_point > 0: - point_processor = [] - for p in range(processing_at_point): - point_processor.append(block(current_filters, current_filters, kernel_size=3, activation=True, norm=norm, bias=False)) - self.point_processor = nn.Sequential(*point_processor) - else: - self.point_processor = None - levels = [] - for d in range(depth): - level = [ExpansionBlock(current_filters, int(current_filters / scale_per_level), block=block)] - current_filters = int(current_filters / scale_per_level) - for pc in range(processing_convs_per_layer): - level.append(block(current_filters, current_filters, kernel_size=3, activation=True, norm=norm, bias=False)) - levels.append(nn.ModuleList(level)) - self.upsamples = nn.ModuleList(levels) - - def forward(self, x): - passthroughs = [] - fea = x - for lvl in self.downsamples: - passthroughs.append(fea) - fea = lvl(fea) - out_levels = [] - fea = self.point_processor(fea) - for i, lvl in enumerate(self.upsamples): - out_levels.append(fea) - for j, sublvl in enumerate(lvl): - if j == 0: - fea = sublvl(fea, passthroughs[-1-i]) - else: - fea = sublvl(fea) - - out_levels.append(fea) - - if self.return_outlevels: - return tuple(out_levels) - else: - return fea - - -class BasicResamplingFlowNet(nn.Module): - def create_termini(self, filters): - return nn.Sequential(ConvGnLelu(int(filters), 2, kernel_size=3, activation=False, norm=False, bias=True), - nn.Tanh()) - - def __init__(self, nf, resample_scale=1): - super(BasicResamplingFlowNet, self).__init__() - self.initial_conv = ConvGnLelu(6, nf, kernel_size=7, activation=False, norm=False, bias=True) - self.pyramid = Pyramid(nf, 3, 0, 1, 1.5, return_outlevels=True) - self.termini = nn.ModuleList([self.create_termini(nf*1.5**3), - self.create_termini(nf*1.5**2), - self.create_termini(nf*1.5)]) - self.terminus = nn.Sequential(ConvGnLelu(nf, nf, kernel_size=3, activation=True, norm=True, bias=True), - ConvGnLelu(nf, nf, kernel_size=3, activation=True, norm=True, bias=False), - ConvGnLelu(nf, nf//2, kernel_size=3, activation=False, norm=False, bias=True), - ConvGnLelu(nf//2, 2, kernel_size=3, activation=False, norm=False, bias=True), - nn.Tanh()) - self.scale = resample_scale - self.resampler = Resample2d() - - def forward(self, left, right): - fea = self.initial_conv(torch.cat([left, right], dim=1)) - levels = checkpoint(self.pyramid, fea) - flos = [] - compares = [] - for i, level in enumerate(levels): - if i == 3: - flow = checkpoint(self.terminus, level) * self.scale - else: - flow = self.termini[i](level) * self.scale - img_scale = 1/2**(3-i) - flos.append(self.resampler(F.interpolate(left, scale_factor=img_scale, mode="area").float(), flow.float())) - compares.append(F.interpolate(right, scale_factor=img_scale, mode="area")) - flos_structural_var = torch.var(flos[-1], dim=[-1,-2]) - return flos, compares, flos_structural_var diff --git a/codes/models/archs/pytorch_ssim.py b/codes/models/archs/pytorch_ssim.py deleted file mode 100644 index 5bdadb79..00000000 --- a/codes/models/archs/pytorch_ssim.py +++ /dev/null @@ -1,80 +0,0 @@ -import torch -import torch.nn.functional as F -from torch.autograd import Variable -import numpy as np -from math import exp - - -def gaussian(window_size, sigma): - gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) - return gauss / gauss.sum() - - -def create_window(window_size, channel): - _1D_window = gaussian(window_size, 1.5).unsqueeze(1) - _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) - window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) - return window - - -def _ssim(img1, img2, window, window_size, channel, size_average=True, raw=False): - mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) - mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) - - mu1_sq = mu1.pow(2) - mu2_sq = mu2.pow(2) - mu1_mu2 = mu1 * mu2 - - sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq - sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq - sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 - - C1 = 0.01 ** 2 - C2 = 0.03 ** 2 - - ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) - - if size_average: - return ssim_map.mean() - elif raw: - return ssim_map - else: - return ssim_map.mean(1).mean(1).mean(1) - - -class SSIM(torch.nn.Module): - def __init__(self, window_size=11, size_average=True, raw=False): - super(SSIM, self).__init__() - self.window_size = window_size - self.size_average = size_average - self.raw = raw - self.channel = 1 - self.window = create_window(window_size, self.channel) - - def forward(self, img1, img2): - (_, channel, _, _) = img1.size() - - if channel == self.channel and self.window.data.type() == img1.data.type(): - window = self.window - else: - window = create_window(self.window_size, channel) - - if img1.is_cuda: - window = window.cuda(img1.get_device()) - window = window.type_as(img1) - - self.window = window - self.channel = channel - - return _ssim(img1, img2, window, self.window_size, channel, self.size_average, self.raw) - - -def ssim(img1, img2, window_size=11, size_average=True): - (_, channel, _, _) = img1.size() - window = create_window(window_size, channel) - - if img1.is_cuda: - window = window.cuda(img1.get_device()) - window = window.type_as(img1) - - return _ssim(img1, img2, window, window_size, channel, size_average) \ No newline at end of file diff --git a/codes/models/archs/rcan.py b/codes/models/archs/rcan.py deleted file mode 100644 index 684727db..00000000 --- a/codes/models/archs/rcan.py +++ /dev/null @@ -1,221 +0,0 @@ -import torch.nn as nn -import math - -import torch -import torch.nn as nn -import torch.nn.functional as F -from utils.util import checkpoint - -from torch.autograd import Variable - -def default_conv(in_channels, out_channels, kernel_size, bias=True): - return nn.Conv2d( - in_channels, out_channels, kernel_size, - padding=(kernel_size//2), bias=bias) - -class MeanShift(nn.Conv2d): - def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): - super(MeanShift, self).__init__(3, 3, kernel_size=1) - std = torch.Tensor(rgb_std) - self.weight.data = torch.eye(3).view(3, 3, 1, 1) - self.weight.data.div_(std.view(3, 1, 1, 1)) - self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) - self.bias.data.div_(std) - self.requires_grad = False - -class BasicBlock(nn.Sequential): - def __init__( - self, in_channels, out_channels, kernel_size, stride=1, bias=False, - bn=True, act=nn.ReLU(True)): - - m = [nn.Conv2d( - in_channels, out_channels, kernel_size, - padding=(kernel_size//2), stride=stride, bias=bias) - ] - if bn: m.append(nn.BatchNorm2d(out_channels)) - if act is not None: m.append(act) - super(BasicBlock, self).__init__(*m) - -class ResBlock(nn.Module): - def __init__( - self, conv, n_feat, kernel_size, - bias=True, bn=False, act=nn.ReLU(True), res_scale=1): - - super(ResBlock, self).__init__() - m = [] - for i in range(2): - m.append(conv(n_feat, n_feat, kernel_size, bias=bias)) - if bn: m.append(nn.BatchNorm2d(n_feat)) - if i == 0: m.append(act) - - self.body = nn.Sequential(*m) - self.res_scale = res_scale - - def forward(self, x): - res = self.body(x).mul(self.res_scale) - res += x - - return res - -class Upsampler(nn.Sequential): - def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True): - - m = [] - if (scale & (scale - 1)) == 0: # Is scale = 2^n? - for _ in range(int(math.log(scale, 2))): - m.append(conv(n_feat, 4 * n_feat, 3, bias)) - m.append(nn.PixelShuffle(2)) - if bn: m.append(nn.BatchNorm2d(n_feat)) - if act: m.append(act()) - elif scale == 3: - m.append(conv(n_feat, 9 * n_feat, 3, bias)) - m.append(nn.PixelShuffle(3)) - if bn: m.append(nn.BatchNorm2d(n_feat)) - if act: m.append(act()) - else: - raise NotImplementedError - - super(Upsampler, self).__init__(*m) - -def make_model(args, parent=False): - return RCAN(args) - - -## Channel Attention (CA) Layer -class CALayer(nn.Module): - def __init__(self, channel, reduction=16): - super(CALayer, self).__init__() - # global average pooling: feature --> point - self.avg_pool = nn.AdaptiveAvgPool2d(1) - # feature channel downscale and upscale --> channel weight - self.conv_du = nn.Sequential( - nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), - nn.ReLU(inplace=True), - nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), - nn.Sigmoid() - ) - - def forward(self, x): - y = self.avg_pool(x) - y = self.conv_du(y) - return x * y - - -## Residual Channel Attention Block (RCAB) -class RCAB(nn.Module): - def __init__( - self, conv, n_feat, kernel_size, reduction, - bias=True, bn=False, act=nn.ReLU(True), res_scale=1): - - super(RCAB, self).__init__() - modules_body = [] - for i in range(2): - modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) - if bn: modules_body.append(nn.BatchNorm2d(n_feat)) - if i == 0: modules_body.append(act) - modules_body.append(CALayer(n_feat, reduction)) - self.body = nn.Sequential(*modules_body) - self.res_scale = res_scale - - def forward(self, x): - res = self.body(x) - # res = self.body(x).mul(self.res_scale) - res += x - return res - - -## Residual Group (RG) -class ResidualGroup(nn.Module): - def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): - super(ResidualGroup, self).__init__() - modules_body = [] - modules_body = [ - RCAB( - conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ - for _ in range(n_resblocks)] - modules_body.append(conv(n_feat, n_feat, kernel_size)) - self.body = nn.Sequential(*modules_body) - - def forward(self, x): - res = self.body(x) - res += x - return res - - -## Residual Channel Attention Network (RCAN) -class RCAN(nn.Module): - def __init__(self, args, conv=default_conv): - super(RCAN, self).__init__() - - n_resgroups = args.n_resgroups - n_resblocks = args.n_resblocks - n_feats = args.n_feats - kernel_size = 3 - reduction = args.reduction - scale = args.scale - act = nn.ReLU(True) - - # RGB mean for DIV2K - rgb_mean = (0.4488, 0.4371, 0.4040) - rgb_std = (1.0, 1.0, 1.0) - self.sub_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std) - - # define head module - modules_head = [conv(args.n_colors, n_feats, kernel_size)] - - # define body module - modules_body = [ - ResidualGroup( - conv, n_feats, kernel_size, reduction, act=act, res_scale=args.res_scale, n_resblocks=n_resblocks) \ - for _ in range(n_resgroups)] - - modules_body.append(conv(n_feats, n_feats, kernel_size)) - - # define tail module - modules_tail = [ - Upsampler(conv, scale, n_feats, act=False), - conv(n_feats, args.n_colors, kernel_size)] - - self.add_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) - - self.head = nn.Sequential(*modules_head) - self.body = nn.Sequential(*modules_body) - self.tail = nn.Sequential(*modules_tail) - - def forward(self, x): - x = self.sub_mean(x) - x = self.head(x) - - res = self.body(x) - res += x - - x = self.tail(res) - x = self.add_mean(x) - - return x - - def load_state_dict(self, state_dict, strict=False): - own_state = self.state_dict() - for name, param in state_dict.items(): - if name in own_state: - if isinstance(param, nn.Parameter): - param = param.data - try: - own_state[name].copy_(param) - except Exception: - if name.find('tail') >= 0: - print('Replace pre-trained upsampler to new one...') - else: - raise RuntimeError('While copying the parameter named {}, ' - 'whose dimensions in the model are {} and ' - 'whose dimensions in the checkpoint are {}.' - .format(name, own_state[name].size(), param.size())) - elif strict: - if name.find('tail') == -1: - raise KeyError('unexpected key "{}" in state_dict' - .format(name)) - - if strict: - missing = set(own_state.keys()) - set(state_dict.keys()) - if len(missing) > 0: - raise KeyError('missing keys in state_dict: "{}"'.format(missing)) \ No newline at end of file diff --git a/codes/models/archs/tecogan/__init__.py b/codes/models/archs/tecogan/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/codes/models/archs/teco_resgen.py b/codes/models/archs/tecogan/teco_resgen.py similarity index 100% rename from codes/models/archs/teco_resgen.py rename to codes/models/archs/tecogan/teco_resgen.py diff --git a/codes/models/archs/transformers/igpt/gpt2.py b/codes/models/archs/transformers/igpt/gpt2.py index 39388b11..e5c4fe92 100644 --- a/codes/models/archs/transformers/igpt/gpt2.py +++ b/codes/models/archs/transformers/igpt/gpt2.py @@ -2,8 +2,7 @@ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np -import torchvision -from models.steps.injectors import Injector +from models.injectors import Injector from utils.util import checkpoint diff --git a/codes/models/custom_training_components/__init__.py b/codes/models/custom_training_components/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/codes/models/steps/progressive_zoom.py b/codes/models/custom_training_components/progressive_zoom.py similarity index 97% rename from codes/models/steps/progressive_zoom.py rename to codes/models/custom_training_components/progressive_zoom.py index 2b27cf7b..d3ad0afa 100644 --- a/codes/models/steps/progressive_zoom.py +++ b/codes/models/custom_training_components/progressive_zoom.py @@ -6,9 +6,8 @@ import torchvision from torch.cuda.amp import autocast from data.multiscale_dataset import build_multiscale_patch_index_map -from models.steps.injectors import Injector -from models.steps.losses import extract_params_from_state -from models.steps.tecogan_losses import extract_inputs_index +from models.injectors import Injector +from models.losses import extract_params_from_state import os.path as osp diff --git a/codes/models/steps/stereoscopic.py b/codes/models/custom_training_components/stereoscopic.py similarity index 90% rename from codes/models/steps/stereoscopic.py rename to codes/models/custom_training_components/stereoscopic.py index ff36be6d..2ca416aa 100644 --- a/codes/models/steps/stereoscopic.py +++ b/codes/models/custom_training_components/stereoscopic.py @@ -1,8 +1,8 @@ import torch from torch.cuda.amp import autocast -from models.flownet2.networks.resample2d_package.resample2d import Resample2d -from models.flownet2.utils.flow_utils import flow2img -from models.steps.injectors import Injector +from models.archs.flownet2.networks import Resample2d +from models.archs.flownet2 import flow2img +from models.injectors import Injector def create_stereoscopic_injector(opt, env): diff --git a/codes/models/steps/tecogan_losses.py b/codes/models/custom_training_components/tecogan_losses.py similarity index 98% rename from codes/models/steps/tecogan_losses.py rename to codes/models/custom_training_components/tecogan_losses.py index 7f921204..6f91ec32 100644 --- a/codes/models/steps/tecogan_losses.py +++ b/codes/models/custom_training_components/tecogan_losses.py @@ -1,15 +1,15 @@ from torch.cuda.amp import autocast from models.archs.stylegan.stylegan2_lucidrains import gradient_penalty -from models.steps.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name -from models.flownet2.networks.resample2d_package.resample2d import Resample2d -from models.steps.injectors import Injector +from models.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name +from models.archs.flownet2.networks import Resample2d +from models.injectors import Injector import torch import torch.nn.functional as F import os import os.path as osp import torchvision -import torch.distributed as dist + def create_teco_loss(opt, env): type = opt['type'] diff --git a/codes/models/steps/injectors.py b/codes/models/injectors.py similarity index 83% rename from codes/models/steps/injectors.py rename to codes/models/injectors.py index a447fa49..3b10c4d6 100644 --- a/codes/models/steps/injectors.py +++ b/codes/models/injectors.py @@ -3,22 +3,20 @@ import random import torch.nn from torch.cuda.amp import autocast -from models.archs.SPSR_arch import ImageGradientNoPadding -from models.archs.pytorch_ssim import SSIM from utils.weight_scheduler import get_scheduler_for_opt -from models.steps.losses import extract_params_from_state +from models.losses import extract_params_from_state # Injectors are a way to sythesize data within a step that can then be used (and reused) by loss functions. def create_injector(opt_inject, env): type = opt_inject['type'] if 'teco_' in type: - from models.steps.tecogan_losses import create_teco_injector + from models.custom_training_components import create_teco_injector return create_teco_injector(opt_inject, env) elif 'progressive_' in type: - from models.steps.progressive_zoom import create_progressive_zoom_injector + from models.custom_training_components import create_progressive_zoom_injector return create_progressive_zoom_injector(opt_inject, env) elif 'stereoscopic_' in type: - from models.steps.stereoscopic import create_stereoscopic_injector + from models.custom_training_components import create_stereoscopic_injector return create_stereoscopic_injector(opt_inject, env) elif 'igpt' in type: from models.archs.transformers.igpt import gpt2 @@ -29,8 +27,6 @@ def create_injector(opt_inject, env): return DiscriminatorInjector(opt_inject, env) elif type == 'scheduled_scalar': return ScheduledScalarInjector(opt_inject, env) - elif type == 'img_grad': - return ImageGradientInjector(opt_inject, env) elif type == 'add_noise': return AddNoiseInjector(opt_inject, env) elif type == 'greyscale': @@ -47,14 +43,10 @@ def create_injector(opt_inject, env): return ForEachInjector(opt_inject, env) elif type == 'constant': return ConstantInjector(opt_inject, env) - elif type == 'fft': - return ImageFftInjector(opt_inject, env) elif type == 'extract_indices': return IndicesExtractor(opt_inject, env) elif type == 'random_shift': return RandomShiftInjector(opt_inject, env) - elif type == 'psnr': - return PsnrInjector(opt_inject, env) elif type == 'batch_rotate': return BatchRotateInjector(opt_inject, env) elif type == 'sr_diffs': @@ -134,16 +126,6 @@ class DiscriminatorInjector(Injector): return new_state -# Creates an image gradient from [in] and injects it into [out] -class ImageGradientInjector(Injector): - def __init__(self, opt, env): - super(ImageGradientInjector, self).__init__(opt, env) - self.img_grad_fn = ImageGradientNoPadding().to(env['device']) - - def forward(self, state): - return {self.opt['out']: self.img_grad_fn(state[self.opt['in']])} - - # Injects a scalar that is modulated with a specified schedule. Useful for increasing or decreasing the influence # of something over time. class ScheduledScalarInjector(Injector): @@ -320,37 +302,6 @@ class ConstantInjector(Injector): return { self.opt['out']: out } -class ImageFftInjector(Injector): - def __init__(self, opt, env): - super(ImageFftInjector, self).__init__(opt, env) - self.is_forward = opt['forward'] # Whether to compute a forward FFT or backward. - self.eps = 1e-100 - - def forward(self, state): - if self.forward: - fftim = torch.rfft(state[self.input], signal_ndim=2, normalized=True) - b, f, h, w, c = fftim.shape - fftim = fftim.permute(0,1,4,2,3).reshape(b,-1,h,w) - # Normalize across spatial dimension - mean = torch.mean(fftim, dim=(0,1)) - fftim = fftim - mean - std = torch.std(fftim, dim=(0,1)) - fftim = (fftim + self.eps) / std - return {self.output: fftim, - '%s_std' % (self.output,): std, - '%s_mean' % (self.output,): mean} - else: - b, f, h, w = state[self.input].shape - # First, de-normalize the FFT. - mean = state['%s_mean' % (self.input,)] - std = state['%s_std' % (self.input,)] - fftim = state[self.input] * std + mean - self.eps - # Second, recover the FFT dimensions from the given filters. - fftim = fftim.reshape(b, f // 2, 2, h, w).permute(0,1,3,4,2) - im = torch.irfft(fftim, signal_ndim=2, normalized=True) - return {self.output: im} - - class IndicesExtractor(Injector): def __init__(self, opt, env): super(IndicesExtractor, self).__init__(opt, env) @@ -374,21 +325,6 @@ class RandomShiftInjector(Injector): return {self.output: img} -class PsnrInjector(Injector): - def __init__(self, opt, env): - super(PsnrInjector, self).__init__(opt, env) - self.ssim = SSIM(size_average=False, raw=True) - self.scale = opt['output_scale_divisor'] - self.exp = opt['exponent'] if 'exponent' in opt.keys() else 1 - - def forward(self, state): - img1, img2 = state[self.input[0]], state[self.input[1]] - ssim = self.ssim(img1, img2) - areal_se = torch.nn.functional.interpolate(ssim, scale_factor=1/self.scale, - mode="area") - return {self.output: areal_se} - - class BatchRotateInjector(Injector): def __init__(self, opt, env): super(BatchRotateInjector, self).__init__(opt, env) @@ -436,7 +372,7 @@ class MultiFrameCombiner(Injector): self.in_hq_key = opt['in_hq'] self.out_lq_key = opt['out'] self.out_hq_key = opt['out_hq'] - from models.flownet2.networks.resample2d_package.resample2d import Resample2d + from models.archs.flownet2.networks import Resample2d self.resampler = Resample2d() def combine(self, state): diff --git a/codes/models/steps/losses.py b/codes/models/losses.py similarity index 99% rename from codes/models/steps/losses.py rename to codes/models/losses.py index 77d748d8..9c604417 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/losses.py @@ -6,13 +6,12 @@ from models.loss import GANLoss import random import functools import torch.nn.functional as F -import numpy as np def create_loss(opt_loss, env): type = opt_loss['type'] if 'teco_' in type: - from models.steps.tecogan_losses import create_teco_loss + from models.custom_training_components.tecogan_losses import create_teco_loss return create_teco_loss(opt_loss, env) elif 'stylegan2_' in type: from models.archs.stylegan import create_stylegan2_loss diff --git a/codes/models/networks.py b/codes/models/networks.py index 056a1f19..abbde8bc 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -10,16 +10,12 @@ import models.archs.stylegan.stylegan2_lucidrains as stylegan2 import models.archs.fixup_resnet.DiscriminatorResnet_arch as DiscriminatorResnet_arch import models.archs.RRDBNet_arch as RRDBNet_arch -import models.archs.SPSR_arch as spsr -import models.archs.SRResNet_arch as SRResNet_arch import models.archs.SwitchedResidualGenerator_arch as SwitchedGen_arch import models.archs.discriminator_vgg_arch as SRGAN_arch import models.archs.feature_arch as feature_arch -import models.archs.rcan as rcan from models.archs import srg2_classic -from models.archs.biggan.biggan_discriminator import BigGanDiscriminator from models.archs.stylegan.Discriminator_StyleGAN import StyleGanDiscriminator -from models.archs.teco_resgen import TecoGen +from models.archs.tecogan.teco_resgen import TecoGen from utils.util import opt_get logger = logging.getLogger('base') @@ -30,11 +26,7 @@ def define_G(opt, opt_net, scale=None): scale = opt['scale'] which_model = opt_net['which_model_G'] - # image restoration - if which_model == 'MSRResNet': - netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], - nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale']) - elif 'RRDBNet' in which_model: + if 'RRDBNet' in which_model: if which_model == 'RRDBNetBypass': block = RRDBNet_arch.RRDBWithBypass elif which_model == 'RRDBNetLambda': @@ -50,24 +42,6 @@ def define_G(opt, opt_net, scale=None): mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode, output_mode=output_mode, body_block=block, scale=opt_net['scale'], growth_channels=gc, initial_stride=initial_stride) - elif which_model == "multires_rrdb": - from models.archs.multi_res_rrdb import MultiResRRDBNet - netG = MultiResRRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'], - mid_channels=opt_net['nf'], l1_blocks=opt_net['l1'], - l2_blocks=opt_net['l2'], l3_blocks=opt_net['l3'], - growth_channels=opt_net['gc'], scale=opt_net['scale']) - elif which_model == "twostep_rrdb": - from models.archs.multi_res_rrdb import PixelShufflingSteppedResRRDBNet - netG = PixelShufflingSteppedResRRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'], - mid_channels=opt_net['nf'], l1_blocks=opt_net['l1'], - l2_blocks=opt_net['l2'], - growth_channels=opt_net['gc'], scale=opt_net['scale']) - elif which_model == 'rcan': - #args: n_resgroups, n_resblocks, res_scale, reduction, scale, n_feats - opt_net['rgb_range'] = 255 - opt_net['n_colors'] = 3 - args_obj = munchify(opt_net) - netG = rcan.RCAN(args_obj) elif which_model == "ConfigurableSwitchedResidualGenerator2": netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'], switch_reductions=opt_net['switch_reductions'], @@ -87,22 +61,8 @@ def define_G(opt, opt_net, scale=None): initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'], heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'], upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise']) - elif which_model == 'spsr': - netG = spsr.SPSRNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], - nb=opt_net['nb'], upscale=opt_net['scale']) - elif which_model == 'spsr_net_improved': - netG = spsr.SPSRNetSimplified(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], - nb=opt_net['nb'], upscale=opt_net['scale']) - elif which_model == "spsr_switched": - netG = spsr.SwitchedSpsr(in_nc=3, nf=opt_net['nf'], upscale=opt_net['scale'], init_temperature=opt_net['temperature']) - elif which_model == "spsr7": - recurrent = opt_net['recurrent'] if 'recurrent' in opt_net.keys() else False - xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8 - netG = spsr.Spsr7(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'], - multiplexer_reductions=opt_net['multiplexer_reductions'] if 'multiplexer_reductions' in opt_net.keys() else 3, - init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10, recurrent=recurrent) elif which_model == "flownet2": - from models.flownet2.models import FlowNet2 + from models.archs.flownet2 import FlowNet2 ld = 'load_path' in opt_net.keys() args = munch.Munch({'fp16': False, 'rgb_max': 1.0, 'checkpoint': not ld}) netG = FlowNet2(args) @@ -148,12 +108,12 @@ def define_G(opt, opt_net, scale=None): from models.archs.transformers.igpt.gpt2 import iGPT2 netG = iGPT2(opt_net['embed_dim'], opt_net['num_heads'], opt_net['num_layers'], opt_net['num_pixels'] ** 2, opt_net['num_vocab'], centroids_file=opt_net['centroids_file']) elif which_model == 'byol': - from models.byol.byol_model_wrapper import BYOL + from models.archs.byol.byol_model_wrapper import BYOL subnet = define_G(opt, opt_net['subnet']) netG = BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'], structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False)) elif which_model == 'structural_byol': - from models.byol.byol_structural import StructuralBYOL + from models.archs.byol.byol_structural import StructuralBYOL subnet = define_G(opt, opt_net['subnet']) netG = StructuralBYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'], pretrained_state_dict=opt_get(opt_net, ["pretrained_path"]), @@ -206,8 +166,6 @@ def define_D_net(opt_net, img_sz=None, wrap=False): #state_dict = torch.hub.load_state_dict_from_url('https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', progress=True) #netD.load_state_dict(state_dict, strict=False) netD.fc = torch.nn.Linear(512 * 4, 1) - elif which_model == 'biggan_resnet': - netD = BigGanDiscriminator(D_activation=torch.nn.LeakyReLU(negative_slope=.2)) elif which_model == 'discriminator_pix': netD = SRGAN_arch.Discriminator_VGG_PixLoss(in_nc=opt_net['in_nc'], nf=opt_net['nf']) elif which_model == "discriminator_unet": diff --git a/codes/models/steps/steps.py b/codes/models/steps.py similarity index 98% rename from codes/models/steps/steps.py rename to codes/models/steps.py index 8659c6f3..3eeb902c 100644 --- a/codes/models/steps/steps.py +++ b/codes/models/steps.py @@ -1,12 +1,12 @@ -from torch.cuda.amp import GradScaler, autocast +from torch.cuda.amp import GradScaler from utils.loss_accumulator import LossAccumulator from torch.nn import Module import logging -from models.steps.losses import create_loss +from models.losses import create_loss import torch from collections import OrderedDict -from .injectors import create_injector +from models.injectors import create_injector from utils.util import recursively_detach logger = logging.getLogger('base') diff --git a/codes/scripts/use_generator_as_filter.py b/codes/scripts/use_generator_as_filter.py index 26718702..e65c82cd 100644 --- a/codes/scripts/use_generator_as_filter.py +++ b/codes/scripts/use_generator_as_filter.py @@ -8,14 +8,11 @@ import os import utils from models.ExtensibleTrainer import ExtensibleTrainer from models.networks import define_F -from models.steps.losses import FeatureLoss from utils import options as option import utils.util as util from data import create_dataset, create_dataloader from tqdm import tqdm import torch -import torchvision - if __name__ == "__main__": bin_path = "f:\\binned" diff --git a/codes/utils/onnx_inference.py b/codes/utils/onnx_inference.py deleted file mode 100644 index 56814b12..00000000 --- a/codes/utils/onnx_inference.py +++ /dev/null @@ -1,22 +0,0 @@ -import onnx -import numpy as np -import time - -init_temperature = 10 -final_temperature_step = 50 -heightened_final_step = 90 -heightened_temp_min = .1 - -for step in range(100): - temp = max(1, 1 + init_temperature * (final_temperature_step - step) / final_temperature_step) - if temp == 1 and step > final_temperature_step and heightened_final_step and heightened_final_step != 1: - # Once the temperature passes (1) it enters an inverted curve to match the linear curve from above. - # without this, the attention specificity "spikes" incredibly fast in the last few iterations. - h_steps_total = heightened_final_step - final_temperature_step - h_steps_current = min(step - final_temperature_step, h_steps_total) - # The "gap" will represent the steps that need to be traveled as a linear function. - h_gap = 1 / heightened_temp_min - temp = h_gap * h_steps_current / h_steps_total - # Invert temperature to represent reality on this side of the curve - temp = 1 / temp - print("%i: %f" % (step, temp)) \ No newline at end of file