Large cleanup

Removed a lot of old code that I won't be touching again. Refactored some
code elements into more logical places.
This commit is contained in:
James Betker 2020-12-18 09:10:44 -07:00
parent 2f0a52b7db
commit b905b108da
29 changed files with 31 additions and 2298 deletions

3
.gitmodules vendored
View File

@ -8,3 +8,6 @@
path = codes/models/flownet2
url = https://github.com/neonbjb/flownet2-pytorch.git
branch = master
[submodule "codes/models/archs/flownet2"]
path = codes/models/archs/flownet2
url = https://github.com/neonbjb/flownet2-pytorch.git

View File

@ -8,8 +8,8 @@ import torch.nn as nn
import models.lr_scheduler as lr_scheduler
import models.networks as networks
from models.base_model import BaseModel
from models.steps.injectors import create_injector
from models.steps.steps import ConfigurableStep
from models.injectors import create_injector
from models.steps import ConfigurableStep
from models.experiments.experiments import get_experiment_for_name
import torchvision.utils as utils

View File

@ -1,668 +0,0 @@
import functools
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from utils.util import checkpoint
from models.archs import SPSR_util as B
from models.archs.SwitchedResidualGenerator_arch import ConfigurableSwitchComputer, ReferenceImageBranch, \
QueryKeyMultiplexer, QueryKeyPyramidMultiplexer, ConvBasisMultiplexer
from models.archs.arch_util import ConvGnLelu, UpconvBlock, MultiConvBlock, ReferenceJoinBlock
from switched_conv.switched_conv import compute_attention_specificity
from switched_conv.switched_conv_util import save_attention_to_image_rgb
from .RRDBNet_arch import RRDB
class ImageGradient(nn.Module):
def __init__(self):
super(ImageGradient, self).__init__()
kernel_v = [[0, -1, 0],
[0, 0, 0],
[0, 1, 0]]
kernel_h = [[0, 0, 0],
[-1, 0, 1],
[0, 0, 0]]
kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0)
kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0)
self.weight_h = nn.Parameter(data = kernel_h, requires_grad = False).cuda()
self.weight_v = nn.Parameter(data = kernel_v, requires_grad = False).cuda()
def forward(self, x):
x0 = x[:, 0]
x1 = x[:, 1]
x2 = x[:, 2]
x0_v = F.conv2d(x0.unsqueeze(1), self.weight_v, padding=2)
x0_h = F.conv2d(x0.unsqueeze(1), self.weight_h, padding=2)
x1_v = F.conv2d(x1.unsqueeze(1), self.weight_v, padding=2)
x1_h = F.conv2d(x1.unsqueeze(1), self.weight_h, padding=2)
x2_v = F.conv2d(x2.unsqueeze(1), self.weight_v, padding=2)
x2_h = F.conv2d(x2.unsqueeze(1), self.weight_h, padding=2)
x0 = torch.sqrt(torch.pow(x0_v, 2) + torch.pow(x0_h, 2) + 1e-6)
x1 = torch.sqrt(torch.pow(x1_v, 2) + torch.pow(x1_h, 2) + 1e-6)
x2 = torch.sqrt(torch.pow(x2_v, 2) + torch.pow(x2_h, 2) + 1e-6)
x = torch.cat([x0, x1, x2], dim=1)
return x
class ImageGradientNoPadding(nn.Module):
def __init__(self):
super(ImageGradientNoPadding, self).__init__()
kernel_v = [[0, -1, 0],
[0, 0, 0],
[0, 1, 0]]
kernel_h = [[0, 0, 0],
[-1, 0, 1],
[0, 0, 0]]
kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0)
kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0)
self.weight_h = nn.Parameter(data = kernel_h, requires_grad = False)
self.weight_v = nn.Parameter(data = kernel_v, requires_grad = False)
def forward(self, x):
x = x.float()
x_list = []
for i in range(x.shape[1]):
x_i = x[:, i]
x_i_v = F.conv2d(x_i.unsqueeze(1), self.weight_v, padding=1)
x_i_h = F.conv2d(x_i.unsqueeze(1), self.weight_h, padding=1)
x_i = torch.sqrt(torch.pow(x_i_v, 2) + torch.pow(x_i_h, 2) + 1e-6)
x_list.append(x_i)
x = torch.cat(x_list, dim = 1)
return x
####################
# Generator
####################
class SPSRNet(nn.Module):
def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, \
act_type='leakyrelu', mode='CNA', upsample_mode='upconv'):
super(SPSRNet, self).__init__()
n_upscale = int(math.log(upscale, 2))
self.scale=upscale
if upscale == 3:
n_upscale = 1
fea_conv = ConvGnLelu(in_nc, nf//2, kernel_size=7, norm=False, activation=False)
self.ref_conv = ConvGnLelu(in_nc, nf//2, stride=upscale, kernel_size=7, norm=False, activation=False)
self.join_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
rb_blocks = [RRDB(nf) for _ in range(nb)]
LR_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
upsample_block = UpconvBlock
if upscale == 3:
upsampler = upsample_block(nf, nf, activation=True)
else:
upsampler = [upsample_block(nf, nf, activation=True) for _ in range(n_upscale)]
self.HR_conv0_new = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True)
self.HR_conv1_new = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)), \
*upsampler, self.HR_conv0_new)
self.b_fea_conv = ConvGnLelu(in_nc, nf//2, kernel_size=3, norm=False, activation=False)
self.b_ref_conv = ConvGnLelu(in_nc, nf//2, stride=upscale, kernel_size=3, norm=False, activation=False)
self.b_join_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
self.b_concat_1 = ConvGnLelu(2 * nf, nf, kernel_size=3, norm=False, activation=False)
self.b_block_1 = RRDB(nf * 2)
self.b_concat_2 = ConvGnLelu(2 * nf, nf, kernel_size=3, norm=False, activation=False)
self.b_block_2 = RRDB(nf * 2)
self.b_concat_3 = ConvGnLelu(2 * nf, nf, kernel_size=3, norm=False, activation=False)
self.b_block_3 = RRDB(nf * 2)
self.b_concat_4 = ConvGnLelu(2 * nf, nf, kernel_size=3, norm=False, activation=False)
self.b_block_4 = RRDB(nf * 2)
self.b_LR_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
if upscale == 3:
b_upsampler = UpconvBlock(nf, nf, activation=True)
else:
b_upsampler = [UpconvBlock(nf, nf, activation=True) for _ in range(n_upscale)]
b_HR_conv0 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True)
b_HR_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
self.b_module = B.sequential(*b_upsampler, b_HR_conv0, b_HR_conv1)
self.conv_w = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False)
self.f_concat = ConvGnLelu(nf * 2, nf, kernel_size=3, norm=False, activation=False)
self.f_block = RRDB(nf * 2)
self.f_HR_conv0 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True)
self.f_HR_conv1 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False)
self.get_g_nopadding = ImageGradientNoPadding()
def bl1(self, x):
block_list = self.model[1].sub
for i in range(5):
x = block_list[i](x)
return x
def bl2(self, x):
block_list = self.model[1].sub
for i in range(5):
x = block_list[i+5](x)
return x
def bl3(self, x):
block_list = self.model[1].sub
for i in range(5):
x = block_list[i+10](x)
return x
def bl4(self, x):
block_list = self.model[1].sub
for i in range(5):
x = block_list[i+15](x)
return x
def bl5(self, x):
block_list = self.model[1].sub
x = block_list[20:](x)
return x
def bl6(self, x_ori, x):
x = x_ori + x
x = self.model[2:](x)
x = self.HR_conv1_new(x)
return x
def branch_bl1(self, x_grad, ref_grad):
x_b_fea = self.b_fea_conv(x_grad)
x_b_ref = self.b_ref_conv(ref_grad)
x_b_fea = self.b_join_conv(torch.cat([x_b_fea, x_b_ref], dim=1))
return x_b_fea
def branch_bl2(self, x_b_fea, x_fea1):
x_cat_1 = torch.cat([x_b_fea, x_fea1], dim=1)
x_cat_1 = self.b_block_1(x_cat_1)
x_cat_1 = self.b_concat_1(x_cat_1)
return x_cat_1
def branch_bl3(self, x_cat_1, x_fea2):
x_cat_2 = torch.cat([x_cat_1, x_fea2], dim=1)
x_cat_2 = self.b_block_2(x_cat_2)
x_cat_2 = self.b_concat_2(x_cat_2)
return x_cat_2
def branch_bl4(self, x_cat_2, x_fea3):
x_cat_3 = torch.cat([x_cat_2, x_fea3], dim=1)
x_cat_3 = self.b_block_3(x_cat_3)
x_cat_3 = self.b_concat_3(x_cat_3)
return x_cat_3
def branch_bl5(self, x_cat_3, x_fea4, x_b_fea):
x_cat_4 = torch.cat([x_cat_3, x_fea4], dim=1)
x_cat_4 = self.b_block_4(x_cat_4)
x_cat_4 = self.b_concat_4(x_cat_4)
x_cat_4 = self.b_LR_conv(x_cat_4)
x_cat_4 = x_cat_4 + x_b_fea
x_branch = self.b_module(x_cat_4)
return x_branch
def forward(self, x, ref=None):
b,f,h,w = x.shape
if ref is None:
ref = torch.zeros((b,f,h*self.scale,w*self.scale), device=x.device, dtype=x.dtype)
x_grad = self.get_g_nopadding(x)
ref_grad = self.get_g_nopadding(ref)
x = self.model[0](x)
x_ref = self.ref_conv(ref)
x = self.join_conv(torch.cat([x, x_ref], dim=1))
x, block_list = self.model[1](x)
x_ori = x
x = checkpoint(self.bl1, x)
x_fea1 = x
x = checkpoint(self.bl2, x)
x_fea2 = x
x = checkpoint(self.bl3, x)
x_fea3 = x
x = checkpoint(self.bl4, x)
x_fea4 = x
x = checkpoint(self.bl5, x)
x = checkpoint(self.bl6, x_ori, x)
x_b_fea = checkpoint(self.branch_bl1, x_grad, ref_grad)
x_cat_1 = checkpoint(self.branch_bl2, x_b_fea, x_fea1)
x_cat_2 = checkpoint(self.branch_bl3, x_cat_1, x_fea2)
x_cat_3 = checkpoint(self.branch_bl4, x_cat_2, x_fea3)
x_branch = checkpoint(self.branch_bl5, x_cat_3, x_fea4, x_b_fea)
x_out_branch = checkpoint(self.conv_w, x_branch)
########
x_branch_d = x_branch
x_f_cat = torch.cat([x_branch_d, x], dim=1)
x_f_cat = checkpoint(self.f_block, x_f_cat)
x_out = self.f_concat(x_f_cat)
x_out = checkpoint(self.f_HR_conv0, x_out)
x_out = checkpoint(self.f_HR_conv1, x_out)
#########
return x_out_branch, x_out, x_grad
class SPSRNetSimplified(nn.Module):
def __init__(self, in_nc, out_nc, nf, nb, upscale=4):
super(SPSRNetSimplified, self).__init__()
n_upscale = int(math.log(upscale, 2))
# Feature branch
self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False)
self.model_shortcut_blk = nn.Sequential(*[RRDB(nf, gc=32) for _ in range(nb)])
self.feature_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
self.model_upsampler = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=False) for _ in range(n_upscale)])
self.feature_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False)
self.feature_hr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False)
# Grad branch
self.get_g_nopadding = ImageGradientNoPadding()
self.b_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False)
self.b_concat_decimate_1 = ConvGnLelu(2 * nf, nf, kernel_size=1, norm=False, activation=False, bias=False)
self.b_proc_block_1 = RRDB(nf, gc=32)
self.b_concat_decimate_2 = ConvGnLelu(2 * nf, nf, kernel_size=1, norm=False, activation=False, bias=False)
self.b_proc_block_2 = RRDB(nf, gc=32)
self.b_concat_decimate_3 = ConvGnLelu(2 * nf, nf, kernel_size=1, norm=False, activation=False, bias=False)
self.b_proc_block_3 = RRDB(nf, gc=32)
self.b_concat_decimate_4 = ConvGnLelu(2 * nf, nf, kernel_size=1, norm=False, activation=False, bias=False)
self.b_proc_block_4 = RRDB(nf, gc=32)
# Upsampling
self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False)
b_upsampler = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=False) for _ in range(n_upscale)])
grad_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False)
grad_hr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False)
self.branch_upsample = B.sequential(*b_upsampler, grad_hr_conv1, grad_hr_conv2)
# Conv used to output grad branch shortcut.
self.grad_branch_output_conv = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=False)
# Conjoin branch.
# Note: "_branch_pretrain" is a special tag used to denote parameters that get pretrained before the rest.
self._branch_pretrain_concat = ConvGnLelu(nf * 2, nf, kernel_size=1, norm=False, activation=False, bias=False)
self._branch_pretrain_block = RRDB(nf * 2, gc=32)
self._branch_pretrain_HR_conv0 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False)
self._branch_pretrain_HR_conv1 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=False)
def forward(self, x):
x_grad = self.get_g_nopadding(x)
x = self.model_fea_conv(x)
x_ori = x
for i in range(5):
x = self.model_shortcut_blk[i](x)
x_fea1 = x
for i in range(5):
x = self.model_shortcut_blk[i + 5](x)
x_fea2 = x
for i in range(5):
x = self.model_shortcut_blk[i + 10](x)
x_fea3 = x
for i in range(5):
x = self.model_shortcut_blk[i + 15](x)
x_fea4 = x
x = self.model_shortcut_blk[20:](x)
x = self.feature_lr_conv(x)
# short cut
x = x_ori + x
x = self.model_upsampler(x)
x = self.feature_hr_conv1(x)
x = self.feature_hr_conv2(x)
x_b_fea = self.b_fea_conv(x_grad)
x_cat_1 = torch.cat([x_b_fea, x_fea1], dim=1)
x_cat_1 = self.b_concat_decimate_1(x_cat_1)
x_cat_1 = self.b_proc_block_1(x_cat_1)
x_cat_2 = torch.cat([x_cat_1, x_fea2], dim=1)
x_cat_2 = self.b_concat_decimate_2(x_cat_2)
x_cat_2 = self.b_proc_block_2(x_cat_2)
x_cat_3 = torch.cat([x_cat_2, x_fea3], dim=1)
x_cat_3 = self.b_concat_decimate_3(x_cat_3)
x_cat_3 = self.b_proc_block_3(x_cat_3)
x_cat_4 = torch.cat([x_cat_3, x_fea4], dim=1)
x_cat_4 = self.b_concat_decimate_4(x_cat_4)
x_cat_4 = self.b_proc_block_4(x_cat_4)
x_cat_4 = self.grad_lr_conv(x_cat_4)
# short cut
x_cat_4 = x_cat_4 + x_b_fea
x_branch = self.branch_upsample(x_cat_4)
x_out_branch = self.grad_branch_output_conv(x_branch)
########
x_branch_d = x_branch
x__branch_pretrain_cat = torch.cat([x_branch_d, x], dim=1)
x__branch_pretrain_cat = self._branch_pretrain_block(x__branch_pretrain_cat)
x_out = self._branch_pretrain_concat(x__branch_pretrain_cat)
x_out = self._branch_pretrain_HR_conv0(x_out)
x_out = self._branch_pretrain_HR_conv1(x_out)
#########
return x_out_branch, x_out, x_grad
# Variant of Spsr6 which uses multiplexer blocks that feed off of a reference embedding. Also computes that embedding.
class Spsr7(nn.Module):
def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, multiplexer_reductions=3, recurrent=False, init_temperature=10):
super(Spsr7, self).__init__()
n_upscale = int(math.log(upscale, 2))
# processing the input embedding
self.reference_embedding = ReferenceImageBranch(nf)
self.recurrent = recurrent
if recurrent:
self.model_recurrent_conv = ConvGnLelu(3, nf, kernel_size=3, stride=2, norm=False, activation=False,
bias=True)
self.model_fea_recurrent_combine = ConvGnLelu(nf * 2, nf, 1, activation=False, norm=False, bias=False, weight_init_factor=.01)
# switch options
self.nf = nf
transformation_filters = nf
self.transformation_counts = xforms
multiplx_fn = functools.partial(QueryKeyMultiplexer, transformation_filters, embedding_channels=512, reductions=multiplexer_reductions)
transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5),
transformation_filters, kernel_size=3, depth=3,
weight_init_factor=.1)
# Feature branch
self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=7, norm=False, activation=False)
self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=None, transform_block=transform_fn,
attention_norm=True,
transform_count=self.transformation_counts, init_temp=init_temperature,
add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True)
self.sw2 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=None, transform_block=transform_fn,
attention_norm=True,
transform_count=self.transformation_counts, init_temp=init_temperature,
add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True)
# Grad branch. Note - groupnorm on this branch is REALLY bad. Avoid it like the plague.
self.get_g_nopadding = ImageGradientNoPadding()
self.grad_conv = ConvGnLelu(in_nc, nf, kernel_size=7, norm=False, activation=False, bias=False)
self.grad_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.3, final_norm=False)
self.sw_grad = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=None, transform_block=transform_fn,
attention_norm=True,
transform_count=self.transformation_counts // 2, init_temp=init_temperature,
add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True)
self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
self.grad_lr_conv2 = ConvGnLelu(nf, nf, kernel_size=1, norm=False, activation=True, bias=True)
self.upsample_grad = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=False) for _ in range(n_upscale)])
self.grad_branch_output_conv = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=True)
# Join branch (grad+fea)
self.noise_ref_join_conjoin = ReferenceJoinBlock(nf, residual_weight_init_factor=.1)
self.conjoin_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.3)
self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=None, transform_block=transform_fn,
attention_norm=True,
transform_count=self.transformation_counts, init_temp=init_temperature,
add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True)
self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=True) for _ in range(n_upscale)])
self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=True)
self.final_hr_conv2 = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=False)
self.switches = [self.sw1, self.sw2, self.sw_grad, self.conjoin_sw]
self.attentions = None
self.init_temperature = init_temperature
self.final_temperature_step = 10000
self.lr = None
def forward(self, x, ref, ref_center, update_attention_norm=True, recurrent=None):
# The attention_maps debugger outputs <x>. Save that here.
self.lr = x.detach().cpu()
x_grad = self.get_g_nopadding(x)
ref_code = self.reference_embedding(ref, ref_center)
ref_embedding = ref_code.view(-1, self.nf * 8, 1, 1).repeat(1, 1, x.shape[2] // 8, x.shape[3] // 8)
x = self.model_fea_conv(x)
if self.recurrent:
rec = self.model_recurrent_conv(recurrent)
br = self.model_fea_recurrent_combine(torch.cat([x, rec], dim=1))
x = x + br
x1 = x
x1, a1 = self.sw1(x1, identity=x, att_in=(x1, ref_embedding), do_checkpointing=True)
x2 = x1
x2, a2 = self.sw2(x2, identity=x1, att_in=(x2, ref_embedding), do_checkpointing=True)
x_grad = self.grad_conv(x_grad)
x_grad_identity = x_grad
x_grad, grad_fea_std = checkpoint(self.grad_ref_join, x_grad, x1)
x_grad, a3 = self.sw_grad(x_grad, identity=x_grad_identity, att_in=(x_grad, ref_embedding), do_checkpointing=True)
x_grad = checkpoint(self.grad_lr_conv, x_grad)
x_grad = checkpoint(self.grad_lr_conv2, x_grad)
x_grad_out = checkpoint(self.upsample_grad, x_grad)
x_grad_out = checkpoint(self.grad_branch_output_conv, x_grad_out)
x_out = x2
x_out, fea_grad_std = self.conjoin_ref_join(x_out, x_grad)
x_out, a4 = self.conjoin_sw(x_out, identity=x2, att_in=(x_out, ref_embedding), do_checkpointing=True)
x_out = checkpoint(self.final_lr_conv, x_out)
x_out = checkpoint(self.upsample, x_out)
x_out = checkpoint(self.final_hr_conv1, x_out)
x_out = checkpoint(self.final_hr_conv2, x_out)
self.attentions = [a1, a2, a3, a4]
self.grad_fea_std = grad_fea_std.detach().cpu()
self.fea_grad_std = fea_grad_std.detach().cpu()
return x_grad_out, x_out
def set_temperature(self, temp):
[sw.set_temperature(temp) for sw in self.switches]
def update_for_step(self, step, experiments_path='.'):
if self.attentions:
temp = max(1, 1 + self.init_temperature *
(self.final_temperature_step - step) / self.final_temperature_step)
self.set_temperature(temp)
if step % 500 == 0:
output_path = os.path.join(experiments_path, "attention_maps")
prefix = "amap_%i_a%i_%%i.png"
[save_attention_to_image_rgb(output_path, self.attentions[i], self.transformation_counts, prefix % (step, i), step, output_mag=False) for i in range(len(self.attentions))]
torchvision.utils.save_image(self.lr, os.path.join(experiments_path, "attention_maps", "amap_%i_base_image.png" % (step,)))
def get_debug_values(self, step, net_name):
temp = self.switches[0].switch.temperature
mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]
means = [i[0] for i in mean_hists]
hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists]
val = {"switch_temperature": temp,
"grad_branch_feat_intg_std_dev": self.grad_fea_std,
"conjoin_branch_grad_intg_std_dev": self.fea_grad_std}
for i in range(len(means)):
val["switch_%i_specificity" % (i,)] = means[i]
val["switch_%i_histogram" % (i,)] = hists[i]
return val
class AttentionBlock(nn.Module):
def __init__(self, nf, num_transforms, multiplexer_reductions, init_temperature=10, has_ref=True):
super(AttentionBlock, self).__init__()
self.nf = nf
self.transformation_counts = num_transforms
multiplx_fn = functools.partial(QueryKeyMultiplexer, nf, embedding_channels=512, reductions=multiplexer_reductions)
transform_fn = functools.partial(MultiConvBlock, nf, int(nf * 1.5),
nf, kernel_size=3, depth=4,
weight_init_factor=.1)
if has_ref:
self.ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.3, final_norm=False)
else:
self.ref_join = None
self.switch = ConfigurableSwitchComputer(nf, multiplx_fn,
pre_transform_block=None, transform_block=transform_fn,
attention_norm=True,
transform_count=self.transformation_counts, init_temp=init_temperature,
add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True)
def forward(self, x, mplex_ref=None, ref=None):
if self.ref_join is not None:
branch, ref_std = self.ref_join(x, ref)
return self.switch(branch, identity=x, att_in=(branch, mplex_ref)) + (ref_std,)
else:
return self.switch(x, identity=x, att_in=(x, mplex_ref))
class SwitchedSpsr(nn.Module):
def __init__(self, in_nc, nf, xforms=8, upscale=4, init_temperature=10):
super(SwitchedSpsr, self).__init__()
n_upscale = int(math.log(upscale, 2))
# switch options
transformation_filters = nf
switch_filters = nf
switch_reductions = 3
switch_processing_layers = 2
self.transformation_counts = xforms
multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, switch_reductions,
switch_processing_layers, self.transformation_counts, use_exp2=True)
pretransform_fn = functools.partial(ConvGnLelu, transformation_filters, transformation_filters, norm=False, bias=False, weight_init_factor=.1)
transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5),
transformation_filters, kernel_size=3, depth=3,
weight_init_factor=.1)
# Feature branch
self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False)
self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=pretransform_fn, transform_block=transform_fn,
attention_norm=True,
transform_count=self.transformation_counts, init_temp=init_temperature,
add_scalable_noise_to_transforms=True)
self.sw2 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=pretransform_fn, transform_block=transform_fn,
attention_norm=True,
transform_count=self.transformation_counts, init_temp=init_temperature,
add_scalable_noise_to_transforms=True)
self.feature_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False)
self.feature_hr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False)
# Grad branch
self.get_g_nopadding = ImageGradientNoPadding()
self.b_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False)
mplex_grad = functools.partial(ConvBasisMultiplexer, nf * 2, nf * 2, switch_reductions,
switch_processing_layers, self.transformation_counts // 2, use_exp2=True)
self.sw_grad = ConfigurableSwitchComputer(transformation_filters, mplex_grad,
pre_transform_block=pretransform_fn, transform_block=transform_fn,
attention_norm=True,
transform_count=self.transformation_counts // 2, init_temp=init_temperature,
add_scalable_noise_to_transforms=True)
# Upsampling
self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False)
self.grad_hr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False)
# Conv used to output grad branch shortcut.
self.grad_branch_output_conv = ConvGnLelu(nf, 3, kernel_size=1, norm=False, activation=False, bias=False)
# Conjoin branch.
# Note: "_branch_pretrain" is a special tag used to denote parameters that get pretrained before the rest.
transform_fn_cat = functools.partial(MultiConvBlock, transformation_filters * 2, int(transformation_filters * 1.5),
transformation_filters, kernel_size=3, depth=4,
weight_init_factor=.1)
pretransform_fn_cat = functools.partial(ConvGnLelu, transformation_filters * 2, transformation_filters * 2, norm=False, bias=False, weight_init_factor=.1)
self._branch_pretrain_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=pretransform_fn_cat, transform_block=transform_fn_cat,
attention_norm=True,
transform_count=self.transformation_counts, init_temp=init_temperature,
add_scalable_noise_to_transforms=True)
self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=True) for _ in range(n_upscale)])
self.upsample_grad = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=True) for _ in range(n_upscale)])
self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False)
self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=True)
self.final_hr_conv2 = ConvGnLelu(nf, 3, kernel_size=3, norm=False, activation=False, bias=False)
self.switches = [self.sw1, self.sw2, self.sw_grad, self._branch_pretrain_sw]
self.attentions = None
self.init_temperature = init_temperature
self.final_temperature_step = 10000
def forward(self, x):
x_grad = self.get_g_nopadding(x)
x = self.model_fea_conv(x)
x1, a1 = self.sw1(x, do_checkpointing=True)
x2, a2 = self.sw2(x1, do_checkpointing=True)
x_fea = self.feature_lr_conv(x2)
x_fea = self.feature_hr_conv2(x_fea)
x_b_fea = self.b_fea_conv(x_grad)
x_grad, a3 = self.sw_grad(x_b_fea, att_in=torch.cat([x1, x_b_fea], dim=1), output_attention_weights=True, do_checkpointing=True)
x_grad = checkpoint(self.grad_lr_conv, x_grad)
x_grad = checkpoint(self.grad_hr_conv, x_grad)
x_out_branch = checkpoint(self.upsample_grad, x_grad)
x_out_branch = self.grad_branch_output_conv(x_out_branch)
x__branch_pretrain_cat = torch.cat([x_grad, x_fea], dim=1)
x__branch_pretrain_cat, a4 = self._branch_pretrain_sw(x__branch_pretrain_cat, att_in=x_fea, identity=x_fea, output_attention_weights=True)
x_out = checkpoint(self.final_lr_conv, x__branch_pretrain_cat)
x_out = checkpoint(self.upsample, x_out)
x_out = checkpoint(self.final_hr_conv1, x_out)
x_out = self.final_hr_conv2(x_out)
self.attentions = [a1, a2, a3, a4]
return x_out_branch, x_out, x_grad
def set_temperature(self, temp):
[sw.set_temperature(temp) for sw in self.switches]
def update_for_step(self, step, experiments_path='.'):
if self.attentions:
temp = max(1, 1 + self.init_temperature *
(self.final_temperature_step - step) / self.final_temperature_step)
self.set_temperature(temp)
if step % 200 == 0:
output_path = os.path.join(experiments_path, "attention_maps", "a%i")
prefix = "attention_map_%i_%%i.png" % (step,)
[save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))]
def get_debug_values(self, step, net):
temp = self.switches[0].switch.temperature
mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]
means = [i[0] for i in mean_hists]
hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists]
val = {"switch_temperature": temp}
for i in range(len(means)):
val["switch_%i_specificity" % (i,)] = means[i]
val["switch_%i_histogram" % (i,)] = hists[i]
return val

View File

@ -1,163 +0,0 @@
from collections import OrderedDict
import torch
import torch.nn as nn
####################
# Basic blocks
####################
def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1):
# helper selecting activation
# neg_slope: for leakyrelu and init of prelu
# n_prelu: for p_relu num_parameters
act_type = act_type.lower()
if act_type == 'relu':
layer = nn.ReLU(inplace)
elif act_type == 'leakyrelu':
layer = nn.LeakyReLU(neg_slope, inplace)
elif act_type == 'prelu':
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
else:
raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
return layer
def norm(norm_type, nc):
# helper selecting normalization layer
norm_type = norm_type.lower()
if norm_type == 'batch':
layer = nn.BatchNorm2d(nc, affine=True)
elif norm_type == 'instance':
layer = nn.InstanceNorm2d(nc, affine=False)
elif norm_type == 'group':
layer = nn.GroupNorm(8, nc)
else:
raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
return layer
def pad(pad_type, padding):
# helper selecting padding layer
# if padding is 'zero', do by conv layers
pad_type = pad_type.lower()
if padding == 0:
return None
if pad_type == 'reflect':
layer = nn.ReflectionPad2d(padding)
elif pad_type == 'replicate':
layer = nn.ReplicationPad2d(padding)
else:
raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
return layer
def get_valid_padding(kernel_size, dilation):
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
padding = (kernel_size - 1) // 2
return padding
class ConcatBlock(nn.Module):
# Concat the output of a submodule to its input
def __init__(self, submodule):
super(ConcatBlock, self).__init__()
self.sub = submodule
def forward(self, x):
output = torch.cat((x, self.sub(x)), dim=1)
return output
def __repr__(self):
tmpstr = 'Identity .. \n|'
modstr = self.sub.__repr__().replace('\n', '\n|')
tmpstr = tmpstr + modstr
return tmpstr
class ShortcutBlock(nn.Module):
#Elementwise sum the output of a submodule to its input
def __init__(self, submodule):
super(ShortcutBlock, self).__init__()
self.sub = submodule
def forward(self, x):
return x, self.sub
def __repr__(self):
tmpstr = 'Identity + \n|'
modstr = self.sub.__repr__().replace('\n', '\n|')
tmpstr = tmpstr + modstr
return tmpstr
def sequential(*args):
# Flatten Sequential. It unwraps nn.Sequential.
if len(args) == 1:
if isinstance(args[0], OrderedDict):
raise NotImplementedError('sequential does not support OrderedDict input.')
return args[0] # No sequential is needed.
modules = []
for module in args:
if isinstance(module, nn.Sequential):
for submodule in module.children():
modules.append(submodule)
elif isinstance(module, nn.Module):
modules.append(module)
return nn.Sequential(*modules)
def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True, \
pad_type='zero', norm_type=None, act_type='relu', mode='CNA'):
'''
Conv layer with padding, normalization, activation
mode: CNA --> Conv -> Norm -> Act
NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
'''
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode)
padding = get_valid_padding(kernel_size, dilation)
p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
padding = padding if pad_type == 'zero' else 0
c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, \
dilation=dilation, bias=bias, groups=groups)
a = act(act_type) if act_type else None
if 'CNA' in mode:
n = norm(norm_type, out_nc) if norm_type else None
return sequential(p, c, n, a)
elif mode == 'NAC':
if norm_type is None and act_type is not None:
a = act(act_type, inplace=False)
# Important!
# input----ReLU(inplace)----Conv--+----output
# |________________________|
# inplace ReLU will modify the input, therefore wrong output
n = norm(norm_type, in_nc) if norm_type else None
return sequential(n, a, p, c)
####################
# Upsampler
####################
def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, \
pad_type='zero', norm_type=None, act_type='relu'):
'''
Pixel shuffle layer
(Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
Neural Network, CVPR17)
'''
conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias, \
pad_type=pad_type, norm_type=None, act_type=None)
pixel_shuffle = nn.PixelShuffle(upscale_factor)
n = norm(norm_type, out_nc) if norm_type else None
a = act(act_type) if act_type else None
return sequential(conv, pixel_shuffle, n, a)
def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, \
pad_type='zero', norm_type=None, act_type='relu', mode='nearest'):
# Up conv
# described in https://distill.pub/2016/deconv-checkerboard/
upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode)
conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias, \
pad_type=pad_type, norm_type=norm_type, act_type=act_type)
return sequential(upsample, conv)

View File

@ -1,55 +0,0 @@
import functools
import torch.nn as nn
import torch.nn.functional as F
import models.archs.arch_util as arch_util
class MSRResNet(nn.Module):
''' modified SRResNet'''
def __init__(self, in_nc=3, out_nc=3, nf=64, nb=16, upscale=4):
super(MSRResNet, self).__init__()
self.upscale = upscale
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
basic_block = functools.partial(arch_util.ResidualBlock_noBN, nf=nf)
self.recon_trunk = arch_util.make_layer(basic_block, nb)
# upsampling
if self.upscale == 2:
self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
self.pixel_shuffle = nn.PixelShuffle(2)
elif self.upscale == 3:
self.upconv1 = nn.Conv2d(nf, nf * 9, 3, 1, 1, bias=True)
self.pixel_shuffle = nn.PixelShuffle(3)
elif self.upscale == 4:
self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
self.upconv2 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
self.pixel_shuffle = nn.PixelShuffle(2)
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
# initialization
arch_util.initialize_weights([self.conv_first, self.upconv1, self.HRconv, self.conv_last],
0.1)
if self.upscale == 4:
arch_util.initialize_weights(self.upconv2, 0.1)
def forward(self, x):
fea = self.lrelu(self.conv_first(x))
out = self.recon_trunk(fea)
if self.upscale == 4:
out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
elif self.upscale == 3 or self.upscale == 2:
out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
out = self.conv_last(self.lrelu(self.HRconv(out)))
base = F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False)
out += base
return out

View File

@ -1,139 +0,0 @@
import functools
import torch
from torch.nn import init
import models.archs.biggan.biggan_layers as layers
import torch.nn as nn
# Discriminator architecture, same paradigm as G's above
def D_arch(ch=64, attention='64',ksize='333333', dilation='111111'):
arch = {}
arch[256] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8, 8, 16]],
'out_channels' : [item * ch for item in [1, 2, 4, 8, 8, 16, 16]],
'downsample' : [True] * 6 + [False],
'resolution' : [128, 64, 32, 16, 8, 4, 4 ],
'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')]
for i in range(2,8)}}
arch[128] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8, 16]],
'out_channels' : [item * ch for item in [1, 2, 4, 8, 16, 16]],
'downsample' : [True] * 5 + [False],
'resolution' : [64, 32, 16, 8, 4, 4],
'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')]
for i in range(2,8)}}
arch[64] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8]],
'out_channels' : [item * ch for item in [1, 2, 4, 8, 16]],
'downsample' : [True] * 4 + [False],
'resolution' : [32, 16, 8, 4, 4],
'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')]
for i in range(2,7)}}
arch[32] = {'in_channels' : [3] + [item * ch for item in [4, 4, 4]],
'out_channels' : [item * ch for item in [4, 4, 4, 4]],
'downsample' : [True, True, False, False],
'resolution' : [16, 16, 16, 16],
'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')]
for i in range(2,6)}}
return arch
class BigGanDiscriminator(nn.Module):
def __init__(self, D_ch=64, D_wide=True, resolution=128,
D_kernel_size=3, D_attn='64', num_D_SVs=1, num_D_SV_itrs=1, D_activation=nn.ReLU(inplace=False),
SN_eps=1e-12, output_dim=1, D_fp16=False,
D_init='ortho', skip_init=False, D_param='SN'):
super(BigGanDiscriminator, self).__init__()
# Width multiplier
self.ch = D_ch
# Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN?
self.D_wide = D_wide
# Resolution
self.resolution = resolution
# Kernel size
self.kernel_size = D_kernel_size
# Attention?
self.attention = D_attn
# Activation
self.activation = D_activation
# Initialization style
self.init = D_init
# Parameterization style
self.D_param = D_param
# Epsilon for Spectral Norm?
self.SN_eps = SN_eps
# Fp16?
self.fp16 = D_fp16
# Architecture
self.arch = D_arch(self.ch, self.attention)[resolution]
# Which convs, batchnorms, and linear layers to use
# No option to turn off SN in D right now
if self.D_param == 'SN':
self.which_conv = functools.partial(layers.SNConv2d,
kernel_size=3, padding=1,
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
eps=self.SN_eps)
self.which_linear = functools.partial(layers.SNLinear,
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
eps=self.SN_eps)
self.which_embedding = functools.partial(layers.SNEmbedding,
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
eps=self.SN_eps)
# Prepare model
# self.blocks is a doubly-nested list of modules, the outer loop intended
# to be over blocks at a given resolution (resblocks and/or self-attention)
self.blocks = []
for index in range(len(self.arch['out_channels'])):
self.blocks += [[layers.DBlock(in_channels=self.arch['in_channels'][index],
out_channels=self.arch['out_channels'][index],
which_conv=self.which_conv,
wide=self.D_wide,
activation=self.activation,
preactivation=(index > 0),
downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] else None))]]
# If attention on this block, attach it to the end
if self.arch['attention'][self.arch['resolution'][index]]:
print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index])
self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index],
self.which_conv)]
# Turn self.blocks into a ModuleList so that it's all properly registered.
self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
# Linear output layer. The output dimension is typically 1, but may be
# larger if we're e.g. turning this into a VAE with an inference output
self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim)
# Initialize weights
if not skip_init:
self.init_weights()
# Initialize
def init_weights(self):
self.param_count = 0
for module in self.modules():
if (isinstance(module, nn.Conv2d)
or isinstance(module, nn.Linear)
or isinstance(module, nn.Embedding)):
if self.init == 'ortho':
init.orthogonal_(module.weight)
elif self.init == 'N02':
init.normal_(module.weight, 0, 0.02)
elif self.init in ['glorot', 'xavier']:
init.xavier_uniform_(module.weight)
else:
print('Init style not recognized...')
self.param_count += sum([p.data.nelement() for p in module.parameters()])
print('Param count for D''s initialized parameters: %d' % self.param_count)
def forward(self, x, y=None):
# Stick x into h for cleaner for loops without flow control
h = x
# Loop over blocks
for index, blocklist in enumerate(self.blocks):
for block in blocklist:
h = block(h)
# Apply global sum pooling as in SN-GAN
h = torch.sum(self.activation(h), [2, 3])
# Get initial class-unconditional output
out = self.linear(h)
return out

View File

@ -1,457 +0,0 @@
''' Layers
This file contains various layers for the BigGAN models.
'''
import numpy as np
import torch
import torch.nn as nn
from torch.nn import init
import torch.optim as optim
import torch.nn.functional as F
from torch.nn import Parameter as P
# Projection of x onto y
def proj(x, y):
return torch.mm(y, x.t()) * y / torch.mm(y, y.t())
# Orthogonalize x wrt list of vectors ys
def gram_schmidt(x, ys):
for y in ys:
x = x - proj(x, y)
return x
# Apply num_itrs steps of the power method to estimate top N singular values.
def power_iteration(W, u_, update=True, eps=1e-12):
# Lists holding singular vectors and values
us, vs, svs = [], [], []
for i, u in enumerate(u_):
# Run one step of the power iteration
with torch.no_grad():
v = torch.matmul(u, W)
# Run Gram-Schmidt to subtract components of all other singular vectors
v = F.normalize(gram_schmidt(v, vs), eps=eps)
# Add to the list
vs += [v]
# Update the other singular vector
u = torch.matmul(v, W.t())
# Run Gram-Schmidt to subtract components of all other singular vectors
u = F.normalize(gram_schmidt(u, us), eps=eps)
# Add to the list
us += [u]
if update:
u_[i][:] = u
# Compute this singular value and add it to the list
svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))]
# svs += [torch.sum(F.linear(u, W.transpose(0, 1)) * v)]
return svs, us, vs
# Convenience passthrough function
class identity(nn.Module):
def forward(self, input):
return input
# Spectral normalization base class
class SN(object):
def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12):
# Number of power iterations per step
self.num_itrs = num_itrs
# Number of singular values
self.num_svs = num_svs
# Transposed?
self.transpose = transpose
# Epsilon value for avoiding divide-by-0
self.eps = eps
# Register a singular vector for each sv
for i in range(self.num_svs):
self.register_buffer('u%d' % i, torch.randn(1, num_outputs))
self.register_buffer('sv%d' % i, torch.ones(1))
# Singular vectors (u side)
@property
def u(self):
return [getattr(self, 'u%d' % i) for i in range(self.num_svs)]
# Singular values;
# note that these buffers are just for logging and are not used in training.
@property
def sv(self):
return [getattr(self, 'sv%d' % i) for i in range(self.num_svs)]
# Compute the spectrally-normalized weight
def W_(self):
W_mat = self.weight.view(self.weight.size(0), -1)
if self.transpose:
W_mat = W_mat.t()
# Apply num_itrs power iterations
for _ in range(self.num_itrs):
svs, us, vs = power_iteration(W_mat, self.u, update=self.training, eps=self.eps)
# Update the svs
if self.training:
with torch.no_grad(): # Make sure to do this in a no_grad() context or you'll get memory leaks!
for i, sv in enumerate(svs):
self.sv[i][:] = sv
return self.weight / svs[0]
# 2D Conv layer with spectral norm
class SNConv2d(nn.Conv2d, SN):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True,
num_svs=1, num_itrs=1, eps=1e-12):
nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias)
SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps)
def forward(self, x):
return F.conv2d(x, self.W_(), self.bias, self.stride,
self.padding, self.dilation, self.groups)
# Linear layer with spectral norm
class SNLinear(nn.Linear, SN):
def __init__(self, in_features, out_features, bias=True,
num_svs=1, num_itrs=1, eps=1e-12):
nn.Linear.__init__(self, in_features, out_features, bias)
SN.__init__(self, num_svs, num_itrs, out_features, eps=eps)
def forward(self, x):
return F.linear(x, self.W_(), self.bias)
# Embedding layer with spectral norm
# We use num_embeddings as the dim instead of embedding_dim here
# for convenience sake
class SNEmbedding(nn.Embedding, SN):
def __init__(self, num_embeddings, embedding_dim, padding_idx=None,
max_norm=None, norm_type=2, scale_grad_by_freq=False,
sparse=False, _weight=None,
num_svs=1, num_itrs=1, eps=1e-12):
nn.Embedding.__init__(self, num_embeddings, embedding_dim, padding_idx,
max_norm, norm_type, scale_grad_by_freq,
sparse, _weight)
SN.__init__(self, num_svs, num_itrs, num_embeddings, eps=eps)
def forward(self, x):
return F.embedding(x, self.W_())
# A non-local block as used in SA-GAN
# Note that the implementation as described in the paper is largely incorrect;
# refer to the released code for the actual implementation.
class Attention(nn.Module):
def __init__(self, ch, which_conv=SNConv2d, name='attention'):
super(Attention, self).__init__()
# Channel multiplier
self.ch = ch
self.which_conv = which_conv
self.theta = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False)
self.phi = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False)
self.g = self.which_conv(self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False)
self.o = self.which_conv(self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False)
# Learnable gain parameter
self.gamma = P(torch.tensor(0.), requires_grad=True)
def forward(self, x, y=None):
# Apply convs
theta = self.theta(x)
phi = F.max_pool2d(self.phi(x), [2, 2])
g = F.max_pool2d(self.g(x), [2, 2])
# Perform reshapes
theta = theta.view(-1, self.ch // 8, x.shape[2] * x.shape[3])
phi = phi.view(-1, self.ch // 8, x.shape[2] * x.shape[3] // 4)
g = g.view(-1, self.ch // 2, x.shape[2] * x.shape[3] // 4)
# Matmul and softmax to get attention maps
beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1)
# Attention map times g path
o = self.o(torch.bmm(g, beta.transpose(1, 2)).view(-1, self.ch // 2, x.shape[2], x.shape[3]))
return self.gamma * o + x
# Fused batchnorm op
def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5):
# Apply scale and shift--if gain and bias are provided, fuse them here
# Prepare scale
scale = torch.rsqrt(var + eps)
# If a gain is provided, use it
if gain is not None:
scale = scale * gain
# Prepare shift
shift = mean * scale
# If bias is provided, use it
if bias is not None:
shift = shift - bias
return x * scale - shift
# return ((x - mean) / ((var + eps) ** 0.5)) * gain + bias # The unfused way.
# Manual BN
# Calculate means and variances using mean-of-squares minus mean-squared
def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5):
# Cast x to float32 if necessary
float_x = x.float()
# Calculate expected value of x (m) and expected value of x**2 (m2)
# Mean of x
m = torch.mean(float_x, [0, 2, 3], keepdim=True)
# Mean of x squared
m2 = torch.mean(float_x ** 2, [0, 2, 3], keepdim=True)
# Calculate variance as mean of squared minus mean squared.
var = (m2 - m ** 2)
# Cast back to float 16 if necessary
var = var.type(x.type())
m = m.type(x.type())
# Return mean and variance for updating stored mean/var if requested
if return_mean_var:
return fused_bn(x, m, var, gain, bias, eps), m.squeeze(), var.squeeze()
else:
return fused_bn(x, m, var, gain, bias, eps)
# My batchnorm, supports standing stats
class myBN(nn.Module):
def __init__(self, num_channels, eps=1e-5, momentum=0.1):
super(myBN, self).__init__()
# momentum for updating running stats
self.momentum = momentum
# epsilon to avoid dividing by 0
self.eps = eps
# Momentum
self.momentum = momentum
# Register buffers
self.register_buffer('stored_mean', torch.zeros(num_channels))
self.register_buffer('stored_var', torch.ones(num_channels))
self.register_buffer('accumulation_counter', torch.zeros(1))
# Accumulate running means and vars
self.accumulate_standing = False
# reset standing stats
def reset_stats(self):
self.stored_mean[:] = 0
self.stored_var[:] = 0
self.accumulation_counter[:] = 0
def forward(self, x, gain, bias):
if self.training:
out, mean, var = manual_bn(x, gain, bias, return_mean_var=True, eps=self.eps)
# If accumulating standing stats, increment them
if self.accumulate_standing:
self.stored_mean[:] = self.stored_mean + mean.data
self.stored_var[:] = self.stored_var + var.data
self.accumulation_counter += 1.0
# If not accumulating standing stats, take running averages
else:
self.stored_mean[:] = self.stored_mean * (1 - self.momentum) + mean * self.momentum
self.stored_var[:] = self.stored_var * (1 - self.momentum) + var * self.momentum
return out
# If not in training mode, use the stored statistics
else:
mean = self.stored_mean.view(1, -1, 1, 1)
var = self.stored_var.view(1, -1, 1, 1)
# If using standing stats, divide them by the accumulation counter
if self.accumulate_standing:
mean = mean / self.accumulation_counter
var = var / self.accumulation_counter
return fused_bn(x, mean, var, gain, bias, self.eps)
# Simple function to handle groupnorm norm stylization
def groupnorm(x, norm_style):
# If number of channels specified in norm_style:
if 'ch' in norm_style:
ch = int(norm_style.split('_')[-1])
groups = max(int(x.shape[1]) // ch, 1)
# If number of groups specified in norm style
elif 'grp' in norm_style:
groups = int(norm_style.split('_')[-1])
# If neither, default to groups = 16
else:
groups = 16
return F.group_norm(x, groups)
# Class-conditional bn
# output size is the number of channels, input size is for the linear layers
# Andy's Note: this class feels messy but I'm not really sure how to clean it up
# Suggestions welcome! (By which I mean, refactor this and make a pull request
# if you want to make this more readable/usable).
class ccbn(nn.Module):
def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1,
cross_replica=False, mybn=False, norm_style='bn', ):
super(ccbn, self).__init__()
self.output_size, self.input_size = output_size, input_size
# Prepare gain and bias layers
self.gain = which_linear(input_size, output_size)
self.bias = which_linear(input_size, output_size)
# epsilon to avoid dividing by 0
self.eps = eps
# Momentum
self.momentum = momentum
# Use cross-replica batchnorm?
self.cross_replica = cross_replica
# Use my batchnorm?
self.mybn = mybn
# Norm style?
self.norm_style = norm_style
if self.cross_replica or self.mybn:
self.bn = myBN(output_size, self.eps, self.momentum)
elif self.norm_style in ['bn', 'in']:
self.register_buffer('stored_mean', torch.zeros(output_size))
self.register_buffer('stored_var', torch.ones(output_size))
def forward(self, x, y):
# Calculate class-conditional gains and biases
gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
bias = self.bias(y).view(y.size(0), -1, 1, 1)
# If using my batchnorm
if self.mybn or self.cross_replica:
return self.bn(x, gain=gain, bias=bias)
# else:
else:
if self.norm_style == 'bn':
out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None,
self.training, 0.1, self.eps)
elif self.norm_style == 'in':
out = F.instance_norm(x, self.stored_mean, self.stored_var, None, None,
self.training, 0.1, self.eps)
elif self.norm_style == 'gn':
out = groupnorm(x, self.normstyle)
elif self.norm_style == 'nonorm':
out = x
return out * gain + bias
def extra_repr(self):
s = 'out: {output_size}, in: {input_size},'
s += ' cross_replica={cross_replica}'
return s.format(**self.__dict__)
# Normal, non-class-conditional BN
class bn(nn.Module):
def __init__(self, output_size, eps=1e-5, momentum=0.1,
cross_replica=False, mybn=False):
super(bn, self).__init__()
self.output_size = output_size
# Prepare gain and bias layers
self.gain = P(torch.ones(output_size), requires_grad=True)
self.bias = P(torch.zeros(output_size), requires_grad=True)
# epsilon to avoid dividing by 0
self.eps = eps
# Momentum
self.momentum = momentum
# Use cross-replica batchnorm?
self.cross_replica = cross_replica
# Use my batchnorm?
self.mybn = mybn
if self.cross_replica or mybn:
self.bn = myBN(output_size, self.eps, self.momentum)
# Register buffers if neither of the above
else:
self.register_buffer('stored_mean', torch.zeros(output_size))
self.register_buffer('stored_var', torch.ones(output_size))
def forward(self, x, y=None):
if self.cross_replica or self.mybn:
gain = self.gain.view(1, -1, 1, 1)
bias = self.bias.view(1, -1, 1, 1)
return self.bn(x, gain=gain, bias=bias)
else:
return F.batch_norm(x, self.stored_mean, self.stored_var, self.gain,
self.bias, self.training, self.momentum, self.eps)
# Generator blocks
# Note that this class assumes the kernel size and padding (and any other
# settings) have been selected in the main generator module and passed in
# through the which_conv arg. Similar rules apply with which_bn (the input
# size [which is actually the number of channels of the conditional info] must
# be preselected)
class GBlock(nn.Module):
def __init__(self, in_channels, out_channels,
which_conv=nn.Conv2d, which_bn=bn, activation=None,
upsample=None):
super(GBlock, self).__init__()
self.in_channels, self.out_channels = in_channels, out_channels
self.which_conv, self.which_bn = which_conv, which_bn
self.activation = activation
self.upsample = upsample
# Conv layers
self.conv1 = self.which_conv(self.in_channels, self.out_channels)
self.conv2 = self.which_conv(self.out_channels, self.out_channels)
self.learnable_sc = in_channels != out_channels or upsample
if self.learnable_sc:
self.conv_sc = self.which_conv(in_channels, out_channels,
kernel_size=1, padding=0)
# Batchnorm layers
self.bn1 = self.which_bn(in_channels)
self.bn2 = self.which_bn(out_channels)
# upsample layers
self.upsample = upsample
def forward(self, x, y):
h = self.activation(self.bn1(x, y))
if self.upsample:
h = self.upsample(h)
x = self.upsample(x)
h = self.conv1(h)
h = self.activation(self.bn2(h, y))
h = self.conv2(h)
if self.learnable_sc:
x = self.conv_sc(x)
return h + x
# Residual block for the discriminator
class DBlock(nn.Module):
def __init__(self, in_channels, out_channels, which_conv=SNConv2d, wide=True,
preactivation=False, activation=None, downsample=None, ):
super(DBlock, self).__init__()
self.in_channels, self.out_channels = in_channels, out_channels
# If using wide D (as in SA-GAN and BigGAN), change the channel pattern
self.hidden_channels = self.out_channels if wide else self.in_channels
self.which_conv = which_conv
self.preactivation = preactivation
self.activation = activation
self.downsample = downsample
# Conv layers
self.conv1 = self.which_conv(self.in_channels, self.hidden_channels)
self.conv2 = self.which_conv(self.hidden_channels, self.out_channels)
self.learnable_sc = True if (in_channels != out_channels) or downsample else False
if self.learnable_sc:
self.conv_sc = self.which_conv(in_channels, out_channels,
kernel_size=1, padding=0)
def shortcut(self, x):
if self.preactivation:
if self.learnable_sc:
x = self.conv_sc(x)
if self.downsample:
x = self.downsample(x)
else:
if self.downsample:
x = self.downsample(x)
if self.learnable_sc:
x = self.conv_sc(x)
return x
def forward(self, x):
if self.preactivation:
# h = self.activation(x) # NOT TODAY SATAN
# Andy's note: This line *must* be an out-of-place ReLU or it
# will negatively affect the shortcut connection.
h = F.relu(x)
else:
h = x
h = self.conv1(h)
h = self.conv2(self.activation(h))
if self.downsample:
h = self.downsample(h)
return h + self.shortcut(x)

View File

@ -1,14 +1,11 @@
import copy
import random
from functools import wraps
from time import time
import torch
import torch.nn.functional as F
from torch import nn
from data.byol_attachment import reconstructed_shared_regions
from models.byol.byol_model_wrapper import singleton, EMA, MLP, get_module_device, set_requires_grad, \
from models.archs.byol.byol_model_wrapper import singleton, EMA, get_module_device, set_requires_grad, \
update_moving_average
from utils.util import checkpoint

@ -0,0 +1 @@
Subproject commit db2b7899ea8506e90418dbd389300c49bdbb55c3

View File

@ -1,47 +0,0 @@
import torch
from torch import nn
from lambda_networks import LambdaLayer
from torch.nn import GroupNorm
from models.archs.RRDBNet_arch import ResidualDenseBlock
from models.archs.arch_util import ConvGnLelu
class LambdaRRDB(nn.Module):
"""Residual in Residual Dense Block.
Used in RRDB-Net in ESRGAN.
Args:
mid_channels (int): Channel number of intermediate features.
growth_channels (int): Channels for each growth.
"""
def __init__(self, mid_channels, growth_channels=32, reduce_to=None):
super(LambdaRRDB, self).__init__()
if reduce_to is None:
reduce_to = mid_channels
self.lam1 = LambdaLayer(dim=mid_channels, dim_out=mid_channels, r=23, dim_k=16, heads=4, dim_u=4)
self.gn1 = GroupNorm(num_groups=8, num_channels=mid_channels)
self.lam2 = LambdaLayer(dim=mid_channels, dim_out=mid_channels, r=23, dim_k=16, heads=4, dim_u=4)
self.gn2 = GroupNorm(num_groups=8, num_channels=mid_channels)
self.lam3 = LambdaLayer(dim=mid_channels, dim_out=reduce_to, r=23, dim_k=16, heads=4, dim_u=4)
self.gn3 = GroupNorm(num_groups=8, num_channels=mid_channels)
self.conv = ConvGnLelu(reduce_to, reduce_to, kernel_size=1, bias=True, norm=False, activation=False, weight_init_factor=.1)
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
out = self.lam1(x)
out = self.gn1(out)
out = self.lam2(out)
out = self.gn1(out)
out = self.lam3(out)
out = self.gn3(out)
return self.conv(out) * .2 + x

View File

@ -1,206 +0,0 @@
import torch.nn as nn
import torch.nn.functional as F
from models.archs.RRDBNet_arch import RRDB
from models.archs.arch_util import make_layer, default_init_weights, ConvGnSilu, ConvGnLelu, PixelUnshuffle
from utils.util import checkpoint, sequential_checkpoint
class MultiLevelRRDB(nn.Module):
def __init__(self, nf, gc, levels):
super().__init__()
self.levels = levels
self.level_rrdbs = nn.ModuleList([RRDB(nf, growth_channels=gc) for i in range(levels)])
# Trunks should be fed in in order HR->LR
def forward(self, trunk):
for i in reversed(range(self.levels)):
lvl_scale = (2**i)
lvl_res = self.level_rrdbs[i](F.interpolate(trunk, scale_factor=1/lvl_scale, mode="area"), return_residual=True)
trunk = trunk + F.interpolate(lvl_res, scale_factor=lvl_scale, mode="nearest")
return trunk
class MultiResRRDBNet(nn.Module):
def __init__(self,
in_channels,
out_channels,
mid_channels=64,
l1_blocks=3,
l2_blocks=4,
l3_blocks=6,
growth_channels=32,
scale=4,
):
super().__init__()
self.scale = scale
self.in_channels = in_channels
self.conv_first = nn.Conv2d(in_channels, mid_channels, 7, stride=1, padding=3)
self.l3_blocks = nn.ModuleList([MultiLevelRRDB(mid_channels, growth_channels, 3) for _ in range(l1_blocks)])
self.l2_blocks = nn.ModuleList([MultiLevelRRDB(mid_channels, growth_channels, 2) for _ in range(l2_blocks)])
self.l1_blocks = nn.ModuleList([MultiLevelRRDB(mid_channels, growth_channels, 1) for _ in range(l3_blocks)])
self.block_levels = [self.l3_blocks, self.l2_blocks, self.l1_blocks]
self.conv_body = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
# upsample
self.conv_up1 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
self.conv_up2 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
self.conv_hr = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
self.conv_last = nn.Conv2d(mid_channels, out_channels, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
for m in [
self.conv_first, self.conv_first, self.conv_body, self.conv_up1,
self.conv_up2, self.conv_hr, self.conv_last
]:
if m is not None:
default_init_weights(m, 0.1)
def forward(self, x):
trunk = self.conv_first(x)
for block_set in self.block_levels:
for block in block_set:
trunk = checkpoint(block, trunk)
body_feat = self.conv_body(trunk)
feat = trunk + body_feat
# upsample
out = self.lrelu(
self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
if self.scale == 4:
out = self.lrelu(
self.conv_up2(F.interpolate(out, scale_factor=2, mode='nearest')))
else:
out = self.lrelu(self.conv_up2(out))
out = self.conv_last(self.lrelu(self.conv_hr(out)))
return out
def visual_dbg(self, step, path):
pass
class SteppedResRRDBNet(nn.Module):
def __init__(self,
in_channels,
out_channels,
mid_channels=64,
l1_blocks=3,
l2_blocks=3,
growth_channels=32,
scale=4,
):
super().__init__()
self.scale = scale
self.in_channels = in_channels
self.conv_first = nn.Conv2d(in_channels, mid_channels, 7, stride=2, padding=3)
self.conv_second = nn.Conv2d(mid_channels, mid_channels*2, 3, stride=2, padding=1)
self.l1_blocks = nn.Sequential(*[RRDB(mid_channels*2, growth_channels*2) for _ in range(l1_blocks)])
self.l1_upsample_conv = nn.Conv2d(mid_channels*2, mid_channels, 3, stride=1, padding=1)
self.l2_blocks = nn.Sequential(*[RRDB(mid_channels, growth_channels, 2) for _ in range(l2_blocks)])
self.conv_body = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
# upsample
self.conv_up1 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
self.conv_up2 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
self.conv_hr = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
self.conv_last = nn.Conv2d(mid_channels, out_channels, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
for m in [
self.conv_first, self.conv_second, self.l1_upsample_conv, self.conv_body, self.conv_up1,
self.conv_up2, self.conv_hr, self.conv_last
]:
if m is not None:
default_init_weights(m, 0.1)
def forward(self, x):
trunk = self.conv_first(x)
trunk = self.conv_second(trunk)
trunk = sequential_checkpoint(self.l1_blocks, len(self.l2_blocks), trunk)
trunk = F.interpolate(trunk, scale_factor=2, mode="nearest")
trunk = self.l1_upsample_conv(trunk)
trunk = sequential_checkpoint(self.l2_blocks, len(self.l2_blocks), trunk)
body_feat = self.conv_body(trunk)
feat = trunk + body_feat
# upsample
out = self.lrelu(
self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
if self.scale == 4:
out = self.lrelu(
self.conv_up2(F.interpolate(out, scale_factor=2, mode='nearest')))
else:
out = self.lrelu(self.conv_up2(out))
out = self.conv_last(self.lrelu(self.conv_hr(out)))
return out
def visual_dbg(self, step, path):
pass
class PixelShufflingSteppedResRRDBNet(nn.Module):
def __init__(self,
in_channels,
out_channels,
mid_channels=64,
l1_blocks=3,
l2_blocks=3,
growth_channels=32,
scale=2,
):
super().__init__()
self.scale = scale * 2 # This RRDB operates at half-scale resolution.
self.in_channels = in_channels
self.pix_unshuffle = PixelUnshuffle(4)
self.conv_first = nn.Conv2d(4*4*in_channels, mid_channels*2, 3, stride=1, padding=1)
self.l1_blocks = nn.Sequential(*[RRDB(mid_channels*2, growth_channels*2) for _ in range(l1_blocks)])
self.l1_upsample_conv = nn.Conv2d(mid_channels*2, mid_channels, 3, stride=1, padding=1)
self.l2_blocks = nn.Sequential(*[RRDB(mid_channels, growth_channels, 2) for _ in range(l2_blocks)])
self.conv_body = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
# upsample
self.conv_up1 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
self.conv_up2 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
self.conv_hr = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
self.conv_last = nn.Conv2d(mid_channels, out_channels, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
for m in [
self.conv_first, self.l1_upsample_conv, self.conv_body, self.conv_up1,
self.conv_up2, self.conv_hr, self.conv_last
]:
if m is not None:
default_init_weights(m, 0.1)
def forward(self, x):
trunk = self.conv_first(self.pix_unshuffle(x))
trunk = sequential_checkpoint(self.l1_blocks, len(self.l1_blocks), trunk)
trunk = F.interpolate(trunk, scale_factor=2, mode="nearest")
trunk = self.l1_upsample_conv(trunk)
trunk = sequential_checkpoint(self.l2_blocks, len(self.l2_blocks), trunk)
body_feat = self.conv_body(trunk)
feat = trunk + body_feat
# upsample
out = self.lrelu(
self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
if self.scale == 4:
out = self.lrelu(
self.conv_up2(F.interpolate(out, scale_factor=2, mode='nearest')))
else:
out = self.lrelu(self.conv_up2(out))
out = self.conv_last(self.lrelu(self.conv_hr(out)))
return out
def visual_dbg(self, step, path):
pass

View File

@ -1,98 +0,0 @@
import torch
from torch import nn
from models.archs.arch_util import ConvGnLelu, ExpansionBlock
from models.flownet2.networks.resample2d_package.resample2d import Resample2d
from utils.util import checkpoint
import torch.nn.functional as F
class Pyramid(nn.Module):
def __init__(self, nf, depth, processing_convs_per_layer, processing_at_point, scale_per_level=2, block=ConvGnLelu,
norm=True, return_outlevels=False):
super(Pyramid, self).__init__()
levels = []
current_filters = nf
self.return_outlevels = return_outlevels
for d in range(depth):
level = [block(current_filters, int(current_filters*scale_per_level), kernel_size=3, stride=2, activation=True, norm=False, bias=False)]
current_filters = int(current_filters*scale_per_level)
for pc in range(processing_convs_per_layer):
level.append(block(current_filters, current_filters, kernel_size=3, activation=True, norm=norm, bias=False))
levels.append(nn.Sequential(*level))
self.downsamples = nn.ModuleList(levels)
if processing_at_point > 0:
point_processor = []
for p in range(processing_at_point):
point_processor.append(block(current_filters, current_filters, kernel_size=3, activation=True, norm=norm, bias=False))
self.point_processor = nn.Sequential(*point_processor)
else:
self.point_processor = None
levels = []
for d in range(depth):
level = [ExpansionBlock(current_filters, int(current_filters / scale_per_level), block=block)]
current_filters = int(current_filters / scale_per_level)
for pc in range(processing_convs_per_layer):
level.append(block(current_filters, current_filters, kernel_size=3, activation=True, norm=norm, bias=False))
levels.append(nn.ModuleList(level))
self.upsamples = nn.ModuleList(levels)
def forward(self, x):
passthroughs = []
fea = x
for lvl in self.downsamples:
passthroughs.append(fea)
fea = lvl(fea)
out_levels = []
fea = self.point_processor(fea)
for i, lvl in enumerate(self.upsamples):
out_levels.append(fea)
for j, sublvl in enumerate(lvl):
if j == 0:
fea = sublvl(fea, passthroughs[-1-i])
else:
fea = sublvl(fea)
out_levels.append(fea)
if self.return_outlevels:
return tuple(out_levels)
else:
return fea
class BasicResamplingFlowNet(nn.Module):
def create_termini(self, filters):
return nn.Sequential(ConvGnLelu(int(filters), 2, kernel_size=3, activation=False, norm=False, bias=True),
nn.Tanh())
def __init__(self, nf, resample_scale=1):
super(BasicResamplingFlowNet, self).__init__()
self.initial_conv = ConvGnLelu(6, nf, kernel_size=7, activation=False, norm=False, bias=True)
self.pyramid = Pyramid(nf, 3, 0, 1, 1.5, return_outlevels=True)
self.termini = nn.ModuleList([self.create_termini(nf*1.5**3),
self.create_termini(nf*1.5**2),
self.create_termini(nf*1.5)])
self.terminus = nn.Sequential(ConvGnLelu(nf, nf, kernel_size=3, activation=True, norm=True, bias=True),
ConvGnLelu(nf, nf, kernel_size=3, activation=True, norm=True, bias=False),
ConvGnLelu(nf, nf//2, kernel_size=3, activation=False, norm=False, bias=True),
ConvGnLelu(nf//2, 2, kernel_size=3, activation=False, norm=False, bias=True),
nn.Tanh())
self.scale = resample_scale
self.resampler = Resample2d()
def forward(self, left, right):
fea = self.initial_conv(torch.cat([left, right], dim=1))
levels = checkpoint(self.pyramid, fea)
flos = []
compares = []
for i, level in enumerate(levels):
if i == 3:
flow = checkpoint(self.terminus, level) * self.scale
else:
flow = self.termini[i](level) * self.scale
img_scale = 1/2**(3-i)
flos.append(self.resampler(F.interpolate(left, scale_factor=img_scale, mode="area").float(), flow.float()))
compares.append(F.interpolate(right, scale_factor=img_scale, mode="area"))
flos_structural_var = torch.var(flos[-1], dim=[-1,-2])
return flos, compares, flos_structural_var

View File

@ -1,80 +0,0 @@
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from math import exp
def gaussian(window_size, sigma):
gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
return gauss / gauss.sum()
def create_window(window_size, channel):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
return window
def _ssim(img1, img2, window, window_size, channel, size_average=True, raw=False):
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
C1 = 0.01 ** 2
C2 = 0.03 ** 2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
if size_average:
return ssim_map.mean()
elif raw:
return ssim_map
else:
return ssim_map.mean(1).mean(1).mean(1)
class SSIM(torch.nn.Module):
def __init__(self, window_size=11, size_average=True, raw=False):
super(SSIM, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.raw = raw
self.channel = 1
self.window = create_window(window_size, self.channel)
def forward(self, img1, img2):
(_, channel, _, _) = img1.size()
if channel == self.channel and self.window.data.type() == img1.data.type():
window = self.window
else:
window = create_window(self.window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
self.window = window
self.channel = channel
return _ssim(img1, img2, window, self.window_size, channel, self.size_average, self.raw)
def ssim(img1, img2, window_size=11, size_average=True):
(_, channel, _, _) = img1.size()
window = create_window(window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
return _ssim(img1, img2, window, window_size, channel, size_average)

View File

@ -1,221 +0,0 @@
import torch.nn as nn
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.util import checkpoint
from torch.autograd import Variable
def default_conv(in_channels, out_channels, kernel_size, bias=True):
return nn.Conv2d(
in_channels, out_channels, kernel_size,
padding=(kernel_size//2), bias=bias)
class MeanShift(nn.Conv2d):
def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
super(MeanShift, self).__init__(3, 3, kernel_size=1)
std = torch.Tensor(rgb_std)
self.weight.data = torch.eye(3).view(3, 3, 1, 1)
self.weight.data.div_(std.view(3, 1, 1, 1))
self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
self.bias.data.div_(std)
self.requires_grad = False
class BasicBlock(nn.Sequential):
def __init__(
self, in_channels, out_channels, kernel_size, stride=1, bias=False,
bn=True, act=nn.ReLU(True)):
m = [nn.Conv2d(
in_channels, out_channels, kernel_size,
padding=(kernel_size//2), stride=stride, bias=bias)
]
if bn: m.append(nn.BatchNorm2d(out_channels))
if act is not None: m.append(act)
super(BasicBlock, self).__init__(*m)
class ResBlock(nn.Module):
def __init__(
self, conv, n_feat, kernel_size,
bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
super(ResBlock, self).__init__()
m = []
for i in range(2):
m.append(conv(n_feat, n_feat, kernel_size, bias=bias))
if bn: m.append(nn.BatchNorm2d(n_feat))
if i == 0: m.append(act)
self.body = nn.Sequential(*m)
self.res_scale = res_scale
def forward(self, x):
res = self.body(x).mul(self.res_scale)
res += x
return res
class Upsampler(nn.Sequential):
def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True):
m = []
if (scale & (scale - 1)) == 0: # Is scale = 2^n?
for _ in range(int(math.log(scale, 2))):
m.append(conv(n_feat, 4 * n_feat, 3, bias))
m.append(nn.PixelShuffle(2))
if bn: m.append(nn.BatchNorm2d(n_feat))
if act: m.append(act())
elif scale == 3:
m.append(conv(n_feat, 9 * n_feat, 3, bias))
m.append(nn.PixelShuffle(3))
if bn: m.append(nn.BatchNorm2d(n_feat))
if act: m.append(act())
else:
raise NotImplementedError
super(Upsampler, self).__init__(*m)
def make_model(args, parent=False):
return RCAN(args)
## Channel Attention (CA) Layer
class CALayer(nn.Module):
def __init__(self, channel, reduction=16):
super(CALayer, self).__init__()
# global average pooling: feature --> point
self.avg_pool = nn.AdaptiveAvgPool2d(1)
# feature channel downscale and upscale --> channel weight
self.conv_du = nn.Sequential(
nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
nn.Sigmoid()
)
def forward(self, x):
y = self.avg_pool(x)
y = self.conv_du(y)
return x * y
## Residual Channel Attention Block (RCAB)
class RCAB(nn.Module):
def __init__(
self, conv, n_feat, kernel_size, reduction,
bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
super(RCAB, self).__init__()
modules_body = []
for i in range(2):
modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
if bn: modules_body.append(nn.BatchNorm2d(n_feat))
if i == 0: modules_body.append(act)
modules_body.append(CALayer(n_feat, reduction))
self.body = nn.Sequential(*modules_body)
self.res_scale = res_scale
def forward(self, x):
res = self.body(x)
# res = self.body(x).mul(self.res_scale)
res += x
return res
## Residual Group (RG)
class ResidualGroup(nn.Module):
def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks):
super(ResidualGroup, self).__init__()
modules_body = []
modules_body = [
RCAB(
conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \
for _ in range(n_resblocks)]
modules_body.append(conv(n_feat, n_feat, kernel_size))
self.body = nn.Sequential(*modules_body)
def forward(self, x):
res = self.body(x)
res += x
return res
## Residual Channel Attention Network (RCAN)
class RCAN(nn.Module):
def __init__(self, args, conv=default_conv):
super(RCAN, self).__init__()
n_resgroups = args.n_resgroups
n_resblocks = args.n_resblocks
n_feats = args.n_feats
kernel_size = 3
reduction = args.reduction
scale = args.scale
act = nn.ReLU(True)
# RGB mean for DIV2K
rgb_mean = (0.4488, 0.4371, 0.4040)
rgb_std = (1.0, 1.0, 1.0)
self.sub_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std)
# define head module
modules_head = [conv(args.n_colors, n_feats, kernel_size)]
# define body module
modules_body = [
ResidualGroup(
conv, n_feats, kernel_size, reduction, act=act, res_scale=args.res_scale, n_resblocks=n_resblocks) \
for _ in range(n_resgroups)]
modules_body.append(conv(n_feats, n_feats, kernel_size))
# define tail module
modules_tail = [
Upsampler(conv, scale, n_feats, act=False),
conv(n_feats, args.n_colors, kernel_size)]
self.add_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
self.head = nn.Sequential(*modules_head)
self.body = nn.Sequential(*modules_body)
self.tail = nn.Sequential(*modules_tail)
def forward(self, x):
x = self.sub_mean(x)
x = self.head(x)
res = self.body(x)
res += x
x = self.tail(res)
x = self.add_mean(x)
return x
def load_state_dict(self, state_dict, strict=False):
own_state = self.state_dict()
for name, param in state_dict.items():
if name in own_state:
if isinstance(param, nn.Parameter):
param = param.data
try:
own_state[name].copy_(param)
except Exception:
if name.find('tail') >= 0:
print('Replace pre-trained upsampler to new one...')
else:
raise RuntimeError('While copying the parameter named {}, '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}.'
.format(name, own_state[name].size(), param.size()))
elif strict:
if name.find('tail') == -1:
raise KeyError('unexpected key "{}" in state_dict'
.format(name))
if strict:
missing = set(own_state.keys()) - set(state_dict.keys())
if len(missing) > 0:
raise KeyError('missing keys in state_dict: "{}"'.format(missing))

View File

View File

@ -2,8 +2,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchvision
from models.steps.injectors import Injector
from models.injectors import Injector
from utils.util import checkpoint

View File

@ -6,9 +6,8 @@ import torchvision
from torch.cuda.amp import autocast
from data.multiscale_dataset import build_multiscale_patch_index_map
from models.steps.injectors import Injector
from models.steps.losses import extract_params_from_state
from models.steps.tecogan_losses import extract_inputs_index
from models.injectors import Injector
from models.losses import extract_params_from_state
import os.path as osp

View File

@ -1,8 +1,8 @@
import torch
from torch.cuda.amp import autocast
from models.flownet2.networks.resample2d_package.resample2d import Resample2d
from models.flownet2.utils.flow_utils import flow2img
from models.steps.injectors import Injector
from models.archs.flownet2.networks import Resample2d
from models.archs.flownet2 import flow2img
from models.injectors import Injector
def create_stereoscopic_injector(opt, env):

View File

@ -1,15 +1,15 @@
from torch.cuda.amp import autocast
from models.archs.stylegan.stylegan2_lucidrains import gradient_penalty
from models.steps.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name
from models.flownet2.networks.resample2d_package.resample2d import Resample2d
from models.steps.injectors import Injector
from models.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name
from models.archs.flownet2.networks import Resample2d
from models.injectors import Injector
import torch
import torch.nn.functional as F
import os
import os.path as osp
import torchvision
import torch.distributed as dist
def create_teco_loss(opt, env):
type = opt['type']

View File

@ -3,22 +3,20 @@ import random
import torch.nn
from torch.cuda.amp import autocast
from models.archs.SPSR_arch import ImageGradientNoPadding
from models.archs.pytorch_ssim import SSIM
from utils.weight_scheduler import get_scheduler_for_opt
from models.steps.losses import extract_params_from_state
from models.losses import extract_params_from_state
# Injectors are a way to sythesize data within a step that can then be used (and reused) by loss functions.
def create_injector(opt_inject, env):
type = opt_inject['type']
if 'teco_' in type:
from models.steps.tecogan_losses import create_teco_injector
from models.custom_training_components import create_teco_injector
return create_teco_injector(opt_inject, env)
elif 'progressive_' in type:
from models.steps.progressive_zoom import create_progressive_zoom_injector
from models.custom_training_components import create_progressive_zoom_injector
return create_progressive_zoom_injector(opt_inject, env)
elif 'stereoscopic_' in type:
from models.steps.stereoscopic import create_stereoscopic_injector
from models.custom_training_components import create_stereoscopic_injector
return create_stereoscopic_injector(opt_inject, env)
elif 'igpt' in type:
from models.archs.transformers.igpt import gpt2
@ -29,8 +27,6 @@ def create_injector(opt_inject, env):
return DiscriminatorInjector(opt_inject, env)
elif type == 'scheduled_scalar':
return ScheduledScalarInjector(opt_inject, env)
elif type == 'img_grad':
return ImageGradientInjector(opt_inject, env)
elif type == 'add_noise':
return AddNoiseInjector(opt_inject, env)
elif type == 'greyscale':
@ -47,14 +43,10 @@ def create_injector(opt_inject, env):
return ForEachInjector(opt_inject, env)
elif type == 'constant':
return ConstantInjector(opt_inject, env)
elif type == 'fft':
return ImageFftInjector(opt_inject, env)
elif type == 'extract_indices':
return IndicesExtractor(opt_inject, env)
elif type == 'random_shift':
return RandomShiftInjector(opt_inject, env)
elif type == 'psnr':
return PsnrInjector(opt_inject, env)
elif type == 'batch_rotate':
return BatchRotateInjector(opt_inject, env)
elif type == 'sr_diffs':
@ -134,16 +126,6 @@ class DiscriminatorInjector(Injector):
return new_state
# Creates an image gradient from [in] and injects it into [out]
class ImageGradientInjector(Injector):
def __init__(self, opt, env):
super(ImageGradientInjector, self).__init__(opt, env)
self.img_grad_fn = ImageGradientNoPadding().to(env['device'])
def forward(self, state):
return {self.opt['out']: self.img_grad_fn(state[self.opt['in']])}
# Injects a scalar that is modulated with a specified schedule. Useful for increasing or decreasing the influence
# of something over time.
class ScheduledScalarInjector(Injector):
@ -320,37 +302,6 @@ class ConstantInjector(Injector):
return { self.opt['out']: out }
class ImageFftInjector(Injector):
def __init__(self, opt, env):
super(ImageFftInjector, self).__init__(opt, env)
self.is_forward = opt['forward'] # Whether to compute a forward FFT or backward.
self.eps = 1e-100
def forward(self, state):
if self.forward:
fftim = torch.rfft(state[self.input], signal_ndim=2, normalized=True)
b, f, h, w, c = fftim.shape
fftim = fftim.permute(0,1,4,2,3).reshape(b,-1,h,w)
# Normalize across spatial dimension
mean = torch.mean(fftim, dim=(0,1))
fftim = fftim - mean
std = torch.std(fftim, dim=(0,1))
fftim = (fftim + self.eps) / std
return {self.output: fftim,
'%s_std' % (self.output,): std,
'%s_mean' % (self.output,): mean}
else:
b, f, h, w = state[self.input].shape
# First, de-normalize the FFT.
mean = state['%s_mean' % (self.input,)]
std = state['%s_std' % (self.input,)]
fftim = state[self.input] * std + mean - self.eps
# Second, recover the FFT dimensions from the given filters.
fftim = fftim.reshape(b, f // 2, 2, h, w).permute(0,1,3,4,2)
im = torch.irfft(fftim, signal_ndim=2, normalized=True)
return {self.output: im}
class IndicesExtractor(Injector):
def __init__(self, opt, env):
super(IndicesExtractor, self).__init__(opt, env)
@ -374,21 +325,6 @@ class RandomShiftInjector(Injector):
return {self.output: img}
class PsnrInjector(Injector):
def __init__(self, opt, env):
super(PsnrInjector, self).__init__(opt, env)
self.ssim = SSIM(size_average=False, raw=True)
self.scale = opt['output_scale_divisor']
self.exp = opt['exponent'] if 'exponent' in opt.keys() else 1
def forward(self, state):
img1, img2 = state[self.input[0]], state[self.input[1]]
ssim = self.ssim(img1, img2)
areal_se = torch.nn.functional.interpolate(ssim, scale_factor=1/self.scale,
mode="area")
return {self.output: areal_se}
class BatchRotateInjector(Injector):
def __init__(self, opt, env):
super(BatchRotateInjector, self).__init__(opt, env)
@ -436,7 +372,7 @@ class MultiFrameCombiner(Injector):
self.in_hq_key = opt['in_hq']
self.out_lq_key = opt['out']
self.out_hq_key = opt['out_hq']
from models.flownet2.networks.resample2d_package.resample2d import Resample2d
from models.archs.flownet2.networks import Resample2d
self.resampler = Resample2d()
def combine(self, state):

View File

@ -6,13 +6,12 @@ from models.loss import GANLoss
import random
import functools
import torch.nn.functional as F
import numpy as np
def create_loss(opt_loss, env):
type = opt_loss['type']
if 'teco_' in type:
from models.steps.tecogan_losses import create_teco_loss
from models.custom_training_components.tecogan_losses import create_teco_loss
return create_teco_loss(opt_loss, env)
elif 'stylegan2_' in type:
from models.archs.stylegan import create_stylegan2_loss

View File

@ -10,16 +10,12 @@ import models.archs.stylegan.stylegan2_lucidrains as stylegan2
import models.archs.fixup_resnet.DiscriminatorResnet_arch as DiscriminatorResnet_arch
import models.archs.RRDBNet_arch as RRDBNet_arch
import models.archs.SPSR_arch as spsr
import models.archs.SRResNet_arch as SRResNet_arch
import models.archs.SwitchedResidualGenerator_arch as SwitchedGen_arch
import models.archs.discriminator_vgg_arch as SRGAN_arch
import models.archs.feature_arch as feature_arch
import models.archs.rcan as rcan
from models.archs import srg2_classic
from models.archs.biggan.biggan_discriminator import BigGanDiscriminator
from models.archs.stylegan.Discriminator_StyleGAN import StyleGanDiscriminator
from models.archs.teco_resgen import TecoGen
from models.archs.tecogan.teco_resgen import TecoGen
from utils.util import opt_get
logger = logging.getLogger('base')
@ -30,11 +26,7 @@ def define_G(opt, opt_net, scale=None):
scale = opt['scale']
which_model = opt_net['which_model_G']
# image restoration
if which_model == 'MSRResNet':
netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'])
elif 'RRDBNet' in which_model:
if 'RRDBNet' in which_model:
if which_model == 'RRDBNetBypass':
block = RRDBNet_arch.RRDBWithBypass
elif which_model == 'RRDBNetLambda':
@ -50,24 +42,6 @@ def define_G(opt, opt_net, scale=None):
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode,
output_mode=output_mode, body_block=block, scale=opt_net['scale'], growth_channels=gc,
initial_stride=initial_stride)
elif which_model == "multires_rrdb":
from models.archs.multi_res_rrdb import MultiResRRDBNet
netG = MultiResRRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
mid_channels=opt_net['nf'], l1_blocks=opt_net['l1'],
l2_blocks=opt_net['l2'], l3_blocks=opt_net['l3'],
growth_channels=opt_net['gc'], scale=opt_net['scale'])
elif which_model == "twostep_rrdb":
from models.archs.multi_res_rrdb import PixelShufflingSteppedResRRDBNet
netG = PixelShufflingSteppedResRRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
mid_channels=opt_net['nf'], l1_blocks=opt_net['l1'],
l2_blocks=opt_net['l2'],
growth_channels=opt_net['gc'], scale=opt_net['scale'])
elif which_model == 'rcan':
#args: n_resgroups, n_resblocks, res_scale, reduction, scale, n_feats
opt_net['rgb_range'] = 255
opt_net['n_colors'] = 3
args_obj = munchify(opt_net)
netG = rcan.RCAN(args_obj)
elif which_model == "ConfigurableSwitchedResidualGenerator2":
netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'],
switch_reductions=opt_net['switch_reductions'],
@ -87,22 +61,8 @@ def define_G(opt, opt_net, scale=None):
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
elif which_model == 'spsr':
netG = spsr.SPSRNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],
nb=opt_net['nb'], upscale=opt_net['scale'])
elif which_model == 'spsr_net_improved':
netG = spsr.SPSRNetSimplified(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],
nb=opt_net['nb'], upscale=opt_net['scale'])
elif which_model == "spsr_switched":
netG = spsr.SwitchedSpsr(in_nc=3, nf=opt_net['nf'], upscale=opt_net['scale'], init_temperature=opt_net['temperature'])
elif which_model == "spsr7":
recurrent = opt_net['recurrent'] if 'recurrent' in opt_net.keys() else False
xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8
netG = spsr.Spsr7(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],
multiplexer_reductions=opt_net['multiplexer_reductions'] if 'multiplexer_reductions' in opt_net.keys() else 3,
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10, recurrent=recurrent)
elif which_model == "flownet2":
from models.flownet2.models import FlowNet2
from models.archs.flownet2 import FlowNet2
ld = 'load_path' in opt_net.keys()
args = munch.Munch({'fp16': False, 'rgb_max': 1.0, 'checkpoint': not ld})
netG = FlowNet2(args)
@ -148,12 +108,12 @@ def define_G(opt, opt_net, scale=None):
from models.archs.transformers.igpt.gpt2 import iGPT2
netG = iGPT2(opt_net['embed_dim'], opt_net['num_heads'], opt_net['num_layers'], opt_net['num_pixels'] ** 2, opt_net['num_vocab'], centroids_file=opt_net['centroids_file'])
elif which_model == 'byol':
from models.byol.byol_model_wrapper import BYOL
from models.archs.byol.byol_model_wrapper import BYOL
subnet = define_G(opt, opt_net['subnet'])
netG = BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False))
elif which_model == 'structural_byol':
from models.byol.byol_structural import StructuralBYOL
from models.archs.byol.byol_structural import StructuralBYOL
subnet = define_G(opt, opt_net['subnet'])
netG = StructuralBYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
pretrained_state_dict=opt_get(opt_net, ["pretrained_path"]),
@ -206,8 +166,6 @@ def define_D_net(opt_net, img_sz=None, wrap=False):
#state_dict = torch.hub.load_state_dict_from_url('https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', progress=True)
#netD.load_state_dict(state_dict, strict=False)
netD.fc = torch.nn.Linear(512 * 4, 1)
elif which_model == 'biggan_resnet':
netD = BigGanDiscriminator(D_activation=torch.nn.LeakyReLU(negative_slope=.2))
elif which_model == 'discriminator_pix':
netD = SRGAN_arch.Discriminator_VGG_PixLoss(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
elif which_model == "discriminator_unet":

View File

@ -1,12 +1,12 @@
from torch.cuda.amp import GradScaler, autocast
from torch.cuda.amp import GradScaler
from utils.loss_accumulator import LossAccumulator
from torch.nn import Module
import logging
from models.steps.losses import create_loss
from models.losses import create_loss
import torch
from collections import OrderedDict
from .injectors import create_injector
from models.injectors import create_injector
from utils.util import recursively_detach
logger = logging.getLogger('base')

View File

@ -8,14 +8,11 @@ import os
import utils
from models.ExtensibleTrainer import ExtensibleTrainer
from models.networks import define_F
from models.steps.losses import FeatureLoss
from utils import options as option
import utils.util as util
from data import create_dataset, create_dataloader
from tqdm import tqdm
import torch
import torchvision
if __name__ == "__main__":
bin_path = "f:\\binned"

View File

@ -1,22 +0,0 @@
import onnx
import numpy as np
import time
init_temperature = 10
final_temperature_step = 50
heightened_final_step = 90
heightened_temp_min = .1
for step in range(100):
temp = max(1, 1 + init_temperature * (final_temperature_step - step) / final_temperature_step)
if temp == 1 and step > final_temperature_step and heightened_final_step and heightened_final_step != 1:
# Once the temperature passes (1) it enters an inverted curve to match the linear curve from above.
# without this, the attention specificity "spikes" incredibly fast in the last few iterations.
h_steps_total = heightened_final_step - final_temperature_step
h_steps_current = min(step - final_temperature_step, h_steps_total)
# The "gap" will represent the steps that need to be traveled as a linear function.
h_gap = 1 / heightened_temp_min
temp = h_gap * h_steps_current / h_steps_total
# Invert temperature to represent reality on this side of the curve
temp = 1 / temp
print("%i: %f" % (step, temp))