forked from mrq/DL-Art-School
Add support for passthrough disc/gen
Add RRDBNetXL, which performs processing at multiple image sizes. Add DiscResnet_passthrough, which allows passthrough of image at different sizes for discrimination. Adjust the rest of the repo to allow generators that return more than just a single image.
This commit is contained in:
parent
44b89330c2
commit
3b4e54c4c5
|
@ -8,6 +8,7 @@ import models.lr_scheduler as lr_scheduler
|
|||
from models.base_model import BaseModel
|
||||
from models.loss import GANLoss
|
||||
from apex import amp
|
||||
import torch.nn.functional as F
|
||||
|
||||
import torchvision.utils as utils
|
||||
import os
|
||||
|
@ -156,10 +157,21 @@ class SRGANModel(BaseModel):
|
|||
if step > self.D_init_iters:
|
||||
self.optimizer_G.zero_grad()
|
||||
|
||||
self.fake_H = []
|
||||
self.fake_GenOut = []
|
||||
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix):
|
||||
fake_H = self.netG(var_L)
|
||||
self.fake_H.append(fake_H.detach())
|
||||
fake_GenOut = self.netG(var_L)
|
||||
|
||||
# Extract the image output. For generators that output skip-through connections, the master output is always
|
||||
# the first element of the tuple.
|
||||
if isinstance(fake_GenOut, tuple):
|
||||
fake_H = fake_GenOut[0]
|
||||
# TODO: Fix this.
|
||||
self.fake_GenOut.append((fake_GenOut[0].detach(),
|
||||
fake_GenOut[1].detach(),
|
||||
fake_GenOut[2].detach()))
|
||||
else:
|
||||
fake_H = fake_GenOut
|
||||
self.fake_GenOut.append(fake_GenOut.detach())
|
||||
|
||||
l_g_total = 0
|
||||
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
||||
|
@ -178,11 +190,11 @@ class SRGANModel(BaseModel):
|
|||
self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay)
|
||||
|
||||
if self.opt['train']['gan_type'] == 'gan':
|
||||
pred_g_fake = self.netD(fake_H)
|
||||
pred_g_fake = self.netD(fake_GenOut)
|
||||
l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
|
||||
elif self.opt['train']['gan_type'] == 'ragan':
|
||||
pred_d_real = self.netD(var_ref).detach()
|
||||
pred_g_fake = self.netD(fake_H)
|
||||
pred_g_fake = self.netD(fake_GenOut)
|
||||
l_g_gan = self.l_gan_w * (
|
||||
self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
|
||||
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
|
||||
|
@ -196,8 +208,17 @@ class SRGANModel(BaseModel):
|
|||
for p in self.netD.parameters():
|
||||
p.requires_grad = True
|
||||
|
||||
# Convert var_ref to have the same output format as the generator. This generally means interpolating the
|
||||
# HR images to have the same output dimensions as each generator skip connection.
|
||||
if isinstance(self.fake_GenOut[0], tuple):
|
||||
var_ref_skips = []
|
||||
for ref, hi_res in zip(self.var_ref, self.var_H):
|
||||
var_ref_skips.append((ref,) + self.create_artificial_skips(hi_res))
|
||||
else:
|
||||
var_ref_skips = self.var_ref
|
||||
|
||||
self.optimizer_D.zero_grad()
|
||||
for var_L, var_H, var_ref, pix, fake_H in zip(self.var_L, self.var_H, self.var_ref, self.pix, self.fake_H):
|
||||
for var_L, var_H, var_ref, pix, fake_H in zip(self.var_L, self.var_H, var_ref_skips, self.pix, self.fake_GenOut):
|
||||
if self.opt['train']['gan_type'] == 'gan':
|
||||
# need to forward and backward separately, since batch norm statistics differ
|
||||
# real
|
||||
|
@ -206,7 +227,7 @@ class SRGANModel(BaseModel):
|
|||
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
|
||||
l_d_real_scaled.backward()
|
||||
# fake
|
||||
pred_d_fake = self.netD(fake_H.detach()) # detach to avoid BP to G
|
||||
pred_d_fake = self.netD(fake_H) # detach to avoid BP to G
|
||||
l_d_fake = self.cri_gan(pred_d_fake, False)
|
||||
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
||||
l_d_fake_scaled.backward()
|
||||
|
@ -217,12 +238,12 @@ class SRGANModel(BaseModel):
|
|||
# l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
|
||||
# l_d_total = (l_d_real + l_d_fake) / 2
|
||||
# l_d_total.backward()
|
||||
pred_d_fake = self.netD(fake_H.detach()).detach()
|
||||
pred_d_fake = self.netD(fake_H).detach()
|
||||
pred_d_real = self.netD(var_ref)
|
||||
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5
|
||||
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
|
||||
l_d_real_scaled.backward()
|
||||
pred_d_fake = self.netD(fake_H.detach())
|
||||
pred_d_fake = self.netD(fake_H)
|
||||
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5
|
||||
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
||||
l_d_fake_scaled.backward()
|
||||
|
@ -234,11 +255,14 @@ class SRGANModel(BaseModel):
|
|||
os.makedirs("temp/lr", exist_ok=True)
|
||||
os.makedirs("temp/gen", exist_ok=True)
|
||||
os.makedirs("temp/pix", exist_ok=True)
|
||||
gen_batch = self.fake_GenOut[0]
|
||||
if isinstance(gen_batch, tuple):
|
||||
gen_batch = gen_batch[0]
|
||||
for i in range(self.var_L[0].shape[0]):
|
||||
utils.save_image(self.var_H[0][i].cpu().detach(), os.path.join("temp/hr", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(self.var_L[0][i].cpu().detach(), os.path.join("temp/lr", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(self.pix[0][i].cpu().detach(), os.path.join("temp/pix", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(self.fake_H[0][i].cpu().detach(), os.path.join("temp/gen", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(gen_batch[i].cpu().detach(), os.path.join("temp/gen", "%05i_%02i.png" % (step, i)))
|
||||
|
||||
# set log TODO(handle mega-batches?)
|
||||
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
||||
|
@ -253,10 +277,15 @@ class SRGANModel(BaseModel):
|
|||
self.log_dict['l_d_fake'] = l_d_fake.item()
|
||||
self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())
|
||||
|
||||
def create_artificial_skips(self, truth_img):
|
||||
med_skip = F.interpolate(truth_img, scale_factor=.5)
|
||||
lo_skip = F.interpolate(truth_img, scale_factor=.25)
|
||||
return med_skip, lo_skip
|
||||
|
||||
def test(self):
|
||||
self.netG.eval()
|
||||
with torch.no_grad():
|
||||
self.fake_H = [self.netG(self.var_L[0])]
|
||||
self.fake_GenOut = [self.netG(self.var_L[0])]
|
||||
self.netG.train()
|
||||
|
||||
def get_current_log(self):
|
||||
|
@ -265,7 +294,10 @@ class SRGANModel(BaseModel):
|
|||
def get_current_visuals(self, need_GT=True):
|
||||
out_dict = OrderedDict()
|
||||
out_dict['LQ'] = self.var_L[0].detach()[0].float().cpu()
|
||||
out_dict['rlt'] = self.fake_H[0].detach()[0].float().cpu()
|
||||
gen_batch = self.fake_GenOut[0]
|
||||
if isinstance(gen_batch, tuple):
|
||||
gen_batch = gen_batch[0]
|
||||
out_dict['rlt'] = gen_batch.detach()[0].float().cpu()
|
||||
if need_GT:
|
||||
out_dict['GT'] = self.var_H[0].detach()[0].float().cpu()
|
||||
return out_dict
|
||||
|
|
|
@ -1,161 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
__all__ = ['ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202']
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=1, bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.downsample = downsample
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.lrelu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
identity = torch.cat((identity, torch.zeros_like(identity)), 1)
|
||||
|
||||
out += identity
|
||||
out = self.lrelu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
|
||||
def __init__(self, block, layers, num_filters=16, num_classes=10):
|
||||
super(ResNet, self).__init__()
|
||||
self.num_layers = sum(layers)
|
||||
self.inplanes = num_filters
|
||||
self.conv1 = conv3x3(3, num_filters)
|
||||
self.bn1 = nn.BatchNorm2d(num_filters)
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
self.layer1 = self._make_layer(block, num_filters, layers[0])
|
||||
self.layer2 = self._make_layer(block, num_filters * 2, layers[1], stride=2)
|
||||
self.skip_conv1 = conv3x3(3, num_filters*2)
|
||||
self.layer3 = self._make_layer(block, num_filters * 4, layers[2], stride=2)
|
||||
self.skip_conv2 = conv3x3(3, num_filters*4)
|
||||
self.layer4 = self._make_layer(block, num_filters * 8, layers[2], stride=2)
|
||||
self.fc1 = nn.Linear(num_filters * 8 * 8 * 8, 64, bias=True)
|
||||
self.fc2 = nn.Linear(64, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
# Zero-initialize the last BN in each residual branch,
|
||||
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
||||
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
||||
for m in self.modules():
|
||||
if isinstance(m, BasicBlock):
|
||||
nn.init.constant_(m.bn2.weight, 0)
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1:
|
||||
downsample = nn.Sequential(
|
||||
nn.AvgPool2d(1, stride=stride),
|
||||
nn.BatchNorm2d(self.inplanes),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||
self.inplanes = planes
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(planes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x, gen_skips=None):
|
||||
x_dim = x.size(-1)
|
||||
if gen_skips is None:
|
||||
gen_skips = {
|
||||
int(x_dim/2): F.interpolate(x, scale_factor=1/2, mode='bilinear', align_corners=False),
|
||||
int(x_dim/4): F.interpolate(x, scale_factor=1/4, mode='bilinear', align_corners=False),
|
||||
}
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.lrelu(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = (x + self.skip_conv1(gen_skips[int(x_dim/2)])) / 2
|
||||
x = self.layer3(x)
|
||||
x = (x + self.skip_conv2(gen_skips[int(x_dim/4)])) / 2
|
||||
x = self.layer4(x)
|
||||
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.lrelu(self.fc1(x))
|
||||
x = self.fc2(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def resnet20(**kwargs):
|
||||
"""Constructs a ResNet-20 model.
|
||||
"""
|
||||
model = ResNet(BasicBlock, [3, 3, 3], **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def resnet32(**kwargs):
|
||||
"""Constructs a ResNet-32 model.
|
||||
"""
|
||||
model = ResNet(BasicBlock, [5, 5, 5], **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def resnet44(**kwargs):
|
||||
"""Constructs a ResNet-44 model.
|
||||
"""
|
||||
model = ResNet(BasicBlock, [7, 7, 7], **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def resnet56(**kwargs):
|
||||
"""Constructs a ResNet-56 model.
|
||||
"""
|
||||
model = ResNet(BasicBlock, [9, 9, 9], **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def resnet110(**kwargs):
|
||||
"""Constructs a ResNet-110 model.
|
||||
"""
|
||||
model = ResNet(BasicBlock, [18, 18, 18], **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def resnet1202(**kwargs):
|
||||
"""Constructs a ResNet-1202 model.
|
||||
"""
|
||||
model = ResNet(BasicBlock, [200, 200, 200], **kwargs)
|
||||
return model
|
207
codes/models/archs/DiscriminatorResnet_arch_passthrough.py
Normal file
207
codes/models/archs/DiscriminatorResnet_arch_passthrough.py
Normal file
|
@ -0,0 +1,207 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
|
||||
__all__ = ['FixupResNet', 'fixup_resnet18', 'fixup_resnet34', 'fixup_resnet50', 'fixup_resnet101', 'fixup_resnet152']
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=1, bias=False)
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
|
||||
class FixupBasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(FixupBasicBlock, self).__init__()
|
||||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||
self.bias1a = nn.Parameter(torch.zeros(1))
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bias1b = nn.Parameter(torch.zeros(1))
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||
self.bias2a = nn.Parameter(torch.zeros(1))
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.scale = nn.Parameter(torch.ones(1))
|
||||
self.bias2b = nn.Parameter(torch.zeros(1))
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x + self.bias1a)
|
||||
out = self.lrelu(out + self.bias1b)
|
||||
|
||||
out = self.conv2(out + self.bias2a)
|
||||
out = out * self.scale + self.bias2b
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x + self.bias1a)
|
||||
|
||||
out += identity
|
||||
out = self.lrelu(out)
|
||||
|
||||
return out
|
||||
|
||||
class FixupBottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(FixupBottleneck, self).__init__()
|
||||
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
||||
self.bias1a = nn.Parameter(torch.zeros(1))
|
||||
self.conv1 = conv1x1(inplanes, planes)
|
||||
self.bias1b = nn.Parameter(torch.zeros(1))
|
||||
self.bias2a = nn.Parameter(torch.zeros(1))
|
||||
self.conv2 = conv3x3(planes, planes, stride)
|
||||
self.bias2b = nn.Parameter(torch.zeros(1))
|
||||
self.bias3a = nn.Parameter(torch.zeros(1))
|
||||
self.conv3 = conv1x1(planes, planes * self.expansion)
|
||||
self.scale = nn.Parameter(torch.ones(1))
|
||||
self.bias3b = nn.Parameter(torch.zeros(1))
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x + self.bias1a)
|
||||
out = self.lrelu(out + self.bias1b)
|
||||
|
||||
out = self.conv2(out + self.bias2a)
|
||||
out = self.lrelu(out + self.bias2b)
|
||||
|
||||
out = self.conv3(out + self.bias3a)
|
||||
out = out * self.scale + self.bias3b
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x + self.bias1a)
|
||||
|
||||
out += identity
|
||||
out = self.lrelu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class FixupResNet(nn.Module):
|
||||
|
||||
def __init__(self, block, layers, num_filters=64, num_classes=1000, input_img_size=64):
|
||||
super(FixupResNet, self).__init__()
|
||||
self.num_layers = sum(layers)
|
||||
self.inplanes = num_filters
|
||||
self.conv1 = nn.Conv2d(3, num_filters, kernel_size=7, stride=2, padding=3,
|
||||
bias=False)
|
||||
self.bias1 = nn.Parameter(torch.zeros(1))
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
self.layer1 = self._make_layer(block, num_filters, layers[0], stride=1)
|
||||
self.skip1 = nn.Conv2d(num_filters + 3, num_filters, kernel_size=5, stride=1, padding=2, bias=False)
|
||||
self.skip1_bias = nn.Parameter(torch.zeros(1))
|
||||
self.layer2 = self._make_layer(block, num_filters*2, layers[1], stride=2)
|
||||
self.skip2 = nn.Conv2d(num_filters*2 + 3, num_filters*2, kernel_size=5, stride=1, padding=2, bias=False)
|
||||
self.skip2_bias = nn.Parameter(torch.zeros(1))
|
||||
self.layer3 = self._make_layer(block, num_filters*4, layers[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, num_filters*8, layers[3], stride=2)
|
||||
self.layer5 = self._make_layer(block, num_filters*16, layers[4], stride=2)
|
||||
self.bias2 = nn.Parameter(torch.zeros(1))
|
||||
reduced_img_sz = int(input_img_size / 32)
|
||||
self.fc1 = nn.Linear(num_filters * 16 * reduced_img_sz * reduced_img_sz, 100)
|
||||
self.fc2 = nn.Linear(100, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, FixupBasicBlock):
|
||||
nn.init.normal_(m.conv1.weight, mean=0, std=np.sqrt(2 / (m.conv1.weight.shape[0] * np.prod(m.conv1.weight.shape[2:]))) * self.num_layers ** (-0.5))
|
||||
nn.init.constant_(m.conv2.weight, 0)
|
||||
if m.downsample is not None:
|
||||
nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:]))))
|
||||
elif isinstance(m, FixupBottleneck):
|
||||
nn.init.normal_(m.conv1.weight, mean=0, std=np.sqrt(2 / (m.conv1.weight.shape[0] * np.prod(m.conv1.weight.shape[2:]))) * self.num_layers ** (-0.25))
|
||||
nn.init.normal_(m.conv2.weight, mean=0, std=np.sqrt(2 / (m.conv2.weight.shape[0] * np.prod(m.conv2.weight.shape[2:]))) * self.num_layers ** (-0.25))
|
||||
nn.init.constant_(m.conv3.weight, 0)
|
||||
if m.downsample is not None:
|
||||
nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:]))))
|
||||
'''
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.constant_(m.weight, 0)
|
||||
nn.init.constant_(m.bias, 0)'''
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = conv1x1(self.inplanes, planes * block.expansion, stride)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
# This class expects a medium skip (half-res) and low skip (quarter-res) provided as a tuple in the input.
|
||||
hi, med_skip, lo_skip = x
|
||||
|
||||
x = self.conv1(hi)
|
||||
x = self.lrelu(x + self.bias1)
|
||||
x = self.layer1(x)
|
||||
x = self.lrelu(self.skip1(torch.cat([x, med_skip], dim=1)) + self.skip1_bias)
|
||||
|
||||
x = self.layer2(x)
|
||||
x = self.lrelu(self.skip2(torch.cat([x, lo_skip], dim=1)) + self.skip2_bias)
|
||||
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
x = self.layer5(x)
|
||||
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.lrelu(self.fc1(x))
|
||||
x = self.fc2(x + self.bias2)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def fixup_resnet18(**kwargs):
|
||||
"""Constructs a Fixup-ResNet-18 model.2
|
||||
"""
|
||||
model = FixupResNet(FixupBasicBlock, [2, 2, 2, 2, 2], **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def fixup_resnet34(**kwargs):
|
||||
"""Constructs a Fixup-ResNet-34 model.
|
||||
"""
|
||||
model = FixupResNet(FixupBasicBlock, [5, 4, 3, 3, 2], **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def fixup_resnet50(**kwargs):
|
||||
"""Constructs a Fixup-ResNet-50 model.
|
||||
"""
|
||||
model = FixupResNet(FixupBottleneck, [3, 4, 6, 3, 2], **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def fixup_resnet101(**kwargs):
|
||||
"""Constructs a Fixup-ResNet-101 model.
|
||||
"""
|
||||
model = FixupResNet(FixupBottleneck, [3, 4, 23, 3, 2], **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def fixup_resnet152(**kwargs):
|
||||
"""Constructs a Fixup-ResNet-152 model.
|
||||
"""
|
||||
model = FixupResNet(FixupBottleneck, [3, 8, 36, 3, 2], **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
__all__ = ['FixupResNet', 'fixup_resnet18', 'fixup_resnet34', 'fixup_resnet50', 'fixup_resnet101', 'fixup_resnet152']
|
98
codes/models/archs/RRDBNetXL_arch.py
Normal file
98
codes/models/archs/RRDBNetXL_arch.py
Normal file
|
@ -0,0 +1,98 @@
|
|||
import functools
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import models.archs.arch_util as arch_util
|
||||
|
||||
|
||||
class ResidualDenseBlock_5C(nn.Module):
|
||||
def __init__(self, nf=64, gc=32, bias=True):
|
||||
super(ResidualDenseBlock_5C, self).__init__()
|
||||
# gc: growth channel, i.e. intermediate channels
|
||||
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
|
||||
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
# initialization
|
||||
arch_util.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5],
|
||||
0.1)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.lrelu(self.conv1(x))
|
||||
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
||||
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
||||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||
return x5 * 0.2 + x
|
||||
|
||||
|
||||
class RRDB(nn.Module):
|
||||
'''Residual in Residual Dense Block'''
|
||||
|
||||
def __init__(self, nf, gc=32):
|
||||
super(RRDB, self).__init__()
|
||||
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.RDB1(x)
|
||||
out = self.RDB2(out)
|
||||
out = self.RDB3(out)
|
||||
return out * 0.2 + x
|
||||
|
||||
|
||||
class RRDBNet(nn.Module):
|
||||
def __init__(self, in_nc, out_nc, nf, nb_lo, nb_med, nb_hi, gc=32, interpolation_scale_factor=2):
|
||||
super(RRDBNet, self).__init__()
|
||||
nfmed = int(nf/2)
|
||||
nfhi = int(nf/8)
|
||||
gcmed = int(gc/2)
|
||||
gchi = int(gc/8)
|
||||
RRDB_block_f_lo = functools.partial(RRDB, nf=nf, gc=gc)
|
||||
RRDB_block_f_lo_med = functools.partial(RRDB, nf=nfmed, gc=gcmed)
|
||||
RRDB_block_f_lo_hi = functools.partial(RRDB, nf=nfhi, gc=gchi)
|
||||
|
||||
self.conv_first = nn.Conv2d(in_nc, nf, 7, 1, padding=3, bias=True)
|
||||
self.RRDB_trunk_lo = arch_util.make_layer(RRDB_block_f_lo, nb_lo)
|
||||
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.lo_skip_conv1 = nn.Conv2d(nf, nf, 3, 1, padding=1, bias=True)
|
||||
self.lo_skip_conv2 = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
||||
|
||||
#### upsampling
|
||||
self.upconv1 = nn.Conv2d(nf, nfmed, 3, 1, padding=1, bias=True)
|
||||
self.RRDB_trunk_med = arch_util.make_layer(RRDB_block_f_lo_med, nb_med)
|
||||
self.trunk_conv_med = nn.Conv2d(nfmed, nfmed, 3, 1, 1, bias=True)
|
||||
self.med_skip_conv1 = nn.Conv2d(nfmed, nfmed, 3, 1, padding=1, bias=True)
|
||||
self.med_skip_conv2 = nn.Conv2d(nfmed, out_nc, 3, 1, 1, bias=True)
|
||||
|
||||
self.upconv2 = nn.Conv2d(nfmed, nfhi, 3, 1, padding=1, bias=True)
|
||||
self.RRDB_trunk_hi = arch_util.make_layer(RRDB_block_f_lo_hi, nb_hi)
|
||||
self.trunk_conv_hi = nn.Conv2d(nfhi, nfhi, 3, 1, 1, bias=True)
|
||||
self.HRconv = nn.Conv2d(nfhi, nfhi, 5, 1, padding=2, bias=True)
|
||||
self.conv_last = nn.Conv2d(nfhi, out_nc, 3, 1, 1, bias=True)
|
||||
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
self.interpolation_scale_factor = interpolation_scale_factor
|
||||
|
||||
def forward(self, x):
|
||||
fea = self.conv_first(x)
|
||||
branch = self.trunk_conv(self.RRDB_trunk_lo(fea))
|
||||
fea = (fea + branch) / 2
|
||||
lo_skip = self.lo_skip_conv2(self.lrelu(self.lo_skip_conv1(fea)))
|
||||
|
||||
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=self.interpolation_scale_factor, mode='nearest')))
|
||||
branch = self.trunk_conv_med(self.RRDB_trunk_med(fea))
|
||||
fea = (fea + branch) / 2
|
||||
med_skip = self.med_skip_conv2(self.lrelu(self.med_skip_conv1(fea)))
|
||||
|
||||
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=self.interpolation_scale_factor, mode='nearest')))
|
||||
branch = self.trunk_conv_hi(self.RRDB_trunk_hi(fea))
|
||||
fea = (fea + branch) / 2
|
||||
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
||||
|
||||
return out, med_skip, lo_skip
|
|
@ -2,9 +2,10 @@ import torch
|
|||
import models.archs.SRResNet_arch as SRResNet_arch
|
||||
import models.archs.discriminator_vgg_arch as SRGAN_arch
|
||||
import models.archs.DiscriminatorResnet_arch as DiscriminatorResnet_arch
|
||||
import models.archs.DiscriminatorResnetBN_arch as DiscriminatorResnetBN_arch
|
||||
import models.archs.DiscriminatorResnet_arch_passthrough as DiscriminatorResnet_arch_passthrough
|
||||
import models.archs.FlatProcessorNetNew_arch as FlatProcessorNetNew_arch
|
||||
import models.archs.RRDBNet_arch as RRDBNet_arch
|
||||
import models.archs.RRDBNetXL_arch as RRDBNetXL_arch
|
||||
#import models.archs.EDVR_arch as EDVR_arch
|
||||
import models.archs.HighToLowResNet as HighToLowResNet
|
||||
import models.archs.FlatProcessorNet_arch as FlatProcessorNet_arch
|
||||
|
@ -26,6 +27,11 @@ def define_G(opt):
|
|||
scale_per_step = math.sqrt(scale)
|
||||
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||
nf=opt_net['nf'], nb=opt_net['nb'], interpolation_scale_factor=scale_per_step)
|
||||
elif which_model == 'RRDBNetXL':
|
||||
scale_per_step = math.sqrt(scale)
|
||||
netG = RRDBNetXL_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||
nf=opt_net['nf'], nb_lo=opt_net['nblo'], nb_med=opt_net['nbmed'], nb_hi=opt_net['nbhi'],
|
||||
interpolation_scale_factor=scale_per_step)
|
||||
# image corruption
|
||||
elif which_model == 'HighToLowResNet':
|
||||
netG = HighToLowResNet.HighToLowResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||
|
@ -59,6 +65,8 @@ def define_D(opt):
|
|||
netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128)
|
||||
elif which_model == 'discriminator_resnet':
|
||||
netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
|
||||
elif which_model == 'discriminator_resnet_passthrough':
|
||||
netD = DiscriminatorResnet_arch_passthrough.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
|
||||
else:
|
||||
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
||||
return netD
|
||||
|
|
|
@ -16,7 +16,7 @@ datasets:
|
|||
dataroot_LQ: K:\4k6k\4k_closeup\lr_corrupted
|
||||
doCrop: false
|
||||
use_shuffle: true
|
||||
n_workers: 12 # per GPU
|
||||
n_workers: 0 # per GPU
|
||||
batch_size: 40
|
||||
target_size: 256
|
||||
color: RGB
|
||||
|
@ -43,7 +43,7 @@ path:
|
|||
pretrain_model_G: ../experiments/rrdb_blacked_gan_g.pth
|
||||
pretrain_model_D: ~
|
||||
strict_load: true
|
||||
resume_state: ../experiments/blacked_fix_and_upconv/training_state/9500.state
|
||||
resume_state: ../experiments/blacked_fix_and_upconv/training_state/16500.state
|
||||
|
||||
#### training settings: learning rate scheme, loss
|
||||
train:
|
||||
|
|
87
codes/options/train/train_ESRGAN_blacked_xl.yml
Normal file
87
codes/options/train/train_ESRGAN_blacked_xl.yml
Normal file
|
@ -0,0 +1,87 @@
|
|||
#### general settings
|
||||
name: blacked_fix_and_upconv_xl
|
||||
use_tb_logger: true
|
||||
model: srgan
|
||||
distortion: sr
|
||||
scale: 4
|
||||
gpu_ids: [0]
|
||||
amp_opt_level: O1
|
||||
|
||||
#### datasets
|
||||
datasets:
|
||||
train:
|
||||
name: vixcloseup
|
||||
mode: LQGT
|
||||
dataroot_GT: K:\4k6k\4k_closeup\hr
|
||||
dataroot_LQ: K:\4k6k\4k_closeup\lr_corrupted
|
||||
doCrop: false
|
||||
use_shuffle: true
|
||||
n_workers: 8 # per GPU
|
||||
batch_size: 6
|
||||
target_size: 256
|
||||
color: RGB
|
||||
val:
|
||||
name: adrianna_val
|
||||
mode: LQGT
|
||||
dataroot_GT: E:\4k6k\datasets\adrianna\val\hhq
|
||||
dataroot_LQ: E:\4k6k\datasets\adrianna\val\hr
|
||||
|
||||
#### network structures
|
||||
network_G:
|
||||
which_model_G: RRDBNetXL
|
||||
in_nc: 3
|
||||
out_nc: 3
|
||||
nf: 64
|
||||
nblo: 18
|
||||
nbmed: 8
|
||||
nbhi: 6
|
||||
network_D:
|
||||
which_model_D: discriminator_resnet_passthrough
|
||||
in_nc: 3
|
||||
nf: 42
|
||||
|
||||
#### path
|
||||
path:
|
||||
pretrain_model_G: ../experiments/blacked_fix_and_upconv_xl_part1/models/3000_G.pth
|
||||
pretrain_model_D: ~
|
||||
strict_load: true
|
||||
resume_state: ~
|
||||
|
||||
#### training settings: learning rate scheme, loss
|
||||
train:
|
||||
lr_G: !!float 1e-4
|
||||
weight_decay_G: 0
|
||||
beta1_G: 0.9
|
||||
beta2_G: 0.99
|
||||
lr_D: !!float 1e-4
|
||||
weight_decay_D: 0
|
||||
beta1_D: 0.9
|
||||
beta2_D: 0.99
|
||||
lr_scheme: MultiStepLR
|
||||
|
||||
niter: 400000
|
||||
warmup_iter: -1 # no warm up
|
||||
lr_steps: [20000, 40000, 50000, 60000]
|
||||
lr_gamma: 0.5
|
||||
mega_batch_factor: 1
|
||||
|
||||
pixel_criterion: l1
|
||||
pixel_weight: !!float 1e-2
|
||||
feature_criterion: l1
|
||||
feature_weight: 1
|
||||
feature_weight_decay: 1
|
||||
feature_weight_decay_steps: 500
|
||||
feature_weight_minimum: 1
|
||||
gan_type: gan # gan | ragan
|
||||
gan_weight: !!float 1e-2
|
||||
|
||||
D_update_ratio: 1
|
||||
D_init_iters: -1
|
||||
|
||||
manual_seed: 10
|
||||
val_freq: !!float 5e2
|
||||
|
||||
#### logger
|
||||
logger:
|
||||
print_freq: 50
|
||||
save_checkpoint_freq: !!float 5e2
|
|
@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs):
|
|||
def main():
|
||||
#### options
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='options/train/train_ESRGAN_blacked.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='options/train/train_ESRGAN_blacked_xl.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
|
|
Loading…
Reference in New Issue
Block a user