forked from mrq/DL-Art-School
Large cleanup
Removed a lot of old code that I won't be touching again. Refactored some code elements into more logical places.
This commit is contained in:
parent
2f0a52b7db
commit
b905b108da
3
.gitmodules
vendored
3
.gitmodules
vendored
|
@ -8,3 +8,6 @@
|
|||
path = codes/models/flownet2
|
||||
url = https://github.com/neonbjb/flownet2-pytorch.git
|
||||
branch = master
|
||||
[submodule "codes/models/archs/flownet2"]
|
||||
path = codes/models/archs/flownet2
|
||||
url = https://github.com/neonbjb/flownet2-pytorch.git
|
||||
|
|
|
@ -8,8 +8,8 @@ import torch.nn as nn
|
|||
import models.lr_scheduler as lr_scheduler
|
||||
import models.networks as networks
|
||||
from models.base_model import BaseModel
|
||||
from models.steps.injectors import create_injector
|
||||
from models.steps.steps import ConfigurableStep
|
||||
from models.injectors import create_injector
|
||||
from models.steps import ConfigurableStep
|
||||
from models.experiments.experiments import get_experiment_for_name
|
||||
import torchvision.utils as utils
|
||||
|
||||
|
|
|
@ -1,668 +0,0 @@
|
|||
import functools
|
||||
import os
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
from utils.util import checkpoint
|
||||
|
||||
from models.archs import SPSR_util as B
|
||||
from models.archs.SwitchedResidualGenerator_arch import ConfigurableSwitchComputer, ReferenceImageBranch, \
|
||||
QueryKeyMultiplexer, QueryKeyPyramidMultiplexer, ConvBasisMultiplexer
|
||||
from models.archs.arch_util import ConvGnLelu, UpconvBlock, MultiConvBlock, ReferenceJoinBlock
|
||||
from switched_conv.switched_conv import compute_attention_specificity
|
||||
from switched_conv.switched_conv_util import save_attention_to_image_rgb
|
||||
from .RRDBNet_arch import RRDB
|
||||
|
||||
|
||||
class ImageGradient(nn.Module):
|
||||
def __init__(self):
|
||||
super(ImageGradient, self).__init__()
|
||||
kernel_v = [[0, -1, 0],
|
||||
[0, 0, 0],
|
||||
[0, 1, 0]]
|
||||
kernel_h = [[0, 0, 0],
|
||||
[-1, 0, 1],
|
||||
[0, 0, 0]]
|
||||
kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0)
|
||||
kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0)
|
||||
self.weight_h = nn.Parameter(data = kernel_h, requires_grad = False).cuda()
|
||||
self.weight_v = nn.Parameter(data = kernel_v, requires_grad = False).cuda()
|
||||
|
||||
def forward(self, x):
|
||||
x0 = x[:, 0]
|
||||
x1 = x[:, 1]
|
||||
x2 = x[:, 2]
|
||||
x0_v = F.conv2d(x0.unsqueeze(1), self.weight_v, padding=2)
|
||||
x0_h = F.conv2d(x0.unsqueeze(1), self.weight_h, padding=2)
|
||||
|
||||
x1_v = F.conv2d(x1.unsqueeze(1), self.weight_v, padding=2)
|
||||
x1_h = F.conv2d(x1.unsqueeze(1), self.weight_h, padding=2)
|
||||
|
||||
x2_v = F.conv2d(x2.unsqueeze(1), self.weight_v, padding=2)
|
||||
x2_h = F.conv2d(x2.unsqueeze(1), self.weight_h, padding=2)
|
||||
|
||||
x0 = torch.sqrt(torch.pow(x0_v, 2) + torch.pow(x0_h, 2) + 1e-6)
|
||||
x1 = torch.sqrt(torch.pow(x1_v, 2) + torch.pow(x1_h, 2) + 1e-6)
|
||||
x2 = torch.sqrt(torch.pow(x2_v, 2) + torch.pow(x2_h, 2) + 1e-6)
|
||||
|
||||
x = torch.cat([x0, x1, x2], dim=1)
|
||||
return x
|
||||
|
||||
|
||||
class ImageGradientNoPadding(nn.Module):
|
||||
def __init__(self):
|
||||
super(ImageGradientNoPadding, self).__init__()
|
||||
kernel_v = [[0, -1, 0],
|
||||
[0, 0, 0],
|
||||
[0, 1, 0]]
|
||||
kernel_h = [[0, 0, 0],
|
||||
[-1, 0, 1],
|
||||
[0, 0, 0]]
|
||||
kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0)
|
||||
kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0)
|
||||
self.weight_h = nn.Parameter(data = kernel_h, requires_grad = False)
|
||||
|
||||
self.weight_v = nn.Parameter(data = kernel_v, requires_grad = False)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
x = x.float()
|
||||
x_list = []
|
||||
for i in range(x.shape[1]):
|
||||
x_i = x[:, i]
|
||||
x_i_v = F.conv2d(x_i.unsqueeze(1), self.weight_v, padding=1)
|
||||
x_i_h = F.conv2d(x_i.unsqueeze(1), self.weight_h, padding=1)
|
||||
x_i = torch.sqrt(torch.pow(x_i_v, 2) + torch.pow(x_i_h, 2) + 1e-6)
|
||||
x_list.append(x_i)
|
||||
|
||||
x = torch.cat(x_list, dim = 1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
####################
|
||||
# Generator
|
||||
####################
|
||||
|
||||
class SPSRNet(nn.Module):
|
||||
def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, \
|
||||
act_type='leakyrelu', mode='CNA', upsample_mode='upconv'):
|
||||
super(SPSRNet, self).__init__()
|
||||
|
||||
n_upscale = int(math.log(upscale, 2))
|
||||
|
||||
self.scale=upscale
|
||||
if upscale == 3:
|
||||
n_upscale = 1
|
||||
|
||||
fea_conv = ConvGnLelu(in_nc, nf//2, kernel_size=7, norm=False, activation=False)
|
||||
self.ref_conv = ConvGnLelu(in_nc, nf//2, stride=upscale, kernel_size=7, norm=False, activation=False)
|
||||
self.join_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
|
||||
rb_blocks = [RRDB(nf) for _ in range(nb)]
|
||||
|
||||
LR_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
|
||||
|
||||
upsample_block = UpconvBlock
|
||||
if upscale == 3:
|
||||
upsampler = upsample_block(nf, nf, activation=True)
|
||||
else:
|
||||
upsampler = [upsample_block(nf, nf, activation=True) for _ in range(n_upscale)]
|
||||
|
||||
self.HR_conv0_new = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True)
|
||||
self.HR_conv1_new = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
|
||||
|
||||
self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)), \
|
||||
*upsampler, self.HR_conv0_new)
|
||||
|
||||
self.b_fea_conv = ConvGnLelu(in_nc, nf//2, kernel_size=3, norm=False, activation=False)
|
||||
self.b_ref_conv = ConvGnLelu(in_nc, nf//2, stride=upscale, kernel_size=3, norm=False, activation=False)
|
||||
self.b_join_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
|
||||
|
||||
self.b_concat_1 = ConvGnLelu(2 * nf, nf, kernel_size=3, norm=False, activation=False)
|
||||
self.b_block_1 = RRDB(nf * 2)
|
||||
|
||||
self.b_concat_2 = ConvGnLelu(2 * nf, nf, kernel_size=3, norm=False, activation=False)
|
||||
self.b_block_2 = RRDB(nf * 2)
|
||||
|
||||
self.b_concat_3 = ConvGnLelu(2 * nf, nf, kernel_size=3, norm=False, activation=False)
|
||||
self.b_block_3 = RRDB(nf * 2)
|
||||
|
||||
self.b_concat_4 = ConvGnLelu(2 * nf, nf, kernel_size=3, norm=False, activation=False)
|
||||
self.b_block_4 = RRDB(nf * 2)
|
||||
|
||||
self.b_LR_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
|
||||
|
||||
if upscale == 3:
|
||||
b_upsampler = UpconvBlock(nf, nf, activation=True)
|
||||
else:
|
||||
b_upsampler = [UpconvBlock(nf, nf, activation=True) for _ in range(n_upscale)]
|
||||
|
||||
b_HR_conv0 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True)
|
||||
b_HR_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
|
||||
|
||||
self.b_module = B.sequential(*b_upsampler, b_HR_conv0, b_HR_conv1)
|
||||
|
||||
self.conv_w = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False)
|
||||
|
||||
self.f_concat = ConvGnLelu(nf * 2, nf, kernel_size=3, norm=False, activation=False)
|
||||
|
||||
self.f_block = RRDB(nf * 2)
|
||||
|
||||
self.f_HR_conv0 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True)
|
||||
self.f_HR_conv1 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False)
|
||||
|
||||
self.get_g_nopadding = ImageGradientNoPadding()
|
||||
|
||||
def bl1(self, x):
|
||||
block_list = self.model[1].sub
|
||||
for i in range(5):
|
||||
x = block_list[i](x)
|
||||
return x
|
||||
|
||||
def bl2(self, x):
|
||||
block_list = self.model[1].sub
|
||||
for i in range(5):
|
||||
x = block_list[i+5](x)
|
||||
return x
|
||||
|
||||
def bl3(self, x):
|
||||
block_list = self.model[1].sub
|
||||
for i in range(5):
|
||||
x = block_list[i+10](x)
|
||||
return x
|
||||
|
||||
def bl4(self, x):
|
||||
block_list = self.model[1].sub
|
||||
for i in range(5):
|
||||
x = block_list[i+15](x)
|
||||
return x
|
||||
|
||||
def bl5(self, x):
|
||||
block_list = self.model[1].sub
|
||||
x = block_list[20:](x)
|
||||
return x
|
||||
|
||||
def bl6(self, x_ori, x):
|
||||
x = x_ori + x
|
||||
x = self.model[2:](x)
|
||||
x = self.HR_conv1_new(x)
|
||||
return x
|
||||
|
||||
def branch_bl1(self, x_grad, ref_grad):
|
||||
x_b_fea = self.b_fea_conv(x_grad)
|
||||
x_b_ref = self.b_ref_conv(ref_grad)
|
||||
x_b_fea = self.b_join_conv(torch.cat([x_b_fea, x_b_ref], dim=1))
|
||||
return x_b_fea
|
||||
|
||||
def branch_bl2(self, x_b_fea, x_fea1):
|
||||
x_cat_1 = torch.cat([x_b_fea, x_fea1], dim=1)
|
||||
x_cat_1 = self.b_block_1(x_cat_1)
|
||||
x_cat_1 = self.b_concat_1(x_cat_1)
|
||||
return x_cat_1
|
||||
|
||||
def branch_bl3(self, x_cat_1, x_fea2):
|
||||
x_cat_2 = torch.cat([x_cat_1, x_fea2], dim=1)
|
||||
x_cat_2 = self.b_block_2(x_cat_2)
|
||||
x_cat_2 = self.b_concat_2(x_cat_2)
|
||||
return x_cat_2
|
||||
|
||||
def branch_bl4(self, x_cat_2, x_fea3):
|
||||
x_cat_3 = torch.cat([x_cat_2, x_fea3], dim=1)
|
||||
x_cat_3 = self.b_block_3(x_cat_3)
|
||||
x_cat_3 = self.b_concat_3(x_cat_3)
|
||||
return x_cat_3
|
||||
|
||||
def branch_bl5(self, x_cat_3, x_fea4, x_b_fea):
|
||||
x_cat_4 = torch.cat([x_cat_3, x_fea4], dim=1)
|
||||
x_cat_4 = self.b_block_4(x_cat_4)
|
||||
x_cat_4 = self.b_concat_4(x_cat_4)
|
||||
x_cat_4 = self.b_LR_conv(x_cat_4)
|
||||
x_cat_4 = x_cat_4 + x_b_fea
|
||||
x_branch = self.b_module(x_cat_4)
|
||||
return x_branch
|
||||
|
||||
def forward(self, x, ref=None):
|
||||
b,f,h,w = x.shape
|
||||
if ref is None:
|
||||
ref = torch.zeros((b,f,h*self.scale,w*self.scale), device=x.device, dtype=x.dtype)
|
||||
|
||||
x_grad = self.get_g_nopadding(x)
|
||||
ref_grad = self.get_g_nopadding(ref)
|
||||
x = self.model[0](x)
|
||||
x_ref = self.ref_conv(ref)
|
||||
x = self.join_conv(torch.cat([x, x_ref], dim=1))
|
||||
|
||||
x, block_list = self.model[1](x)
|
||||
x_ori = x
|
||||
x = checkpoint(self.bl1, x)
|
||||
x_fea1 = x
|
||||
x = checkpoint(self.bl2, x)
|
||||
x_fea2 = x
|
||||
x = checkpoint(self.bl3, x)
|
||||
x_fea3 = x
|
||||
x = checkpoint(self.bl4, x)
|
||||
x_fea4 = x
|
||||
x = checkpoint(self.bl5, x)
|
||||
x = checkpoint(self.bl6, x_ori, x)
|
||||
|
||||
x_b_fea = checkpoint(self.branch_bl1, x_grad, ref_grad)
|
||||
x_cat_1 = checkpoint(self.branch_bl2, x_b_fea, x_fea1)
|
||||
x_cat_2 = checkpoint(self.branch_bl3, x_cat_1, x_fea2)
|
||||
x_cat_3 = checkpoint(self.branch_bl4, x_cat_2, x_fea3)
|
||||
x_branch = checkpoint(self.branch_bl5, x_cat_3, x_fea4, x_b_fea)
|
||||
|
||||
x_out_branch = checkpoint(self.conv_w, x_branch)
|
||||
########
|
||||
x_branch_d = x_branch
|
||||
x_f_cat = torch.cat([x_branch_d, x], dim=1)
|
||||
x_f_cat = checkpoint(self.f_block, x_f_cat)
|
||||
x_out = self.f_concat(x_f_cat)
|
||||
x_out = checkpoint(self.f_HR_conv0, x_out)
|
||||
x_out = checkpoint(self.f_HR_conv1, x_out)
|
||||
|
||||
#########
|
||||
return x_out_branch, x_out, x_grad
|
||||
|
||||
|
||||
class SPSRNetSimplified(nn.Module):
|
||||
def __init__(self, in_nc, out_nc, nf, nb, upscale=4):
|
||||
super(SPSRNetSimplified, self).__init__()
|
||||
n_upscale = int(math.log(upscale, 2))
|
||||
|
||||
# Feature branch
|
||||
self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False)
|
||||
self.model_shortcut_blk = nn.Sequential(*[RRDB(nf, gc=32) for _ in range(nb)])
|
||||
self.feature_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
|
||||
self.model_upsampler = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=False) for _ in range(n_upscale)])
|
||||
self.feature_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False)
|
||||
self.feature_hr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False)
|
||||
|
||||
# Grad branch
|
||||
self.get_g_nopadding = ImageGradientNoPadding()
|
||||
self.b_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False)
|
||||
self.b_concat_decimate_1 = ConvGnLelu(2 * nf, nf, kernel_size=1, norm=False, activation=False, bias=False)
|
||||
self.b_proc_block_1 = RRDB(nf, gc=32)
|
||||
self.b_concat_decimate_2 = ConvGnLelu(2 * nf, nf, kernel_size=1, norm=False, activation=False, bias=False)
|
||||
self.b_proc_block_2 = RRDB(nf, gc=32)
|
||||
self.b_concat_decimate_3 = ConvGnLelu(2 * nf, nf, kernel_size=1, norm=False, activation=False, bias=False)
|
||||
self.b_proc_block_3 = RRDB(nf, gc=32)
|
||||
self.b_concat_decimate_4 = ConvGnLelu(2 * nf, nf, kernel_size=1, norm=False, activation=False, bias=False)
|
||||
self.b_proc_block_4 = RRDB(nf, gc=32)
|
||||
|
||||
# Upsampling
|
||||
self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False)
|
||||
b_upsampler = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=False) for _ in range(n_upscale)])
|
||||
grad_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False)
|
||||
grad_hr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False)
|
||||
self.branch_upsample = B.sequential(*b_upsampler, grad_hr_conv1, grad_hr_conv2)
|
||||
# Conv used to output grad branch shortcut.
|
||||
self.grad_branch_output_conv = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=False)
|
||||
|
||||
# Conjoin branch.
|
||||
# Note: "_branch_pretrain" is a special tag used to denote parameters that get pretrained before the rest.
|
||||
self._branch_pretrain_concat = ConvGnLelu(nf * 2, nf, kernel_size=1, norm=False, activation=False, bias=False)
|
||||
self._branch_pretrain_block = RRDB(nf * 2, gc=32)
|
||||
self._branch_pretrain_HR_conv0 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False)
|
||||
self._branch_pretrain_HR_conv1 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x_grad = self.get_g_nopadding(x)
|
||||
x = self.model_fea_conv(x)
|
||||
|
||||
x_ori = x
|
||||
for i in range(5):
|
||||
x = self.model_shortcut_blk[i](x)
|
||||
x_fea1 = x
|
||||
|
||||
for i in range(5):
|
||||
x = self.model_shortcut_blk[i + 5](x)
|
||||
x_fea2 = x
|
||||
|
||||
for i in range(5):
|
||||
x = self.model_shortcut_blk[i + 10](x)
|
||||
x_fea3 = x
|
||||
|
||||
for i in range(5):
|
||||
x = self.model_shortcut_blk[i + 15](x)
|
||||
x_fea4 = x
|
||||
|
||||
x = self.model_shortcut_blk[20:](x)
|
||||
x = self.feature_lr_conv(x)
|
||||
|
||||
# short cut
|
||||
x = x_ori + x
|
||||
x = self.model_upsampler(x)
|
||||
x = self.feature_hr_conv1(x)
|
||||
x = self.feature_hr_conv2(x)
|
||||
|
||||
x_b_fea = self.b_fea_conv(x_grad)
|
||||
x_cat_1 = torch.cat([x_b_fea, x_fea1], dim=1)
|
||||
|
||||
x_cat_1 = self.b_concat_decimate_1(x_cat_1)
|
||||
x_cat_1 = self.b_proc_block_1(x_cat_1)
|
||||
|
||||
x_cat_2 = torch.cat([x_cat_1, x_fea2], dim=1)
|
||||
|
||||
x_cat_2 = self.b_concat_decimate_2(x_cat_2)
|
||||
x_cat_2 = self.b_proc_block_2(x_cat_2)
|
||||
|
||||
x_cat_3 = torch.cat([x_cat_2, x_fea3], dim=1)
|
||||
|
||||
x_cat_3 = self.b_concat_decimate_3(x_cat_3)
|
||||
x_cat_3 = self.b_proc_block_3(x_cat_3)
|
||||
|
||||
x_cat_4 = torch.cat([x_cat_3, x_fea4], dim=1)
|
||||
|
||||
x_cat_4 = self.b_concat_decimate_4(x_cat_4)
|
||||
x_cat_4 = self.b_proc_block_4(x_cat_4)
|
||||
|
||||
x_cat_4 = self.grad_lr_conv(x_cat_4)
|
||||
|
||||
# short cut
|
||||
x_cat_4 = x_cat_4 + x_b_fea
|
||||
x_branch = self.branch_upsample(x_cat_4)
|
||||
x_out_branch = self.grad_branch_output_conv(x_branch)
|
||||
|
||||
########
|
||||
x_branch_d = x_branch
|
||||
x__branch_pretrain_cat = torch.cat([x_branch_d, x], dim=1)
|
||||
x__branch_pretrain_cat = self._branch_pretrain_block(x__branch_pretrain_cat)
|
||||
x_out = self._branch_pretrain_concat(x__branch_pretrain_cat)
|
||||
x_out = self._branch_pretrain_HR_conv0(x_out)
|
||||
x_out = self._branch_pretrain_HR_conv1(x_out)
|
||||
|
||||
#########
|
||||
return x_out_branch, x_out, x_grad
|
||||
|
||||
|
||||
# Variant of Spsr6 which uses multiplexer blocks that feed off of a reference embedding. Also computes that embedding.
|
||||
class Spsr7(nn.Module):
|
||||
def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, multiplexer_reductions=3, recurrent=False, init_temperature=10):
|
||||
super(Spsr7, self).__init__()
|
||||
n_upscale = int(math.log(upscale, 2))
|
||||
|
||||
# processing the input embedding
|
||||
self.reference_embedding = ReferenceImageBranch(nf)
|
||||
|
||||
self.recurrent = recurrent
|
||||
if recurrent:
|
||||
self.model_recurrent_conv = ConvGnLelu(3, nf, kernel_size=3, stride=2, norm=False, activation=False,
|
||||
bias=True)
|
||||
self.model_fea_recurrent_combine = ConvGnLelu(nf * 2, nf, 1, activation=False, norm=False, bias=False, weight_init_factor=.01)
|
||||
|
||||
# switch options
|
||||
self.nf = nf
|
||||
transformation_filters = nf
|
||||
self.transformation_counts = xforms
|
||||
multiplx_fn = functools.partial(QueryKeyMultiplexer, transformation_filters, embedding_channels=512, reductions=multiplexer_reductions)
|
||||
transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5),
|
||||
transformation_filters, kernel_size=3, depth=3,
|
||||
weight_init_factor=.1)
|
||||
|
||||
# Feature branch
|
||||
self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=7, norm=False, activation=False)
|
||||
self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||
pre_transform_block=None, transform_block=transform_fn,
|
||||
attention_norm=True,
|
||||
transform_count=self.transformation_counts, init_temp=init_temperature,
|
||||
add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True)
|
||||
self.sw2 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||
pre_transform_block=None, transform_block=transform_fn,
|
||||
attention_norm=True,
|
||||
transform_count=self.transformation_counts, init_temp=init_temperature,
|
||||
add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True)
|
||||
|
||||
# Grad branch. Note - groupnorm on this branch is REALLY bad. Avoid it like the plague.
|
||||
self.get_g_nopadding = ImageGradientNoPadding()
|
||||
self.grad_conv = ConvGnLelu(in_nc, nf, kernel_size=7, norm=False, activation=False, bias=False)
|
||||
self.grad_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.3, final_norm=False)
|
||||
|
||||
self.sw_grad = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||
pre_transform_block=None, transform_block=transform_fn,
|
||||
attention_norm=True,
|
||||
transform_count=self.transformation_counts // 2, init_temp=init_temperature,
|
||||
add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True)
|
||||
self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
|
||||
self.grad_lr_conv2 = ConvGnLelu(nf, nf, kernel_size=1, norm=False, activation=True, bias=True)
|
||||
self.upsample_grad = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=False) for _ in range(n_upscale)])
|
||||
self.grad_branch_output_conv = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=True)
|
||||
|
||||
# Join branch (grad+fea)
|
||||
self.noise_ref_join_conjoin = ReferenceJoinBlock(nf, residual_weight_init_factor=.1)
|
||||
self.conjoin_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.3)
|
||||
self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||
pre_transform_block=None, transform_block=transform_fn,
|
||||
attention_norm=True,
|
||||
transform_count=self.transformation_counts, init_temp=init_temperature,
|
||||
add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True)
|
||||
self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
|
||||
self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=True) for _ in range(n_upscale)])
|
||||
self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=True)
|
||||
self.final_hr_conv2 = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=False)
|
||||
self.switches = [self.sw1, self.sw2, self.sw_grad, self.conjoin_sw]
|
||||
self.attentions = None
|
||||
self.init_temperature = init_temperature
|
||||
self.final_temperature_step = 10000
|
||||
self.lr = None
|
||||
|
||||
def forward(self, x, ref, ref_center, update_attention_norm=True, recurrent=None):
|
||||
# The attention_maps debugger outputs <x>. Save that here.
|
||||
self.lr = x.detach().cpu()
|
||||
|
||||
x_grad = self.get_g_nopadding(x)
|
||||
ref_code = self.reference_embedding(ref, ref_center)
|
||||
ref_embedding = ref_code.view(-1, self.nf * 8, 1, 1).repeat(1, 1, x.shape[2] // 8, x.shape[3] // 8)
|
||||
|
||||
x = self.model_fea_conv(x)
|
||||
if self.recurrent:
|
||||
rec = self.model_recurrent_conv(recurrent)
|
||||
br = self.model_fea_recurrent_combine(torch.cat([x, rec], dim=1))
|
||||
x = x + br
|
||||
|
||||
x1 = x
|
||||
x1, a1 = self.sw1(x1, identity=x, att_in=(x1, ref_embedding), do_checkpointing=True)
|
||||
|
||||
x2 = x1
|
||||
x2, a2 = self.sw2(x2, identity=x1, att_in=(x2, ref_embedding), do_checkpointing=True)
|
||||
|
||||
x_grad = self.grad_conv(x_grad)
|
||||
x_grad_identity = x_grad
|
||||
x_grad, grad_fea_std = checkpoint(self.grad_ref_join, x_grad, x1)
|
||||
x_grad, a3 = self.sw_grad(x_grad, identity=x_grad_identity, att_in=(x_grad, ref_embedding), do_checkpointing=True)
|
||||
x_grad = checkpoint(self.grad_lr_conv, x_grad)
|
||||
x_grad = checkpoint(self.grad_lr_conv2, x_grad)
|
||||
x_grad_out = checkpoint(self.upsample_grad, x_grad)
|
||||
x_grad_out = checkpoint(self.grad_branch_output_conv, x_grad_out)
|
||||
|
||||
x_out = x2
|
||||
x_out, fea_grad_std = self.conjoin_ref_join(x_out, x_grad)
|
||||
x_out, a4 = self.conjoin_sw(x_out, identity=x2, att_in=(x_out, ref_embedding), do_checkpointing=True)
|
||||
x_out = checkpoint(self.final_lr_conv, x_out)
|
||||
x_out = checkpoint(self.upsample, x_out)
|
||||
x_out = checkpoint(self.final_hr_conv1, x_out)
|
||||
x_out = checkpoint(self.final_hr_conv2, x_out)
|
||||
|
||||
self.attentions = [a1, a2, a3, a4]
|
||||
self.grad_fea_std = grad_fea_std.detach().cpu()
|
||||
self.fea_grad_std = fea_grad_std.detach().cpu()
|
||||
return x_grad_out, x_out
|
||||
|
||||
def set_temperature(self, temp):
|
||||
[sw.set_temperature(temp) for sw in self.switches]
|
||||
|
||||
def update_for_step(self, step, experiments_path='.'):
|
||||
if self.attentions:
|
||||
temp = max(1, 1 + self.init_temperature *
|
||||
(self.final_temperature_step - step) / self.final_temperature_step)
|
||||
self.set_temperature(temp)
|
||||
if step % 500 == 0:
|
||||
output_path = os.path.join(experiments_path, "attention_maps")
|
||||
prefix = "amap_%i_a%i_%%i.png"
|
||||
[save_attention_to_image_rgb(output_path, self.attentions[i], self.transformation_counts, prefix % (step, i), step, output_mag=False) for i in range(len(self.attentions))]
|
||||
torchvision.utils.save_image(self.lr, os.path.join(experiments_path, "attention_maps", "amap_%i_base_image.png" % (step,)))
|
||||
|
||||
def get_debug_values(self, step, net_name):
|
||||
temp = self.switches[0].switch.temperature
|
||||
mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]
|
||||
means = [i[0] for i in mean_hists]
|
||||
hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists]
|
||||
val = {"switch_temperature": temp,
|
||||
"grad_branch_feat_intg_std_dev": self.grad_fea_std,
|
||||
"conjoin_branch_grad_intg_std_dev": self.fea_grad_std}
|
||||
for i in range(len(means)):
|
||||
val["switch_%i_specificity" % (i,)] = means[i]
|
||||
val["switch_%i_histogram" % (i,)] = hists[i]
|
||||
return val
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
def __init__(self, nf, num_transforms, multiplexer_reductions, init_temperature=10, has_ref=True):
|
||||
super(AttentionBlock, self).__init__()
|
||||
self.nf = nf
|
||||
self.transformation_counts = num_transforms
|
||||
multiplx_fn = functools.partial(QueryKeyMultiplexer, nf, embedding_channels=512, reductions=multiplexer_reductions)
|
||||
transform_fn = functools.partial(MultiConvBlock, nf, int(nf * 1.5),
|
||||
nf, kernel_size=3, depth=4,
|
||||
weight_init_factor=.1)
|
||||
if has_ref:
|
||||
self.ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.3, final_norm=False)
|
||||
else:
|
||||
self.ref_join = None
|
||||
self.switch = ConfigurableSwitchComputer(nf, multiplx_fn,
|
||||
pre_transform_block=None, transform_block=transform_fn,
|
||||
attention_norm=True,
|
||||
transform_count=self.transformation_counts, init_temp=init_temperature,
|
||||
add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True)
|
||||
|
||||
def forward(self, x, mplex_ref=None, ref=None):
|
||||
if self.ref_join is not None:
|
||||
branch, ref_std = self.ref_join(x, ref)
|
||||
return self.switch(branch, identity=x, att_in=(branch, mplex_ref)) + (ref_std,)
|
||||
else:
|
||||
return self.switch(x, identity=x, att_in=(x, mplex_ref))
|
||||
|
||||
|
||||
class SwitchedSpsr(nn.Module):
|
||||
def __init__(self, in_nc, nf, xforms=8, upscale=4, init_temperature=10):
|
||||
super(SwitchedSpsr, self).__init__()
|
||||
n_upscale = int(math.log(upscale, 2))
|
||||
|
||||
# switch options
|
||||
transformation_filters = nf
|
||||
switch_filters = nf
|
||||
switch_reductions = 3
|
||||
switch_processing_layers = 2
|
||||
self.transformation_counts = xforms
|
||||
multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, switch_reductions,
|
||||
switch_processing_layers, self.transformation_counts, use_exp2=True)
|
||||
pretransform_fn = functools.partial(ConvGnLelu, transformation_filters, transformation_filters, norm=False, bias=False, weight_init_factor=.1)
|
||||
transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5),
|
||||
transformation_filters, kernel_size=3, depth=3,
|
||||
weight_init_factor=.1)
|
||||
|
||||
# Feature branch
|
||||
self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False)
|
||||
self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||
attention_norm=True,
|
||||
transform_count=self.transformation_counts, init_temp=init_temperature,
|
||||
add_scalable_noise_to_transforms=True)
|
||||
self.sw2 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||
attention_norm=True,
|
||||
transform_count=self.transformation_counts, init_temp=init_temperature,
|
||||
add_scalable_noise_to_transforms=True)
|
||||
self.feature_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False)
|
||||
self.feature_hr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False)
|
||||
|
||||
# Grad branch
|
||||
self.get_g_nopadding = ImageGradientNoPadding()
|
||||
self.b_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False)
|
||||
mplex_grad = functools.partial(ConvBasisMultiplexer, nf * 2, nf * 2, switch_reductions,
|
||||
switch_processing_layers, self.transformation_counts // 2, use_exp2=True)
|
||||
self.sw_grad = ConfigurableSwitchComputer(transformation_filters, mplex_grad,
|
||||
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||
attention_norm=True,
|
||||
transform_count=self.transformation_counts // 2, init_temp=init_temperature,
|
||||
add_scalable_noise_to_transforms=True)
|
||||
# Upsampling
|
||||
self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False)
|
||||
self.grad_hr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False)
|
||||
# Conv used to output grad branch shortcut.
|
||||
self.grad_branch_output_conv = ConvGnLelu(nf, 3, kernel_size=1, norm=False, activation=False, bias=False)
|
||||
|
||||
# Conjoin branch.
|
||||
# Note: "_branch_pretrain" is a special tag used to denote parameters that get pretrained before the rest.
|
||||
transform_fn_cat = functools.partial(MultiConvBlock, transformation_filters * 2, int(transformation_filters * 1.5),
|
||||
transformation_filters, kernel_size=3, depth=4,
|
||||
weight_init_factor=.1)
|
||||
pretransform_fn_cat = functools.partial(ConvGnLelu, transformation_filters * 2, transformation_filters * 2, norm=False, bias=False, weight_init_factor=.1)
|
||||
self._branch_pretrain_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||
pre_transform_block=pretransform_fn_cat, transform_block=transform_fn_cat,
|
||||
attention_norm=True,
|
||||
transform_count=self.transformation_counts, init_temp=init_temperature,
|
||||
add_scalable_noise_to_transforms=True)
|
||||
self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=True) for _ in range(n_upscale)])
|
||||
self.upsample_grad = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=True) for _ in range(n_upscale)])
|
||||
self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False)
|
||||
self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=True)
|
||||
self.final_hr_conv2 = ConvGnLelu(nf, 3, kernel_size=3, norm=False, activation=False, bias=False)
|
||||
self.switches = [self.sw1, self.sw2, self.sw_grad, self._branch_pretrain_sw]
|
||||
self.attentions = None
|
||||
self.init_temperature = init_temperature
|
||||
self.final_temperature_step = 10000
|
||||
|
||||
def forward(self, x):
|
||||
x_grad = self.get_g_nopadding(x)
|
||||
x = self.model_fea_conv(x)
|
||||
|
||||
x1, a1 = self.sw1(x, do_checkpointing=True)
|
||||
x2, a2 = self.sw2(x1, do_checkpointing=True)
|
||||
x_fea = self.feature_lr_conv(x2)
|
||||
x_fea = self.feature_hr_conv2(x_fea)
|
||||
|
||||
x_b_fea = self.b_fea_conv(x_grad)
|
||||
x_grad, a3 = self.sw_grad(x_b_fea, att_in=torch.cat([x1, x_b_fea], dim=1), output_attention_weights=True, do_checkpointing=True)
|
||||
x_grad = checkpoint(self.grad_lr_conv, x_grad)
|
||||
x_grad = checkpoint(self.grad_hr_conv, x_grad)
|
||||
x_out_branch = checkpoint(self.upsample_grad, x_grad)
|
||||
x_out_branch = self.grad_branch_output_conv(x_out_branch)
|
||||
|
||||
x__branch_pretrain_cat = torch.cat([x_grad, x_fea], dim=1)
|
||||
x__branch_pretrain_cat, a4 = self._branch_pretrain_sw(x__branch_pretrain_cat, att_in=x_fea, identity=x_fea, output_attention_weights=True)
|
||||
x_out = checkpoint(self.final_lr_conv, x__branch_pretrain_cat)
|
||||
x_out = checkpoint(self.upsample, x_out)
|
||||
x_out = checkpoint(self.final_hr_conv1, x_out)
|
||||
x_out = self.final_hr_conv2(x_out)
|
||||
|
||||
self.attentions = [a1, a2, a3, a4]
|
||||
|
||||
return x_out_branch, x_out, x_grad
|
||||
|
||||
def set_temperature(self, temp):
|
||||
[sw.set_temperature(temp) for sw in self.switches]
|
||||
|
||||
def update_for_step(self, step, experiments_path='.'):
|
||||
if self.attentions:
|
||||
temp = max(1, 1 + self.init_temperature *
|
||||
(self.final_temperature_step - step) / self.final_temperature_step)
|
||||
self.set_temperature(temp)
|
||||
if step % 200 == 0:
|
||||
output_path = os.path.join(experiments_path, "attention_maps", "a%i")
|
||||
prefix = "attention_map_%i_%%i.png" % (step,)
|
||||
[save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))]
|
||||
|
||||
def get_debug_values(self, step, net):
|
||||
temp = self.switches[0].switch.temperature
|
||||
mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]
|
||||
means = [i[0] for i in mean_hists]
|
||||
hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists]
|
||||
val = {"switch_temperature": temp}
|
||||
for i in range(len(means)):
|
||||
val["switch_%i_specificity" % (i,)] = means[i]
|
||||
val["switch_%i_histogram" % (i,)] = hists[i]
|
||||
return val
|
|
@ -1,163 +0,0 @@
|
|||
from collections import OrderedDict
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
####################
|
||||
# Basic blocks
|
||||
####################
|
||||
def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1):
|
||||
# helper selecting activation
|
||||
# neg_slope: for leakyrelu and init of prelu
|
||||
# n_prelu: for p_relu num_parameters
|
||||
act_type = act_type.lower()
|
||||
if act_type == 'relu':
|
||||
layer = nn.ReLU(inplace)
|
||||
elif act_type == 'leakyrelu':
|
||||
layer = nn.LeakyReLU(neg_slope, inplace)
|
||||
elif act_type == 'prelu':
|
||||
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
|
||||
else:
|
||||
raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
|
||||
return layer
|
||||
|
||||
def norm(norm_type, nc):
|
||||
# helper selecting normalization layer
|
||||
norm_type = norm_type.lower()
|
||||
if norm_type == 'batch':
|
||||
layer = nn.BatchNorm2d(nc, affine=True)
|
||||
elif norm_type == 'instance':
|
||||
layer = nn.InstanceNorm2d(nc, affine=False)
|
||||
elif norm_type == 'group':
|
||||
layer = nn.GroupNorm(8, nc)
|
||||
else:
|
||||
raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
|
||||
return layer
|
||||
|
||||
def pad(pad_type, padding):
|
||||
# helper selecting padding layer
|
||||
# if padding is 'zero', do by conv layers
|
||||
pad_type = pad_type.lower()
|
||||
if padding == 0:
|
||||
return None
|
||||
if pad_type == 'reflect':
|
||||
layer = nn.ReflectionPad2d(padding)
|
||||
elif pad_type == 'replicate':
|
||||
layer = nn.ReplicationPad2d(padding)
|
||||
else:
|
||||
raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
|
||||
return layer
|
||||
|
||||
|
||||
def get_valid_padding(kernel_size, dilation):
|
||||
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
|
||||
padding = (kernel_size - 1) // 2
|
||||
return padding
|
||||
|
||||
|
||||
class ConcatBlock(nn.Module):
|
||||
# Concat the output of a submodule to its input
|
||||
def __init__(self, submodule):
|
||||
super(ConcatBlock, self).__init__()
|
||||
self.sub = submodule
|
||||
|
||||
def forward(self, x):
|
||||
output = torch.cat((x, self.sub(x)), dim=1)
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
tmpstr = 'Identity .. \n|'
|
||||
modstr = self.sub.__repr__().replace('\n', '\n|')
|
||||
tmpstr = tmpstr + modstr
|
||||
return tmpstr
|
||||
|
||||
|
||||
class ShortcutBlock(nn.Module):
|
||||
#Elementwise sum the output of a submodule to its input
|
||||
def __init__(self, submodule):
|
||||
super(ShortcutBlock, self).__init__()
|
||||
self.sub = submodule
|
||||
|
||||
def forward(self, x):
|
||||
return x, self.sub
|
||||
|
||||
def __repr__(self):
|
||||
tmpstr = 'Identity + \n|'
|
||||
modstr = self.sub.__repr__().replace('\n', '\n|')
|
||||
tmpstr = tmpstr + modstr
|
||||
return tmpstr
|
||||
|
||||
|
||||
def sequential(*args):
|
||||
# Flatten Sequential. It unwraps nn.Sequential.
|
||||
if len(args) == 1:
|
||||
if isinstance(args[0], OrderedDict):
|
||||
raise NotImplementedError('sequential does not support OrderedDict input.')
|
||||
return args[0] # No sequential is needed.
|
||||
modules = []
|
||||
for module in args:
|
||||
if isinstance(module, nn.Sequential):
|
||||
for submodule in module.children():
|
||||
modules.append(submodule)
|
||||
elif isinstance(module, nn.Module):
|
||||
modules.append(module)
|
||||
return nn.Sequential(*modules)
|
||||
|
||||
|
||||
def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True, \
|
||||
pad_type='zero', norm_type=None, act_type='relu', mode='CNA'):
|
||||
'''
|
||||
Conv layer with padding, normalization, activation
|
||||
mode: CNA --> Conv -> Norm -> Act
|
||||
NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
|
||||
'''
|
||||
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode)
|
||||
padding = get_valid_padding(kernel_size, dilation)
|
||||
p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
|
||||
padding = padding if pad_type == 'zero' else 0
|
||||
|
||||
c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, \
|
||||
dilation=dilation, bias=bias, groups=groups)
|
||||
a = act(act_type) if act_type else None
|
||||
if 'CNA' in mode:
|
||||
n = norm(norm_type, out_nc) if norm_type else None
|
||||
return sequential(p, c, n, a)
|
||||
elif mode == 'NAC':
|
||||
if norm_type is None and act_type is not None:
|
||||
a = act(act_type, inplace=False)
|
||||
# Important!
|
||||
# input----ReLU(inplace)----Conv--+----output
|
||||
# |________________________|
|
||||
# inplace ReLU will modify the input, therefore wrong output
|
||||
n = norm(norm_type, in_nc) if norm_type else None
|
||||
return sequential(n, a, p, c)
|
||||
|
||||
|
||||
####################
|
||||
# Upsampler
|
||||
####################
|
||||
|
||||
|
||||
def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, \
|
||||
pad_type='zero', norm_type=None, act_type='relu'):
|
||||
'''
|
||||
Pixel shuffle layer
|
||||
(Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
|
||||
Neural Network, CVPR17)
|
||||
'''
|
||||
conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias, \
|
||||
pad_type=pad_type, norm_type=None, act_type=None)
|
||||
pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
||||
|
||||
n = norm(norm_type, out_nc) if norm_type else None
|
||||
a = act(act_type) if act_type else None
|
||||
return sequential(conv, pixel_shuffle, n, a)
|
||||
|
||||
|
||||
def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, \
|
||||
pad_type='zero', norm_type=None, act_type='relu', mode='nearest'):
|
||||
# Up conv
|
||||
# described in https://distill.pub/2016/deconv-checkerboard/
|
||||
upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode)
|
||||
conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias, \
|
||||
pad_type=pad_type, norm_type=norm_type, act_type=act_type)
|
||||
return sequential(upsample, conv)
|
|
@ -1,55 +0,0 @@
|
|||
import functools
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import models.archs.arch_util as arch_util
|
||||
|
||||
|
||||
class MSRResNet(nn.Module):
|
||||
''' modified SRResNet'''
|
||||
|
||||
def __init__(self, in_nc=3, out_nc=3, nf=64, nb=16, upscale=4):
|
||||
super(MSRResNet, self).__init__()
|
||||
self.upscale = upscale
|
||||
|
||||
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||
basic_block = functools.partial(arch_util.ResidualBlock_noBN, nf=nf)
|
||||
self.recon_trunk = arch_util.make_layer(basic_block, nb)
|
||||
|
||||
# upsampling
|
||||
if self.upscale == 2:
|
||||
self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
|
||||
self.pixel_shuffle = nn.PixelShuffle(2)
|
||||
elif self.upscale == 3:
|
||||
self.upconv1 = nn.Conv2d(nf, nf * 9, 3, 1, 1, bias=True)
|
||||
self.pixel_shuffle = nn.PixelShuffle(3)
|
||||
elif self.upscale == 4:
|
||||
self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
|
||||
self.upconv2 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
|
||||
self.pixel_shuffle = nn.PixelShuffle(2)
|
||||
|
||||
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
||||
|
||||
# activation function
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||
|
||||
# initialization
|
||||
arch_util.initialize_weights([self.conv_first, self.upconv1, self.HRconv, self.conv_last],
|
||||
0.1)
|
||||
if self.upscale == 4:
|
||||
arch_util.initialize_weights(self.upconv2, 0.1)
|
||||
|
||||
def forward(self, x):
|
||||
fea = self.lrelu(self.conv_first(x))
|
||||
out = self.recon_trunk(fea)
|
||||
|
||||
if self.upscale == 4:
|
||||
out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
|
||||
out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
|
||||
elif self.upscale == 3 or self.upscale == 2:
|
||||
out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
|
||||
|
||||
out = self.conv_last(self.lrelu(self.HRconv(out)))
|
||||
base = F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False)
|
||||
out += base
|
||||
return out
|
|
@ -1,139 +0,0 @@
|
|||
import functools
|
||||
|
||||
import torch
|
||||
from torch.nn import init
|
||||
|
||||
import models.archs.biggan.biggan_layers as layers
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
# Discriminator architecture, same paradigm as G's above
|
||||
def D_arch(ch=64, attention='64',ksize='333333', dilation='111111'):
|
||||
arch = {}
|
||||
arch[256] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8, 8, 16]],
|
||||
'out_channels' : [item * ch for item in [1, 2, 4, 8, 8, 16, 16]],
|
||||
'downsample' : [True] * 6 + [False],
|
||||
'resolution' : [128, 64, 32, 16, 8, 4, 4 ],
|
||||
'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')]
|
||||
for i in range(2,8)}}
|
||||
arch[128] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8, 16]],
|
||||
'out_channels' : [item * ch for item in [1, 2, 4, 8, 16, 16]],
|
||||
'downsample' : [True] * 5 + [False],
|
||||
'resolution' : [64, 32, 16, 8, 4, 4],
|
||||
'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')]
|
||||
for i in range(2,8)}}
|
||||
arch[64] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8]],
|
||||
'out_channels' : [item * ch for item in [1, 2, 4, 8, 16]],
|
||||
'downsample' : [True] * 4 + [False],
|
||||
'resolution' : [32, 16, 8, 4, 4],
|
||||
'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')]
|
||||
for i in range(2,7)}}
|
||||
arch[32] = {'in_channels' : [3] + [item * ch for item in [4, 4, 4]],
|
||||
'out_channels' : [item * ch for item in [4, 4, 4, 4]],
|
||||
'downsample' : [True, True, False, False],
|
||||
'resolution' : [16, 16, 16, 16],
|
||||
'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')]
|
||||
for i in range(2,6)}}
|
||||
return arch
|
||||
|
||||
|
||||
class BigGanDiscriminator(nn.Module):
|
||||
|
||||
def __init__(self, D_ch=64, D_wide=True, resolution=128,
|
||||
D_kernel_size=3, D_attn='64', num_D_SVs=1, num_D_SV_itrs=1, D_activation=nn.ReLU(inplace=False),
|
||||
SN_eps=1e-12, output_dim=1, D_fp16=False,
|
||||
D_init='ortho', skip_init=False, D_param='SN'):
|
||||
super(BigGanDiscriminator, self).__init__()
|
||||
# Width multiplier
|
||||
self.ch = D_ch
|
||||
# Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN?
|
||||
self.D_wide = D_wide
|
||||
# Resolution
|
||||
self.resolution = resolution
|
||||
# Kernel size
|
||||
self.kernel_size = D_kernel_size
|
||||
# Attention?
|
||||
self.attention = D_attn
|
||||
# Activation
|
||||
self.activation = D_activation
|
||||
# Initialization style
|
||||
self.init = D_init
|
||||
# Parameterization style
|
||||
self.D_param = D_param
|
||||
# Epsilon for Spectral Norm?
|
||||
self.SN_eps = SN_eps
|
||||
# Fp16?
|
||||
self.fp16 = D_fp16
|
||||
# Architecture
|
||||
self.arch = D_arch(self.ch, self.attention)[resolution]
|
||||
|
||||
# Which convs, batchnorms, and linear layers to use
|
||||
# No option to turn off SN in D right now
|
||||
if self.D_param == 'SN':
|
||||
self.which_conv = functools.partial(layers.SNConv2d,
|
||||
kernel_size=3, padding=1,
|
||||
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
|
||||
eps=self.SN_eps)
|
||||
self.which_linear = functools.partial(layers.SNLinear,
|
||||
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
|
||||
eps=self.SN_eps)
|
||||
self.which_embedding = functools.partial(layers.SNEmbedding,
|
||||
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
|
||||
eps=self.SN_eps)
|
||||
# Prepare model
|
||||
# self.blocks is a doubly-nested list of modules, the outer loop intended
|
||||
# to be over blocks at a given resolution (resblocks and/or self-attention)
|
||||
self.blocks = []
|
||||
for index in range(len(self.arch['out_channels'])):
|
||||
self.blocks += [[layers.DBlock(in_channels=self.arch['in_channels'][index],
|
||||
out_channels=self.arch['out_channels'][index],
|
||||
which_conv=self.which_conv,
|
||||
wide=self.D_wide,
|
||||
activation=self.activation,
|
||||
preactivation=(index > 0),
|
||||
downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] else None))]]
|
||||
# If attention on this block, attach it to the end
|
||||
if self.arch['attention'][self.arch['resolution'][index]]:
|
||||
print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index])
|
||||
self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index],
|
||||
self.which_conv)]
|
||||
# Turn self.blocks into a ModuleList so that it's all properly registered.
|
||||
self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
|
||||
# Linear output layer. The output dimension is typically 1, but may be
|
||||
# larger if we're e.g. turning this into a VAE with an inference output
|
||||
self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim)
|
||||
|
||||
# Initialize weights
|
||||
if not skip_init:
|
||||
self.init_weights()
|
||||
|
||||
# Initialize
|
||||
def init_weights(self):
|
||||
self.param_count = 0
|
||||
for module in self.modules():
|
||||
if (isinstance(module, nn.Conv2d)
|
||||
or isinstance(module, nn.Linear)
|
||||
or isinstance(module, nn.Embedding)):
|
||||
if self.init == 'ortho':
|
||||
init.orthogonal_(module.weight)
|
||||
elif self.init == 'N02':
|
||||
init.normal_(module.weight, 0, 0.02)
|
||||
elif self.init in ['glorot', 'xavier']:
|
||||
init.xavier_uniform_(module.weight)
|
||||
else:
|
||||
print('Init style not recognized...')
|
||||
self.param_count += sum([p.data.nelement() for p in module.parameters()])
|
||||
print('Param count for D''s initialized parameters: %d' % self.param_count)
|
||||
|
||||
def forward(self, x, y=None):
|
||||
# Stick x into h for cleaner for loops without flow control
|
||||
h = x
|
||||
# Loop over blocks
|
||||
for index, blocklist in enumerate(self.blocks):
|
||||
for block in blocklist:
|
||||
h = block(h)
|
||||
# Apply global sum pooling as in SN-GAN
|
||||
h = torch.sum(self.activation(h), [2, 3])
|
||||
# Get initial class-unconditional output
|
||||
out = self.linear(h)
|
||||
return out
|
|
@ -1,457 +0,0 @@
|
|||
''' Layers
|
||||
This file contains various layers for the BigGAN models.
|
||||
'''
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import init
|
||||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Parameter as P
|
||||
|
||||
|
||||
# Projection of x onto y
|
||||
def proj(x, y):
|
||||
return torch.mm(y, x.t()) * y / torch.mm(y, y.t())
|
||||
|
||||
|
||||
# Orthogonalize x wrt list of vectors ys
|
||||
def gram_schmidt(x, ys):
|
||||
for y in ys:
|
||||
x = x - proj(x, y)
|
||||
return x
|
||||
|
||||
|
||||
# Apply num_itrs steps of the power method to estimate top N singular values.
|
||||
def power_iteration(W, u_, update=True, eps=1e-12):
|
||||
# Lists holding singular vectors and values
|
||||
us, vs, svs = [], [], []
|
||||
for i, u in enumerate(u_):
|
||||
# Run one step of the power iteration
|
||||
with torch.no_grad():
|
||||
v = torch.matmul(u, W)
|
||||
# Run Gram-Schmidt to subtract components of all other singular vectors
|
||||
v = F.normalize(gram_schmidt(v, vs), eps=eps)
|
||||
# Add to the list
|
||||
vs += [v]
|
||||
# Update the other singular vector
|
||||
u = torch.matmul(v, W.t())
|
||||
# Run Gram-Schmidt to subtract components of all other singular vectors
|
||||
u = F.normalize(gram_schmidt(u, us), eps=eps)
|
||||
# Add to the list
|
||||
us += [u]
|
||||
if update:
|
||||
u_[i][:] = u
|
||||
# Compute this singular value and add it to the list
|
||||
svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))]
|
||||
# svs += [torch.sum(F.linear(u, W.transpose(0, 1)) * v)]
|
||||
return svs, us, vs
|
||||
|
||||
|
||||
# Convenience passthrough function
|
||||
class identity(nn.Module):
|
||||
def forward(self, input):
|
||||
return input
|
||||
|
||||
|
||||
# Spectral normalization base class
|
||||
class SN(object):
|
||||
def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12):
|
||||
# Number of power iterations per step
|
||||
self.num_itrs = num_itrs
|
||||
# Number of singular values
|
||||
self.num_svs = num_svs
|
||||
# Transposed?
|
||||
self.transpose = transpose
|
||||
# Epsilon value for avoiding divide-by-0
|
||||
self.eps = eps
|
||||
# Register a singular vector for each sv
|
||||
for i in range(self.num_svs):
|
||||
self.register_buffer('u%d' % i, torch.randn(1, num_outputs))
|
||||
self.register_buffer('sv%d' % i, torch.ones(1))
|
||||
|
||||
# Singular vectors (u side)
|
||||
@property
|
||||
def u(self):
|
||||
return [getattr(self, 'u%d' % i) for i in range(self.num_svs)]
|
||||
|
||||
# Singular values;
|
||||
# note that these buffers are just for logging and are not used in training.
|
||||
@property
|
||||
def sv(self):
|
||||
return [getattr(self, 'sv%d' % i) for i in range(self.num_svs)]
|
||||
|
||||
# Compute the spectrally-normalized weight
|
||||
def W_(self):
|
||||
W_mat = self.weight.view(self.weight.size(0), -1)
|
||||
if self.transpose:
|
||||
W_mat = W_mat.t()
|
||||
# Apply num_itrs power iterations
|
||||
for _ in range(self.num_itrs):
|
||||
svs, us, vs = power_iteration(W_mat, self.u, update=self.training, eps=self.eps)
|
||||
# Update the svs
|
||||
if self.training:
|
||||
with torch.no_grad(): # Make sure to do this in a no_grad() context or you'll get memory leaks!
|
||||
for i, sv in enumerate(svs):
|
||||
self.sv[i][:] = sv
|
||||
return self.weight / svs[0]
|
||||
|
||||
|
||||
# 2D Conv layer with spectral norm
|
||||
class SNConv2d(nn.Conv2d, SN):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
||||
padding=0, dilation=1, groups=1, bias=True,
|
||||
num_svs=1, num_itrs=1, eps=1e-12):
|
||||
nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride,
|
||||
padding, dilation, groups, bias)
|
||||
SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps)
|
||||
|
||||
def forward(self, x):
|
||||
return F.conv2d(x, self.W_(), self.bias, self.stride,
|
||||
self.padding, self.dilation, self.groups)
|
||||
|
||||
|
||||
# Linear layer with spectral norm
|
||||
class SNLinear(nn.Linear, SN):
|
||||
def __init__(self, in_features, out_features, bias=True,
|
||||
num_svs=1, num_itrs=1, eps=1e-12):
|
||||
nn.Linear.__init__(self, in_features, out_features, bias)
|
||||
SN.__init__(self, num_svs, num_itrs, out_features, eps=eps)
|
||||
|
||||
def forward(self, x):
|
||||
return F.linear(x, self.W_(), self.bias)
|
||||
|
||||
|
||||
# Embedding layer with spectral norm
|
||||
# We use num_embeddings as the dim instead of embedding_dim here
|
||||
# for convenience sake
|
||||
class SNEmbedding(nn.Embedding, SN):
|
||||
def __init__(self, num_embeddings, embedding_dim, padding_idx=None,
|
||||
max_norm=None, norm_type=2, scale_grad_by_freq=False,
|
||||
sparse=False, _weight=None,
|
||||
num_svs=1, num_itrs=1, eps=1e-12):
|
||||
nn.Embedding.__init__(self, num_embeddings, embedding_dim, padding_idx,
|
||||
max_norm, norm_type, scale_grad_by_freq,
|
||||
sparse, _weight)
|
||||
SN.__init__(self, num_svs, num_itrs, num_embeddings, eps=eps)
|
||||
|
||||
def forward(self, x):
|
||||
return F.embedding(x, self.W_())
|
||||
|
||||
|
||||
# A non-local block as used in SA-GAN
|
||||
# Note that the implementation as described in the paper is largely incorrect;
|
||||
# refer to the released code for the actual implementation.
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, ch, which_conv=SNConv2d, name='attention'):
|
||||
super(Attention, self).__init__()
|
||||
# Channel multiplier
|
||||
self.ch = ch
|
||||
self.which_conv = which_conv
|
||||
self.theta = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False)
|
||||
self.phi = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False)
|
||||
self.g = self.which_conv(self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False)
|
||||
self.o = self.which_conv(self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False)
|
||||
# Learnable gain parameter
|
||||
self.gamma = P(torch.tensor(0.), requires_grad=True)
|
||||
|
||||
def forward(self, x, y=None):
|
||||
# Apply convs
|
||||
theta = self.theta(x)
|
||||
phi = F.max_pool2d(self.phi(x), [2, 2])
|
||||
g = F.max_pool2d(self.g(x), [2, 2])
|
||||
# Perform reshapes
|
||||
theta = theta.view(-1, self.ch // 8, x.shape[2] * x.shape[3])
|
||||
phi = phi.view(-1, self.ch // 8, x.shape[2] * x.shape[3] // 4)
|
||||
g = g.view(-1, self.ch // 2, x.shape[2] * x.shape[3] // 4)
|
||||
# Matmul and softmax to get attention maps
|
||||
beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1)
|
||||
# Attention map times g path
|
||||
o = self.o(torch.bmm(g, beta.transpose(1, 2)).view(-1, self.ch // 2, x.shape[2], x.shape[3]))
|
||||
return self.gamma * o + x
|
||||
|
||||
|
||||
# Fused batchnorm op
|
||||
def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5):
|
||||
# Apply scale and shift--if gain and bias are provided, fuse them here
|
||||
# Prepare scale
|
||||
scale = torch.rsqrt(var + eps)
|
||||
# If a gain is provided, use it
|
||||
if gain is not None:
|
||||
scale = scale * gain
|
||||
# Prepare shift
|
||||
shift = mean * scale
|
||||
# If bias is provided, use it
|
||||
if bias is not None:
|
||||
shift = shift - bias
|
||||
return x * scale - shift
|
||||
# return ((x - mean) / ((var + eps) ** 0.5)) * gain + bias # The unfused way.
|
||||
|
||||
|
||||
# Manual BN
|
||||
# Calculate means and variances using mean-of-squares minus mean-squared
|
||||
def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5):
|
||||
# Cast x to float32 if necessary
|
||||
float_x = x.float()
|
||||
# Calculate expected value of x (m) and expected value of x**2 (m2)
|
||||
# Mean of x
|
||||
m = torch.mean(float_x, [0, 2, 3], keepdim=True)
|
||||
# Mean of x squared
|
||||
m2 = torch.mean(float_x ** 2, [0, 2, 3], keepdim=True)
|
||||
# Calculate variance as mean of squared minus mean squared.
|
||||
var = (m2 - m ** 2)
|
||||
# Cast back to float 16 if necessary
|
||||
var = var.type(x.type())
|
||||
m = m.type(x.type())
|
||||
# Return mean and variance for updating stored mean/var if requested
|
||||
if return_mean_var:
|
||||
return fused_bn(x, m, var, gain, bias, eps), m.squeeze(), var.squeeze()
|
||||
else:
|
||||
return fused_bn(x, m, var, gain, bias, eps)
|
||||
|
||||
|
||||
# My batchnorm, supports standing stats
|
||||
class myBN(nn.Module):
|
||||
def __init__(self, num_channels, eps=1e-5, momentum=0.1):
|
||||
super(myBN, self).__init__()
|
||||
# momentum for updating running stats
|
||||
self.momentum = momentum
|
||||
# epsilon to avoid dividing by 0
|
||||
self.eps = eps
|
||||
# Momentum
|
||||
self.momentum = momentum
|
||||
# Register buffers
|
||||
self.register_buffer('stored_mean', torch.zeros(num_channels))
|
||||
self.register_buffer('stored_var', torch.ones(num_channels))
|
||||
self.register_buffer('accumulation_counter', torch.zeros(1))
|
||||
# Accumulate running means and vars
|
||||
self.accumulate_standing = False
|
||||
|
||||
# reset standing stats
|
||||
def reset_stats(self):
|
||||
self.stored_mean[:] = 0
|
||||
self.stored_var[:] = 0
|
||||
self.accumulation_counter[:] = 0
|
||||
|
||||
def forward(self, x, gain, bias):
|
||||
if self.training:
|
||||
out, mean, var = manual_bn(x, gain, bias, return_mean_var=True, eps=self.eps)
|
||||
# If accumulating standing stats, increment them
|
||||
if self.accumulate_standing:
|
||||
self.stored_mean[:] = self.stored_mean + mean.data
|
||||
self.stored_var[:] = self.stored_var + var.data
|
||||
self.accumulation_counter += 1.0
|
||||
# If not accumulating standing stats, take running averages
|
||||
else:
|
||||
self.stored_mean[:] = self.stored_mean * (1 - self.momentum) + mean * self.momentum
|
||||
self.stored_var[:] = self.stored_var * (1 - self.momentum) + var * self.momentum
|
||||
return out
|
||||
# If not in training mode, use the stored statistics
|
||||
else:
|
||||
mean = self.stored_mean.view(1, -1, 1, 1)
|
||||
var = self.stored_var.view(1, -1, 1, 1)
|
||||
# If using standing stats, divide them by the accumulation counter
|
||||
if self.accumulate_standing:
|
||||
mean = mean / self.accumulation_counter
|
||||
var = var / self.accumulation_counter
|
||||
return fused_bn(x, mean, var, gain, bias, self.eps)
|
||||
|
||||
|
||||
# Simple function to handle groupnorm norm stylization
|
||||
def groupnorm(x, norm_style):
|
||||
# If number of channels specified in norm_style:
|
||||
if 'ch' in norm_style:
|
||||
ch = int(norm_style.split('_')[-1])
|
||||
groups = max(int(x.shape[1]) // ch, 1)
|
||||
# If number of groups specified in norm style
|
||||
elif 'grp' in norm_style:
|
||||
groups = int(norm_style.split('_')[-1])
|
||||
# If neither, default to groups = 16
|
||||
else:
|
||||
groups = 16
|
||||
return F.group_norm(x, groups)
|
||||
|
||||
|
||||
# Class-conditional bn
|
||||
# output size is the number of channels, input size is for the linear layers
|
||||
# Andy's Note: this class feels messy but I'm not really sure how to clean it up
|
||||
# Suggestions welcome! (By which I mean, refactor this and make a pull request
|
||||
# if you want to make this more readable/usable).
|
||||
class ccbn(nn.Module):
|
||||
def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1,
|
||||
cross_replica=False, mybn=False, norm_style='bn', ):
|
||||
super(ccbn, self).__init__()
|
||||
self.output_size, self.input_size = output_size, input_size
|
||||
# Prepare gain and bias layers
|
||||
self.gain = which_linear(input_size, output_size)
|
||||
self.bias = which_linear(input_size, output_size)
|
||||
# epsilon to avoid dividing by 0
|
||||
self.eps = eps
|
||||
# Momentum
|
||||
self.momentum = momentum
|
||||
# Use cross-replica batchnorm?
|
||||
self.cross_replica = cross_replica
|
||||
# Use my batchnorm?
|
||||
self.mybn = mybn
|
||||
# Norm style?
|
||||
self.norm_style = norm_style
|
||||
|
||||
if self.cross_replica or self.mybn:
|
||||
self.bn = myBN(output_size, self.eps, self.momentum)
|
||||
elif self.norm_style in ['bn', 'in']:
|
||||
self.register_buffer('stored_mean', torch.zeros(output_size))
|
||||
self.register_buffer('stored_var', torch.ones(output_size))
|
||||
|
||||
def forward(self, x, y):
|
||||
# Calculate class-conditional gains and biases
|
||||
gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
|
||||
bias = self.bias(y).view(y.size(0), -1, 1, 1)
|
||||
# If using my batchnorm
|
||||
if self.mybn or self.cross_replica:
|
||||
return self.bn(x, gain=gain, bias=bias)
|
||||
# else:
|
||||
else:
|
||||
if self.norm_style == 'bn':
|
||||
out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None,
|
||||
self.training, 0.1, self.eps)
|
||||
elif self.norm_style == 'in':
|
||||
out = F.instance_norm(x, self.stored_mean, self.stored_var, None, None,
|
||||
self.training, 0.1, self.eps)
|
||||
elif self.norm_style == 'gn':
|
||||
out = groupnorm(x, self.normstyle)
|
||||
elif self.norm_style == 'nonorm':
|
||||
out = x
|
||||
return out * gain + bias
|
||||
|
||||
def extra_repr(self):
|
||||
s = 'out: {output_size}, in: {input_size},'
|
||||
s += ' cross_replica={cross_replica}'
|
||||
return s.format(**self.__dict__)
|
||||
|
||||
|
||||
# Normal, non-class-conditional BN
|
||||
class bn(nn.Module):
|
||||
def __init__(self, output_size, eps=1e-5, momentum=0.1,
|
||||
cross_replica=False, mybn=False):
|
||||
super(bn, self).__init__()
|
||||
self.output_size = output_size
|
||||
# Prepare gain and bias layers
|
||||
self.gain = P(torch.ones(output_size), requires_grad=True)
|
||||
self.bias = P(torch.zeros(output_size), requires_grad=True)
|
||||
# epsilon to avoid dividing by 0
|
||||
self.eps = eps
|
||||
# Momentum
|
||||
self.momentum = momentum
|
||||
# Use cross-replica batchnorm?
|
||||
self.cross_replica = cross_replica
|
||||
# Use my batchnorm?
|
||||
self.mybn = mybn
|
||||
|
||||
if self.cross_replica or mybn:
|
||||
self.bn = myBN(output_size, self.eps, self.momentum)
|
||||
# Register buffers if neither of the above
|
||||
else:
|
||||
self.register_buffer('stored_mean', torch.zeros(output_size))
|
||||
self.register_buffer('stored_var', torch.ones(output_size))
|
||||
|
||||
def forward(self, x, y=None):
|
||||
if self.cross_replica or self.mybn:
|
||||
gain = self.gain.view(1, -1, 1, 1)
|
||||
bias = self.bias.view(1, -1, 1, 1)
|
||||
return self.bn(x, gain=gain, bias=bias)
|
||||
else:
|
||||
return F.batch_norm(x, self.stored_mean, self.stored_var, self.gain,
|
||||
self.bias, self.training, self.momentum, self.eps)
|
||||
|
||||
|
||||
# Generator blocks
|
||||
# Note that this class assumes the kernel size and padding (and any other
|
||||
# settings) have been selected in the main generator module and passed in
|
||||
# through the which_conv arg. Similar rules apply with which_bn (the input
|
||||
# size [which is actually the number of channels of the conditional info] must
|
||||
# be preselected)
|
||||
class GBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels,
|
||||
which_conv=nn.Conv2d, which_bn=bn, activation=None,
|
||||
upsample=None):
|
||||
super(GBlock, self).__init__()
|
||||
|
||||
self.in_channels, self.out_channels = in_channels, out_channels
|
||||
self.which_conv, self.which_bn = which_conv, which_bn
|
||||
self.activation = activation
|
||||
self.upsample = upsample
|
||||
# Conv layers
|
||||
self.conv1 = self.which_conv(self.in_channels, self.out_channels)
|
||||
self.conv2 = self.which_conv(self.out_channels, self.out_channels)
|
||||
self.learnable_sc = in_channels != out_channels or upsample
|
||||
if self.learnable_sc:
|
||||
self.conv_sc = self.which_conv(in_channels, out_channels,
|
||||
kernel_size=1, padding=0)
|
||||
# Batchnorm layers
|
||||
self.bn1 = self.which_bn(in_channels)
|
||||
self.bn2 = self.which_bn(out_channels)
|
||||
# upsample layers
|
||||
self.upsample = upsample
|
||||
|
||||
def forward(self, x, y):
|
||||
h = self.activation(self.bn1(x, y))
|
||||
if self.upsample:
|
||||
h = self.upsample(h)
|
||||
x = self.upsample(x)
|
||||
h = self.conv1(h)
|
||||
h = self.activation(self.bn2(h, y))
|
||||
h = self.conv2(h)
|
||||
if self.learnable_sc:
|
||||
x = self.conv_sc(x)
|
||||
return h + x
|
||||
|
||||
|
||||
# Residual block for the discriminator
|
||||
class DBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, which_conv=SNConv2d, wide=True,
|
||||
preactivation=False, activation=None, downsample=None, ):
|
||||
super(DBlock, self).__init__()
|
||||
self.in_channels, self.out_channels = in_channels, out_channels
|
||||
# If using wide D (as in SA-GAN and BigGAN), change the channel pattern
|
||||
self.hidden_channels = self.out_channels if wide else self.in_channels
|
||||
self.which_conv = which_conv
|
||||
self.preactivation = preactivation
|
||||
self.activation = activation
|
||||
self.downsample = downsample
|
||||
|
||||
# Conv layers
|
||||
self.conv1 = self.which_conv(self.in_channels, self.hidden_channels)
|
||||
self.conv2 = self.which_conv(self.hidden_channels, self.out_channels)
|
||||
self.learnable_sc = True if (in_channels != out_channels) or downsample else False
|
||||
if self.learnable_sc:
|
||||
self.conv_sc = self.which_conv(in_channels, out_channels,
|
||||
kernel_size=1, padding=0)
|
||||
|
||||
def shortcut(self, x):
|
||||
if self.preactivation:
|
||||
if self.learnable_sc:
|
||||
x = self.conv_sc(x)
|
||||
if self.downsample:
|
||||
x = self.downsample(x)
|
||||
else:
|
||||
if self.downsample:
|
||||
x = self.downsample(x)
|
||||
if self.learnable_sc:
|
||||
x = self.conv_sc(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
if self.preactivation:
|
||||
# h = self.activation(x) # NOT TODAY SATAN
|
||||
# Andy's note: This line *must* be an out-of-place ReLU or it
|
||||
# will negatively affect the shortcut connection.
|
||||
h = F.relu(x)
|
||||
else:
|
||||
h = x
|
||||
h = self.conv1(h)
|
||||
h = self.conv2(self.activation(h))
|
||||
if self.downsample:
|
||||
h = self.downsample(h)
|
||||
|
||||
return h + self.shortcut(x)
|
||||
|
|
@ -1,14 +1,11 @@
|
|||
import copy
|
||||
import random
|
||||
from functools import wraps
|
||||
from time import time
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from data.byol_attachment import reconstructed_shared_regions
|
||||
from models.byol.byol_model_wrapper import singleton, EMA, MLP, get_module_device, set_requires_grad, \
|
||||
from models.archs.byol.byol_model_wrapper import singleton, EMA, get_module_device, set_requires_grad, \
|
||||
update_moving_average
|
||||
from utils.util import checkpoint
|
||||
|
1
codes/models/archs/flownet2
Submodule
1
codes/models/archs/flownet2
Submodule
|
@ -0,0 +1 @@
|
|||
Subproject commit db2b7899ea8506e90418dbd389300c49bdbb55c3
|
|
@ -1,47 +0,0 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from lambda_networks import LambdaLayer
|
||||
from torch.nn import GroupNorm
|
||||
|
||||
from models.archs.RRDBNet_arch import ResidualDenseBlock
|
||||
from models.archs.arch_util import ConvGnLelu
|
||||
|
||||
|
||||
class LambdaRRDB(nn.Module):
|
||||
"""Residual in Residual Dense Block.
|
||||
|
||||
Used in RRDB-Net in ESRGAN.
|
||||
|
||||
Args:
|
||||
mid_channels (int): Channel number of intermediate features.
|
||||
growth_channels (int): Channels for each growth.
|
||||
"""
|
||||
|
||||
def __init__(self, mid_channels, growth_channels=32, reduce_to=None):
|
||||
super(LambdaRRDB, self).__init__()
|
||||
if reduce_to is None:
|
||||
reduce_to = mid_channels
|
||||
self.lam1 = LambdaLayer(dim=mid_channels, dim_out=mid_channels, r=23, dim_k=16, heads=4, dim_u=4)
|
||||
self.gn1 = GroupNorm(num_groups=8, num_channels=mid_channels)
|
||||
self.lam2 = LambdaLayer(dim=mid_channels, dim_out=mid_channels, r=23, dim_k=16, heads=4, dim_u=4)
|
||||
self.gn2 = GroupNorm(num_groups=8, num_channels=mid_channels)
|
||||
self.lam3 = LambdaLayer(dim=mid_channels, dim_out=reduce_to, r=23, dim_k=16, heads=4, dim_u=4)
|
||||
self.gn3 = GroupNorm(num_groups=8, num_channels=mid_channels)
|
||||
self.conv = ConvGnLelu(reduce_to, reduce_to, kernel_size=1, bias=True, norm=False, activation=False, weight_init_factor=.1)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor with shape (n, c, h, w).
|
||||
|
||||
Returns:
|
||||
Tensor: Forward results.
|
||||
"""
|
||||
out = self.lam1(x)
|
||||
out = self.gn1(out)
|
||||
out = self.lam2(out)
|
||||
out = self.gn1(out)
|
||||
out = self.lam3(out)
|
||||
out = self.gn3(out)
|
||||
return self.conv(out) * .2 + x
|
|
@ -1,206 +0,0 @@
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from models.archs.RRDBNet_arch import RRDB
|
||||
from models.archs.arch_util import make_layer, default_init_weights, ConvGnSilu, ConvGnLelu, PixelUnshuffle
|
||||
from utils.util import checkpoint, sequential_checkpoint
|
||||
|
||||
|
||||
class MultiLevelRRDB(nn.Module):
|
||||
def __init__(self, nf, gc, levels):
|
||||
super().__init__()
|
||||
self.levels = levels
|
||||
self.level_rrdbs = nn.ModuleList([RRDB(nf, growth_channels=gc) for i in range(levels)])
|
||||
|
||||
# Trunks should be fed in in order HR->LR
|
||||
def forward(self, trunk):
|
||||
for i in reversed(range(self.levels)):
|
||||
lvl_scale = (2**i)
|
||||
lvl_res = self.level_rrdbs[i](F.interpolate(trunk, scale_factor=1/lvl_scale, mode="area"), return_residual=True)
|
||||
trunk = trunk + F.interpolate(lvl_res, scale_factor=lvl_scale, mode="nearest")
|
||||
return trunk
|
||||
|
||||
|
||||
class MultiResRRDBNet(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
mid_channels=64,
|
||||
l1_blocks=3,
|
||||
l2_blocks=4,
|
||||
l3_blocks=6,
|
||||
growth_channels=32,
|
||||
scale=4,
|
||||
):
|
||||
super().__init__()
|
||||
self.scale = scale
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.conv_first = nn.Conv2d(in_channels, mid_channels, 7, stride=1, padding=3)
|
||||
|
||||
self.l3_blocks = nn.ModuleList([MultiLevelRRDB(mid_channels, growth_channels, 3) for _ in range(l1_blocks)])
|
||||
self.l2_blocks = nn.ModuleList([MultiLevelRRDB(mid_channels, growth_channels, 2) for _ in range(l2_blocks)])
|
||||
self.l1_blocks = nn.ModuleList([MultiLevelRRDB(mid_channels, growth_channels, 1) for _ in range(l3_blocks)])
|
||||
self.block_levels = [self.l3_blocks, self.l2_blocks, self.l1_blocks]
|
||||
|
||||
self.conv_body = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
|
||||
# upsample
|
||||
self.conv_up1 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
|
||||
self.conv_up2 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
|
||||
self.conv_hr = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
|
||||
self.conv_last = nn.Conv2d(mid_channels, out_channels, 3, 1, 1)
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
for m in [
|
||||
self.conv_first, self.conv_first, self.conv_body, self.conv_up1,
|
||||
self.conv_up2, self.conv_hr, self.conv_last
|
||||
]:
|
||||
if m is not None:
|
||||
default_init_weights(m, 0.1)
|
||||
|
||||
def forward(self, x):
|
||||
trunk = self.conv_first(x)
|
||||
for block_set in self.block_levels:
|
||||
for block in block_set:
|
||||
trunk = checkpoint(block, trunk)
|
||||
|
||||
body_feat = self.conv_body(trunk)
|
||||
feat = trunk + body_feat
|
||||
|
||||
# upsample
|
||||
out = self.lrelu(
|
||||
self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
||||
if self.scale == 4:
|
||||
out = self.lrelu(
|
||||
self.conv_up2(F.interpolate(out, scale_factor=2, mode='nearest')))
|
||||
else:
|
||||
out = self.lrelu(self.conv_up2(out))
|
||||
out = self.conv_last(self.lrelu(self.conv_hr(out)))
|
||||
|
||||
return out
|
||||
|
||||
def visual_dbg(self, step, path):
|
||||
pass
|
||||
|
||||
|
||||
class SteppedResRRDBNet(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
mid_channels=64,
|
||||
l1_blocks=3,
|
||||
l2_blocks=3,
|
||||
growth_channels=32,
|
||||
scale=4,
|
||||
):
|
||||
super().__init__()
|
||||
self.scale = scale
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.conv_first = nn.Conv2d(in_channels, mid_channels, 7, stride=2, padding=3)
|
||||
self.conv_second = nn.Conv2d(mid_channels, mid_channels*2, 3, stride=2, padding=1)
|
||||
|
||||
self.l1_blocks = nn.Sequential(*[RRDB(mid_channels*2, growth_channels*2) for _ in range(l1_blocks)])
|
||||
self.l1_upsample_conv = nn.Conv2d(mid_channels*2, mid_channels, 3, stride=1, padding=1)
|
||||
self.l2_blocks = nn.Sequential(*[RRDB(mid_channels, growth_channels, 2) for _ in range(l2_blocks)])
|
||||
|
||||
self.conv_body = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
|
||||
# upsample
|
||||
self.conv_up1 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
|
||||
self.conv_up2 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
|
||||
self.conv_hr = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
|
||||
self.conv_last = nn.Conv2d(mid_channels, out_channels, 3, 1, 1)
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
for m in [
|
||||
self.conv_first, self.conv_second, self.l1_upsample_conv, self.conv_body, self.conv_up1,
|
||||
self.conv_up2, self.conv_hr, self.conv_last
|
||||
]:
|
||||
if m is not None:
|
||||
default_init_weights(m, 0.1)
|
||||
|
||||
def forward(self, x):
|
||||
trunk = self.conv_first(x)
|
||||
trunk = self.conv_second(trunk)
|
||||
trunk = sequential_checkpoint(self.l1_blocks, len(self.l2_blocks), trunk)
|
||||
trunk = F.interpolate(trunk, scale_factor=2, mode="nearest")
|
||||
trunk = self.l1_upsample_conv(trunk)
|
||||
trunk = sequential_checkpoint(self.l2_blocks, len(self.l2_blocks), trunk)
|
||||
body_feat = self.conv_body(trunk)
|
||||
feat = trunk + body_feat
|
||||
|
||||
# upsample
|
||||
out = self.lrelu(
|
||||
self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
||||
if self.scale == 4:
|
||||
out = self.lrelu(
|
||||
self.conv_up2(F.interpolate(out, scale_factor=2, mode='nearest')))
|
||||
else:
|
||||
out = self.lrelu(self.conv_up2(out))
|
||||
out = self.conv_last(self.lrelu(self.conv_hr(out)))
|
||||
|
||||
return out
|
||||
|
||||
def visual_dbg(self, step, path):
|
||||
pass
|
||||
|
||||
|
||||
class PixelShufflingSteppedResRRDBNet(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
mid_channels=64,
|
||||
l1_blocks=3,
|
||||
l2_blocks=3,
|
||||
growth_channels=32,
|
||||
scale=2,
|
||||
):
|
||||
super().__init__()
|
||||
self.scale = scale * 2 # This RRDB operates at half-scale resolution.
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.pix_unshuffle = PixelUnshuffle(4)
|
||||
self.conv_first = nn.Conv2d(4*4*in_channels, mid_channels*2, 3, stride=1, padding=1)
|
||||
|
||||
self.l1_blocks = nn.Sequential(*[RRDB(mid_channels*2, growth_channels*2) for _ in range(l1_blocks)])
|
||||
self.l1_upsample_conv = nn.Conv2d(mid_channels*2, mid_channels, 3, stride=1, padding=1)
|
||||
self.l2_blocks = nn.Sequential(*[RRDB(mid_channels, growth_channels, 2) for _ in range(l2_blocks)])
|
||||
|
||||
self.conv_body = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
|
||||
# upsample
|
||||
self.conv_up1 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
|
||||
self.conv_up2 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
|
||||
self.conv_hr = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
|
||||
self.conv_last = nn.Conv2d(mid_channels, out_channels, 3, 1, 1)
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
for m in [
|
||||
self.conv_first, self.l1_upsample_conv, self.conv_body, self.conv_up1,
|
||||
self.conv_up2, self.conv_hr, self.conv_last
|
||||
]:
|
||||
if m is not None:
|
||||
default_init_weights(m, 0.1)
|
||||
|
||||
def forward(self, x):
|
||||
trunk = self.conv_first(self.pix_unshuffle(x))
|
||||
trunk = sequential_checkpoint(self.l1_blocks, len(self.l1_blocks), trunk)
|
||||
trunk = F.interpolate(trunk, scale_factor=2, mode="nearest")
|
||||
trunk = self.l1_upsample_conv(trunk)
|
||||
trunk = sequential_checkpoint(self.l2_blocks, len(self.l2_blocks), trunk)
|
||||
body_feat = self.conv_body(trunk)
|
||||
feat = trunk + body_feat
|
||||
|
||||
# upsample
|
||||
out = self.lrelu(
|
||||
self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
||||
if self.scale == 4:
|
||||
out = self.lrelu(
|
||||
self.conv_up2(F.interpolate(out, scale_factor=2, mode='nearest')))
|
||||
else:
|
||||
out = self.lrelu(self.conv_up2(out))
|
||||
out = self.conv_last(self.lrelu(self.conv_hr(out)))
|
||||
|
||||
return out
|
||||
|
||||
def visual_dbg(self, step, path):
|
||||
pass
|
|
@ -1,98 +0,0 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from models.archs.arch_util import ConvGnLelu, ExpansionBlock
|
||||
from models.flownet2.networks.resample2d_package.resample2d import Resample2d
|
||||
from utils.util import checkpoint
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Pyramid(nn.Module):
|
||||
def __init__(self, nf, depth, processing_convs_per_layer, processing_at_point, scale_per_level=2, block=ConvGnLelu,
|
||||
norm=True, return_outlevels=False):
|
||||
super(Pyramid, self).__init__()
|
||||
levels = []
|
||||
current_filters = nf
|
||||
self.return_outlevels = return_outlevels
|
||||
for d in range(depth):
|
||||
level = [block(current_filters, int(current_filters*scale_per_level), kernel_size=3, stride=2, activation=True, norm=False, bias=False)]
|
||||
current_filters = int(current_filters*scale_per_level)
|
||||
for pc in range(processing_convs_per_layer):
|
||||
level.append(block(current_filters, current_filters, kernel_size=3, activation=True, norm=norm, bias=False))
|
||||
levels.append(nn.Sequential(*level))
|
||||
self.downsamples = nn.ModuleList(levels)
|
||||
if processing_at_point > 0:
|
||||
point_processor = []
|
||||
for p in range(processing_at_point):
|
||||
point_processor.append(block(current_filters, current_filters, kernel_size=3, activation=True, norm=norm, bias=False))
|
||||
self.point_processor = nn.Sequential(*point_processor)
|
||||
else:
|
||||
self.point_processor = None
|
||||
levels = []
|
||||
for d in range(depth):
|
||||
level = [ExpansionBlock(current_filters, int(current_filters / scale_per_level), block=block)]
|
||||
current_filters = int(current_filters / scale_per_level)
|
||||
for pc in range(processing_convs_per_layer):
|
||||
level.append(block(current_filters, current_filters, kernel_size=3, activation=True, norm=norm, bias=False))
|
||||
levels.append(nn.ModuleList(level))
|
||||
self.upsamples = nn.ModuleList(levels)
|
||||
|
||||
def forward(self, x):
|
||||
passthroughs = []
|
||||
fea = x
|
||||
for lvl in self.downsamples:
|
||||
passthroughs.append(fea)
|
||||
fea = lvl(fea)
|
||||
out_levels = []
|
||||
fea = self.point_processor(fea)
|
||||
for i, lvl in enumerate(self.upsamples):
|
||||
out_levels.append(fea)
|
||||
for j, sublvl in enumerate(lvl):
|
||||
if j == 0:
|
||||
fea = sublvl(fea, passthroughs[-1-i])
|
||||
else:
|
||||
fea = sublvl(fea)
|
||||
|
||||
out_levels.append(fea)
|
||||
|
||||
if self.return_outlevels:
|
||||
return tuple(out_levels)
|
||||
else:
|
||||
return fea
|
||||
|
||||
|
||||
class BasicResamplingFlowNet(nn.Module):
|
||||
def create_termini(self, filters):
|
||||
return nn.Sequential(ConvGnLelu(int(filters), 2, kernel_size=3, activation=False, norm=False, bias=True),
|
||||
nn.Tanh())
|
||||
|
||||
def __init__(self, nf, resample_scale=1):
|
||||
super(BasicResamplingFlowNet, self).__init__()
|
||||
self.initial_conv = ConvGnLelu(6, nf, kernel_size=7, activation=False, norm=False, bias=True)
|
||||
self.pyramid = Pyramid(nf, 3, 0, 1, 1.5, return_outlevels=True)
|
||||
self.termini = nn.ModuleList([self.create_termini(nf*1.5**3),
|
||||
self.create_termini(nf*1.5**2),
|
||||
self.create_termini(nf*1.5)])
|
||||
self.terminus = nn.Sequential(ConvGnLelu(nf, nf, kernel_size=3, activation=True, norm=True, bias=True),
|
||||
ConvGnLelu(nf, nf, kernel_size=3, activation=True, norm=True, bias=False),
|
||||
ConvGnLelu(nf, nf//2, kernel_size=3, activation=False, norm=False, bias=True),
|
||||
ConvGnLelu(nf//2, 2, kernel_size=3, activation=False, norm=False, bias=True),
|
||||
nn.Tanh())
|
||||
self.scale = resample_scale
|
||||
self.resampler = Resample2d()
|
||||
|
||||
def forward(self, left, right):
|
||||
fea = self.initial_conv(torch.cat([left, right], dim=1))
|
||||
levels = checkpoint(self.pyramid, fea)
|
||||
flos = []
|
||||
compares = []
|
||||
for i, level in enumerate(levels):
|
||||
if i == 3:
|
||||
flow = checkpoint(self.terminus, level) * self.scale
|
||||
else:
|
||||
flow = self.termini[i](level) * self.scale
|
||||
img_scale = 1/2**(3-i)
|
||||
flos.append(self.resampler(F.interpolate(left, scale_factor=img_scale, mode="area").float(), flow.float()))
|
||||
compares.append(F.interpolate(right, scale_factor=img_scale, mode="area"))
|
||||
flos_structural_var = torch.var(flos[-1], dim=[-1,-2])
|
||||
return flos, compares, flos_structural_var
|
|
@ -1,80 +0,0 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
import numpy as np
|
||||
from math import exp
|
||||
|
||||
|
||||
def gaussian(window_size, sigma):
|
||||
gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
|
||||
return gauss / gauss.sum()
|
||||
|
||||
|
||||
def create_window(window_size, channel):
|
||||
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
||||
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
||||
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
|
||||
return window
|
||||
|
||||
|
||||
def _ssim(img1, img2, window, window_size, channel, size_average=True, raw=False):
|
||||
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
|
||||
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
|
||||
|
||||
mu1_sq = mu1.pow(2)
|
||||
mu2_sq = mu2.pow(2)
|
||||
mu1_mu2 = mu1 * mu2
|
||||
|
||||
sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
|
||||
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
|
||||
sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
|
||||
|
||||
C1 = 0.01 ** 2
|
||||
C2 = 0.03 ** 2
|
||||
|
||||
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
||||
|
||||
if size_average:
|
||||
return ssim_map.mean()
|
||||
elif raw:
|
||||
return ssim_map
|
||||
else:
|
||||
return ssim_map.mean(1).mean(1).mean(1)
|
||||
|
||||
|
||||
class SSIM(torch.nn.Module):
|
||||
def __init__(self, window_size=11, size_average=True, raw=False):
|
||||
super(SSIM, self).__init__()
|
||||
self.window_size = window_size
|
||||
self.size_average = size_average
|
||||
self.raw = raw
|
||||
self.channel = 1
|
||||
self.window = create_window(window_size, self.channel)
|
||||
|
||||
def forward(self, img1, img2):
|
||||
(_, channel, _, _) = img1.size()
|
||||
|
||||
if channel == self.channel and self.window.data.type() == img1.data.type():
|
||||
window = self.window
|
||||
else:
|
||||
window = create_window(self.window_size, channel)
|
||||
|
||||
if img1.is_cuda:
|
||||
window = window.cuda(img1.get_device())
|
||||
window = window.type_as(img1)
|
||||
|
||||
self.window = window
|
||||
self.channel = channel
|
||||
|
||||
return _ssim(img1, img2, window, self.window_size, channel, self.size_average, self.raw)
|
||||
|
||||
|
||||
def ssim(img1, img2, window_size=11, size_average=True):
|
||||
(_, channel, _, _) = img1.size()
|
||||
window = create_window(window_size, channel)
|
||||
|
||||
if img1.is_cuda:
|
||||
window = window.cuda(img1.get_device())
|
||||
window = window.type_as(img1)
|
||||
|
||||
return _ssim(img1, img2, window, window_size, channel, size_average)
|
|
@ -1,221 +0,0 @@
|
|||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from utils.util import checkpoint
|
||||
|
||||
from torch.autograd import Variable
|
||||
|
||||
def default_conv(in_channels, out_channels, kernel_size, bias=True):
|
||||
return nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size,
|
||||
padding=(kernel_size//2), bias=bias)
|
||||
|
||||
class MeanShift(nn.Conv2d):
|
||||
def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
|
||||
super(MeanShift, self).__init__(3, 3, kernel_size=1)
|
||||
std = torch.Tensor(rgb_std)
|
||||
self.weight.data = torch.eye(3).view(3, 3, 1, 1)
|
||||
self.weight.data.div_(std.view(3, 1, 1, 1))
|
||||
self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
|
||||
self.bias.data.div_(std)
|
||||
self.requires_grad = False
|
||||
|
||||
class BasicBlock(nn.Sequential):
|
||||
def __init__(
|
||||
self, in_channels, out_channels, kernel_size, stride=1, bias=False,
|
||||
bn=True, act=nn.ReLU(True)):
|
||||
|
||||
m = [nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size,
|
||||
padding=(kernel_size//2), stride=stride, bias=bias)
|
||||
]
|
||||
if bn: m.append(nn.BatchNorm2d(out_channels))
|
||||
if act is not None: m.append(act)
|
||||
super(BasicBlock, self).__init__(*m)
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(
|
||||
self, conv, n_feat, kernel_size,
|
||||
bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
|
||||
|
||||
super(ResBlock, self).__init__()
|
||||
m = []
|
||||
for i in range(2):
|
||||
m.append(conv(n_feat, n_feat, kernel_size, bias=bias))
|
||||
if bn: m.append(nn.BatchNorm2d(n_feat))
|
||||
if i == 0: m.append(act)
|
||||
|
||||
self.body = nn.Sequential(*m)
|
||||
self.res_scale = res_scale
|
||||
|
||||
def forward(self, x):
|
||||
res = self.body(x).mul(self.res_scale)
|
||||
res += x
|
||||
|
||||
return res
|
||||
|
||||
class Upsampler(nn.Sequential):
|
||||
def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True):
|
||||
|
||||
m = []
|
||||
if (scale & (scale - 1)) == 0: # Is scale = 2^n?
|
||||
for _ in range(int(math.log(scale, 2))):
|
||||
m.append(conv(n_feat, 4 * n_feat, 3, bias))
|
||||
m.append(nn.PixelShuffle(2))
|
||||
if bn: m.append(nn.BatchNorm2d(n_feat))
|
||||
if act: m.append(act())
|
||||
elif scale == 3:
|
||||
m.append(conv(n_feat, 9 * n_feat, 3, bias))
|
||||
m.append(nn.PixelShuffle(3))
|
||||
if bn: m.append(nn.BatchNorm2d(n_feat))
|
||||
if act: m.append(act())
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
super(Upsampler, self).__init__(*m)
|
||||
|
||||
def make_model(args, parent=False):
|
||||
return RCAN(args)
|
||||
|
||||
|
||||
## Channel Attention (CA) Layer
|
||||
class CALayer(nn.Module):
|
||||
def __init__(self, channel, reduction=16):
|
||||
super(CALayer, self).__init__()
|
||||
# global average pooling: feature --> point
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
# feature channel downscale and upscale --> channel weight
|
||||
self.conv_du = nn.Sequential(
|
||||
nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.avg_pool(x)
|
||||
y = self.conv_du(y)
|
||||
return x * y
|
||||
|
||||
|
||||
## Residual Channel Attention Block (RCAB)
|
||||
class RCAB(nn.Module):
|
||||
def __init__(
|
||||
self, conv, n_feat, kernel_size, reduction,
|
||||
bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
|
||||
|
||||
super(RCAB, self).__init__()
|
||||
modules_body = []
|
||||
for i in range(2):
|
||||
modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
|
||||
if bn: modules_body.append(nn.BatchNorm2d(n_feat))
|
||||
if i == 0: modules_body.append(act)
|
||||
modules_body.append(CALayer(n_feat, reduction))
|
||||
self.body = nn.Sequential(*modules_body)
|
||||
self.res_scale = res_scale
|
||||
|
||||
def forward(self, x):
|
||||
res = self.body(x)
|
||||
# res = self.body(x).mul(self.res_scale)
|
||||
res += x
|
||||
return res
|
||||
|
||||
|
||||
## Residual Group (RG)
|
||||
class ResidualGroup(nn.Module):
|
||||
def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks):
|
||||
super(ResidualGroup, self).__init__()
|
||||
modules_body = []
|
||||
modules_body = [
|
||||
RCAB(
|
||||
conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \
|
||||
for _ in range(n_resblocks)]
|
||||
modules_body.append(conv(n_feat, n_feat, kernel_size))
|
||||
self.body = nn.Sequential(*modules_body)
|
||||
|
||||
def forward(self, x):
|
||||
res = self.body(x)
|
||||
res += x
|
||||
return res
|
||||
|
||||
|
||||
## Residual Channel Attention Network (RCAN)
|
||||
class RCAN(nn.Module):
|
||||
def __init__(self, args, conv=default_conv):
|
||||
super(RCAN, self).__init__()
|
||||
|
||||
n_resgroups = args.n_resgroups
|
||||
n_resblocks = args.n_resblocks
|
||||
n_feats = args.n_feats
|
||||
kernel_size = 3
|
||||
reduction = args.reduction
|
||||
scale = args.scale
|
||||
act = nn.ReLU(True)
|
||||
|
||||
# RGB mean for DIV2K
|
||||
rgb_mean = (0.4488, 0.4371, 0.4040)
|
||||
rgb_std = (1.0, 1.0, 1.0)
|
||||
self.sub_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std)
|
||||
|
||||
# define head module
|
||||
modules_head = [conv(args.n_colors, n_feats, kernel_size)]
|
||||
|
||||
# define body module
|
||||
modules_body = [
|
||||
ResidualGroup(
|
||||
conv, n_feats, kernel_size, reduction, act=act, res_scale=args.res_scale, n_resblocks=n_resblocks) \
|
||||
for _ in range(n_resgroups)]
|
||||
|
||||
modules_body.append(conv(n_feats, n_feats, kernel_size))
|
||||
|
||||
# define tail module
|
||||
modules_tail = [
|
||||
Upsampler(conv, scale, n_feats, act=False),
|
||||
conv(n_feats, args.n_colors, kernel_size)]
|
||||
|
||||
self.add_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
|
||||
|
||||
self.head = nn.Sequential(*modules_head)
|
||||
self.body = nn.Sequential(*modules_body)
|
||||
self.tail = nn.Sequential(*modules_tail)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.sub_mean(x)
|
||||
x = self.head(x)
|
||||
|
||||
res = self.body(x)
|
||||
res += x
|
||||
|
||||
x = self.tail(res)
|
||||
x = self.add_mean(x)
|
||||
|
||||
return x
|
||||
|
||||
def load_state_dict(self, state_dict, strict=False):
|
||||
own_state = self.state_dict()
|
||||
for name, param in state_dict.items():
|
||||
if name in own_state:
|
||||
if isinstance(param, nn.Parameter):
|
||||
param = param.data
|
||||
try:
|
||||
own_state[name].copy_(param)
|
||||
except Exception:
|
||||
if name.find('tail') >= 0:
|
||||
print('Replace pre-trained upsampler to new one...')
|
||||
else:
|
||||
raise RuntimeError('While copying the parameter named {}, '
|
||||
'whose dimensions in the model are {} and '
|
||||
'whose dimensions in the checkpoint are {}.'
|
||||
.format(name, own_state[name].size(), param.size()))
|
||||
elif strict:
|
||||
if name.find('tail') == -1:
|
||||
raise KeyError('unexpected key "{}" in state_dict'
|
||||
.format(name))
|
||||
|
||||
if strict:
|
||||
missing = set(own_state.keys()) - set(state_dict.keys())
|
||||
if len(missing) > 0:
|
||||
raise KeyError('missing keys in state_dict: "{}"'.format(missing))
|
0
codes/models/archs/tecogan/__init__.py
Normal file
0
codes/models/archs/tecogan/__init__.py
Normal file
|
@ -2,8 +2,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import torchvision
|
||||
from models.steps.injectors import Injector
|
||||
from models.injectors import Injector
|
||||
from utils.util import checkpoint
|
||||
|
||||
|
||||
|
|
0
codes/models/custom_training_components/__init__.py
Normal file
0
codes/models/custom_training_components/__init__.py
Normal file
|
@ -6,9 +6,8 @@ import torchvision
|
|||
from torch.cuda.amp import autocast
|
||||
|
||||
from data.multiscale_dataset import build_multiscale_patch_index_map
|
||||
from models.steps.injectors import Injector
|
||||
from models.steps.losses import extract_params_from_state
|
||||
from models.steps.tecogan_losses import extract_inputs_index
|
||||
from models.injectors import Injector
|
||||
from models.losses import extract_params_from_state
|
||||
import os.path as osp
|
||||
|
||||
|
|
@ -1,8 +1,8 @@
|
|||
import torch
|
||||
from torch.cuda.amp import autocast
|
||||
from models.flownet2.networks.resample2d_package.resample2d import Resample2d
|
||||
from models.flownet2.utils.flow_utils import flow2img
|
||||
from models.steps.injectors import Injector
|
||||
from models.archs.flownet2.networks import Resample2d
|
||||
from models.archs.flownet2 import flow2img
|
||||
from models.injectors import Injector
|
||||
|
||||
|
||||
def create_stereoscopic_injector(opt, env):
|
|
@ -1,15 +1,15 @@
|
|||
from torch.cuda.amp import autocast
|
||||
|
||||
from models.archs.stylegan.stylegan2_lucidrains import gradient_penalty
|
||||
from models.steps.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name
|
||||
from models.flownet2.networks.resample2d_package.resample2d import Resample2d
|
||||
from models.steps.injectors import Injector
|
||||
from models.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name
|
||||
from models.archs.flownet2.networks import Resample2d
|
||||
from models.injectors import Injector
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import os
|
||||
import os.path as osp
|
||||
import torchvision
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def create_teco_loss(opt, env):
|
||||
type = opt['type']
|
|
@ -3,22 +3,20 @@ import random
|
|||
import torch.nn
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
from models.archs.SPSR_arch import ImageGradientNoPadding
|
||||
from models.archs.pytorch_ssim import SSIM
|
||||
from utils.weight_scheduler import get_scheduler_for_opt
|
||||
from models.steps.losses import extract_params_from_state
|
||||
from models.losses import extract_params_from_state
|
||||
|
||||
# Injectors are a way to sythesize data within a step that can then be used (and reused) by loss functions.
|
||||
def create_injector(opt_inject, env):
|
||||
type = opt_inject['type']
|
||||
if 'teco_' in type:
|
||||
from models.steps.tecogan_losses import create_teco_injector
|
||||
from models.custom_training_components import create_teco_injector
|
||||
return create_teco_injector(opt_inject, env)
|
||||
elif 'progressive_' in type:
|
||||
from models.steps.progressive_zoom import create_progressive_zoom_injector
|
||||
from models.custom_training_components import create_progressive_zoom_injector
|
||||
return create_progressive_zoom_injector(opt_inject, env)
|
||||
elif 'stereoscopic_' in type:
|
||||
from models.steps.stereoscopic import create_stereoscopic_injector
|
||||
from models.custom_training_components import create_stereoscopic_injector
|
||||
return create_stereoscopic_injector(opt_inject, env)
|
||||
elif 'igpt' in type:
|
||||
from models.archs.transformers.igpt import gpt2
|
||||
|
@ -29,8 +27,6 @@ def create_injector(opt_inject, env):
|
|||
return DiscriminatorInjector(opt_inject, env)
|
||||
elif type == 'scheduled_scalar':
|
||||
return ScheduledScalarInjector(opt_inject, env)
|
||||
elif type == 'img_grad':
|
||||
return ImageGradientInjector(opt_inject, env)
|
||||
elif type == 'add_noise':
|
||||
return AddNoiseInjector(opt_inject, env)
|
||||
elif type == 'greyscale':
|
||||
|
@ -47,14 +43,10 @@ def create_injector(opt_inject, env):
|
|||
return ForEachInjector(opt_inject, env)
|
||||
elif type == 'constant':
|
||||
return ConstantInjector(opt_inject, env)
|
||||
elif type == 'fft':
|
||||
return ImageFftInjector(opt_inject, env)
|
||||
elif type == 'extract_indices':
|
||||
return IndicesExtractor(opt_inject, env)
|
||||
elif type == 'random_shift':
|
||||
return RandomShiftInjector(opt_inject, env)
|
||||
elif type == 'psnr':
|
||||
return PsnrInjector(opt_inject, env)
|
||||
elif type == 'batch_rotate':
|
||||
return BatchRotateInjector(opt_inject, env)
|
||||
elif type == 'sr_diffs':
|
||||
|
@ -134,16 +126,6 @@ class DiscriminatorInjector(Injector):
|
|||
return new_state
|
||||
|
||||
|
||||
# Creates an image gradient from [in] and injects it into [out]
|
||||
class ImageGradientInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super(ImageGradientInjector, self).__init__(opt, env)
|
||||
self.img_grad_fn = ImageGradientNoPadding().to(env['device'])
|
||||
|
||||
def forward(self, state):
|
||||
return {self.opt['out']: self.img_grad_fn(state[self.opt['in']])}
|
||||
|
||||
|
||||
# Injects a scalar that is modulated with a specified schedule. Useful for increasing or decreasing the influence
|
||||
# of something over time.
|
||||
class ScheduledScalarInjector(Injector):
|
||||
|
@ -320,37 +302,6 @@ class ConstantInjector(Injector):
|
|||
return { self.opt['out']: out }
|
||||
|
||||
|
||||
class ImageFftInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super(ImageFftInjector, self).__init__(opt, env)
|
||||
self.is_forward = opt['forward'] # Whether to compute a forward FFT or backward.
|
||||
self.eps = 1e-100
|
||||
|
||||
def forward(self, state):
|
||||
if self.forward:
|
||||
fftim = torch.rfft(state[self.input], signal_ndim=2, normalized=True)
|
||||
b, f, h, w, c = fftim.shape
|
||||
fftim = fftim.permute(0,1,4,2,3).reshape(b,-1,h,w)
|
||||
# Normalize across spatial dimension
|
||||
mean = torch.mean(fftim, dim=(0,1))
|
||||
fftim = fftim - mean
|
||||
std = torch.std(fftim, dim=(0,1))
|
||||
fftim = (fftim + self.eps) / std
|
||||
return {self.output: fftim,
|
||||
'%s_std' % (self.output,): std,
|
||||
'%s_mean' % (self.output,): mean}
|
||||
else:
|
||||
b, f, h, w = state[self.input].shape
|
||||
# First, de-normalize the FFT.
|
||||
mean = state['%s_mean' % (self.input,)]
|
||||
std = state['%s_std' % (self.input,)]
|
||||
fftim = state[self.input] * std + mean - self.eps
|
||||
# Second, recover the FFT dimensions from the given filters.
|
||||
fftim = fftim.reshape(b, f // 2, 2, h, w).permute(0,1,3,4,2)
|
||||
im = torch.irfft(fftim, signal_ndim=2, normalized=True)
|
||||
return {self.output: im}
|
||||
|
||||
|
||||
class IndicesExtractor(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super(IndicesExtractor, self).__init__(opt, env)
|
||||
|
@ -374,21 +325,6 @@ class RandomShiftInjector(Injector):
|
|||
return {self.output: img}
|
||||
|
||||
|
||||
class PsnrInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super(PsnrInjector, self).__init__(opt, env)
|
||||
self.ssim = SSIM(size_average=False, raw=True)
|
||||
self.scale = opt['output_scale_divisor']
|
||||
self.exp = opt['exponent'] if 'exponent' in opt.keys() else 1
|
||||
|
||||
def forward(self, state):
|
||||
img1, img2 = state[self.input[0]], state[self.input[1]]
|
||||
ssim = self.ssim(img1, img2)
|
||||
areal_se = torch.nn.functional.interpolate(ssim, scale_factor=1/self.scale,
|
||||
mode="area")
|
||||
return {self.output: areal_se}
|
||||
|
||||
|
||||
class BatchRotateInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super(BatchRotateInjector, self).__init__(opt, env)
|
||||
|
@ -436,7 +372,7 @@ class MultiFrameCombiner(Injector):
|
|||
self.in_hq_key = opt['in_hq']
|
||||
self.out_lq_key = opt['out']
|
||||
self.out_hq_key = opt['out_hq']
|
||||
from models.flownet2.networks.resample2d_package.resample2d import Resample2d
|
||||
from models.archs.flownet2.networks import Resample2d
|
||||
self.resampler = Resample2d()
|
||||
|
||||
def combine(self, state):
|
|
@ -6,13 +6,12 @@ from models.loss import GANLoss
|
|||
import random
|
||||
import functools
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
|
||||
def create_loss(opt_loss, env):
|
||||
type = opt_loss['type']
|
||||
if 'teco_' in type:
|
||||
from models.steps.tecogan_losses import create_teco_loss
|
||||
from models.custom_training_components.tecogan_losses import create_teco_loss
|
||||
return create_teco_loss(opt_loss, env)
|
||||
elif 'stylegan2_' in type:
|
||||
from models.archs.stylegan import create_stylegan2_loss
|
|
@ -10,16 +10,12 @@ import models.archs.stylegan.stylegan2_lucidrains as stylegan2
|
|||
|
||||
import models.archs.fixup_resnet.DiscriminatorResnet_arch as DiscriminatorResnet_arch
|
||||
import models.archs.RRDBNet_arch as RRDBNet_arch
|
||||
import models.archs.SPSR_arch as spsr
|
||||
import models.archs.SRResNet_arch as SRResNet_arch
|
||||
import models.archs.SwitchedResidualGenerator_arch as SwitchedGen_arch
|
||||
import models.archs.discriminator_vgg_arch as SRGAN_arch
|
||||
import models.archs.feature_arch as feature_arch
|
||||
import models.archs.rcan as rcan
|
||||
from models.archs import srg2_classic
|
||||
from models.archs.biggan.biggan_discriminator import BigGanDiscriminator
|
||||
from models.archs.stylegan.Discriminator_StyleGAN import StyleGanDiscriminator
|
||||
from models.archs.teco_resgen import TecoGen
|
||||
from models.archs.tecogan.teco_resgen import TecoGen
|
||||
from utils.util import opt_get
|
||||
|
||||
logger = logging.getLogger('base')
|
||||
|
@ -30,11 +26,7 @@ def define_G(opt, opt_net, scale=None):
|
|||
scale = opt['scale']
|
||||
which_model = opt_net['which_model_G']
|
||||
|
||||
# image restoration
|
||||
if which_model == 'MSRResNet':
|
||||
netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||
nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'])
|
||||
elif 'RRDBNet' in which_model:
|
||||
if 'RRDBNet' in which_model:
|
||||
if which_model == 'RRDBNetBypass':
|
||||
block = RRDBNet_arch.RRDBWithBypass
|
||||
elif which_model == 'RRDBNetLambda':
|
||||
|
@ -50,24 +42,6 @@ def define_G(opt, opt_net, scale=None):
|
|||
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode,
|
||||
output_mode=output_mode, body_block=block, scale=opt_net['scale'], growth_channels=gc,
|
||||
initial_stride=initial_stride)
|
||||
elif which_model == "multires_rrdb":
|
||||
from models.archs.multi_res_rrdb import MultiResRRDBNet
|
||||
netG = MultiResRRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
|
||||
mid_channels=opt_net['nf'], l1_blocks=opt_net['l1'],
|
||||
l2_blocks=opt_net['l2'], l3_blocks=opt_net['l3'],
|
||||
growth_channels=opt_net['gc'], scale=opt_net['scale'])
|
||||
elif which_model == "twostep_rrdb":
|
||||
from models.archs.multi_res_rrdb import PixelShufflingSteppedResRRDBNet
|
||||
netG = PixelShufflingSteppedResRRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
|
||||
mid_channels=opt_net['nf'], l1_blocks=opt_net['l1'],
|
||||
l2_blocks=opt_net['l2'],
|
||||
growth_channels=opt_net['gc'], scale=opt_net['scale'])
|
||||
elif which_model == 'rcan':
|
||||
#args: n_resgroups, n_resblocks, res_scale, reduction, scale, n_feats
|
||||
opt_net['rgb_range'] = 255
|
||||
opt_net['n_colors'] = 3
|
||||
args_obj = munchify(opt_net)
|
||||
netG = rcan.RCAN(args_obj)
|
||||
elif which_model == "ConfigurableSwitchedResidualGenerator2":
|
||||
netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'],
|
||||
switch_reductions=opt_net['switch_reductions'],
|
||||
|
@ -87,22 +61,8 @@ def define_G(opt, opt_net, scale=None):
|
|||
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
||||
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
|
||||
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
|
||||
elif which_model == 'spsr':
|
||||
netG = spsr.SPSRNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],
|
||||
nb=opt_net['nb'], upscale=opt_net['scale'])
|
||||
elif which_model == 'spsr_net_improved':
|
||||
netG = spsr.SPSRNetSimplified(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],
|
||||
nb=opt_net['nb'], upscale=opt_net['scale'])
|
||||
elif which_model == "spsr_switched":
|
||||
netG = spsr.SwitchedSpsr(in_nc=3, nf=opt_net['nf'], upscale=opt_net['scale'], init_temperature=opt_net['temperature'])
|
||||
elif which_model == "spsr7":
|
||||
recurrent = opt_net['recurrent'] if 'recurrent' in opt_net.keys() else False
|
||||
xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8
|
||||
netG = spsr.Spsr7(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],
|
||||
multiplexer_reductions=opt_net['multiplexer_reductions'] if 'multiplexer_reductions' in opt_net.keys() else 3,
|
||||
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10, recurrent=recurrent)
|
||||
elif which_model == "flownet2":
|
||||
from models.flownet2.models import FlowNet2
|
||||
from models.archs.flownet2 import FlowNet2
|
||||
ld = 'load_path' in opt_net.keys()
|
||||
args = munch.Munch({'fp16': False, 'rgb_max': 1.0, 'checkpoint': not ld})
|
||||
netG = FlowNet2(args)
|
||||
|
@ -148,12 +108,12 @@ def define_G(opt, opt_net, scale=None):
|
|||
from models.archs.transformers.igpt.gpt2 import iGPT2
|
||||
netG = iGPT2(opt_net['embed_dim'], opt_net['num_heads'], opt_net['num_layers'], opt_net['num_pixels'] ** 2, opt_net['num_vocab'], centroids_file=opt_net['centroids_file'])
|
||||
elif which_model == 'byol':
|
||||
from models.byol.byol_model_wrapper import BYOL
|
||||
from models.archs.byol.byol_model_wrapper import BYOL
|
||||
subnet = define_G(opt, opt_net['subnet'])
|
||||
netG = BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
|
||||
structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False))
|
||||
elif which_model == 'structural_byol':
|
||||
from models.byol.byol_structural import StructuralBYOL
|
||||
from models.archs.byol.byol_structural import StructuralBYOL
|
||||
subnet = define_G(opt, opt_net['subnet'])
|
||||
netG = StructuralBYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
|
||||
pretrained_state_dict=opt_get(opt_net, ["pretrained_path"]),
|
||||
|
@ -206,8 +166,6 @@ def define_D_net(opt_net, img_sz=None, wrap=False):
|
|||
#state_dict = torch.hub.load_state_dict_from_url('https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', progress=True)
|
||||
#netD.load_state_dict(state_dict, strict=False)
|
||||
netD.fc = torch.nn.Linear(512 * 4, 1)
|
||||
elif which_model == 'biggan_resnet':
|
||||
netD = BigGanDiscriminator(D_activation=torch.nn.LeakyReLU(negative_slope=.2))
|
||||
elif which_model == 'discriminator_pix':
|
||||
netD = SRGAN_arch.Discriminator_VGG_PixLoss(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
|
||||
elif which_model == "discriminator_unet":
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
from torch.cuda.amp import GradScaler, autocast
|
||||
from torch.cuda.amp import GradScaler
|
||||
|
||||
from utils.loss_accumulator import LossAccumulator
|
||||
from torch.nn import Module
|
||||
import logging
|
||||
from models.steps.losses import create_loss
|
||||
from models.losses import create_loss
|
||||
import torch
|
||||
from collections import OrderedDict
|
||||
from .injectors import create_injector
|
||||
from models.injectors import create_injector
|
||||
from utils.util import recursively_detach
|
||||
|
||||
logger = logging.getLogger('base')
|
|
@ -8,14 +8,11 @@ import os
|
|||
import utils
|
||||
from models.ExtensibleTrainer import ExtensibleTrainer
|
||||
from models.networks import define_F
|
||||
from models.steps.losses import FeatureLoss
|
||||
from utils import options as option
|
||||
import utils.util as util
|
||||
from data import create_dataset, create_dataloader
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bin_path = "f:\\binned"
|
||||
|
|
|
@ -1,22 +0,0 @@
|
|||
import onnx
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
init_temperature = 10
|
||||
final_temperature_step = 50
|
||||
heightened_final_step = 90
|
||||
heightened_temp_min = .1
|
||||
|
||||
for step in range(100):
|
||||
temp = max(1, 1 + init_temperature * (final_temperature_step - step) / final_temperature_step)
|
||||
if temp == 1 and step > final_temperature_step and heightened_final_step and heightened_final_step != 1:
|
||||
# Once the temperature passes (1) it enters an inverted curve to match the linear curve from above.
|
||||
# without this, the attention specificity "spikes" incredibly fast in the last few iterations.
|
||||
h_steps_total = heightened_final_step - final_temperature_step
|
||||
h_steps_current = min(step - final_temperature_step, h_steps_total)
|
||||
# The "gap" will represent the steps that need to be traveled as a linear function.
|
||||
h_gap = 1 / heightened_temp_min
|
||||
temp = h_gap * h_steps_current / h_steps_total
|
||||
# Invert temperature to represent reality on this side of the curve
|
||||
temp = 1 / temp
|
||||
print("%i: %f" % (step, temp))
|
Loading…
Reference in New Issue
Block a user