Add support for passthrough disc/gen

Add RRDBNetXL, which performs processing at multiple image sizes.
Add DiscResnet_passthrough, which allows passthrough of image at different sizes for discrimination.
Adjust the rest of the repo to allow generators that return more than just a single image.
This commit is contained in:
James Betker 2020-05-04 14:01:43 -06:00
parent 44b89330c2
commit 3b4e54c4c5
8 changed files with 448 additions and 177 deletions

View File

@ -8,6 +8,7 @@ import models.lr_scheduler as lr_scheduler
from models.base_model import BaseModel
from models.loss import GANLoss
from apex import amp
import torch.nn.functional as F
import torchvision.utils as utils
import os
@ -156,10 +157,21 @@ class SRGANModel(BaseModel):
if step > self.D_init_iters:
self.optimizer_G.zero_grad()
self.fake_H = []
self.fake_GenOut = []
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix):
fake_H = self.netG(var_L)
self.fake_H.append(fake_H.detach())
fake_GenOut = self.netG(var_L)
# Extract the image output. For generators that output skip-through connections, the master output is always
# the first element of the tuple.
if isinstance(fake_GenOut, tuple):
fake_H = fake_GenOut[0]
# TODO: Fix this.
self.fake_GenOut.append((fake_GenOut[0].detach(),
fake_GenOut[1].detach(),
fake_GenOut[2].detach()))
else:
fake_H = fake_GenOut
self.fake_GenOut.append(fake_GenOut.detach())
l_g_total = 0
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
@ -178,11 +190,11 @@ class SRGANModel(BaseModel):
self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay)
if self.opt['train']['gan_type'] == 'gan':
pred_g_fake = self.netD(fake_H)
pred_g_fake = self.netD(fake_GenOut)
l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
elif self.opt['train']['gan_type'] == 'ragan':
pred_d_real = self.netD(var_ref).detach()
pred_g_fake = self.netD(fake_H)
pred_g_fake = self.netD(fake_GenOut)
l_g_gan = self.l_gan_w * (
self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
@ -196,8 +208,17 @@ class SRGANModel(BaseModel):
for p in self.netD.parameters():
p.requires_grad = True
# Convert var_ref to have the same output format as the generator. This generally means interpolating the
# HR images to have the same output dimensions as each generator skip connection.
if isinstance(self.fake_GenOut[0], tuple):
var_ref_skips = []
for ref, hi_res in zip(self.var_ref, self.var_H):
var_ref_skips.append((ref,) + self.create_artificial_skips(hi_res))
else:
var_ref_skips = self.var_ref
self.optimizer_D.zero_grad()
for var_L, var_H, var_ref, pix, fake_H in zip(self.var_L, self.var_H, self.var_ref, self.pix, self.fake_H):
for var_L, var_H, var_ref, pix, fake_H in zip(self.var_L, self.var_H, var_ref_skips, self.pix, self.fake_GenOut):
if self.opt['train']['gan_type'] == 'gan':
# need to forward and backward separately, since batch norm statistics differ
# real
@ -206,7 +227,7 @@ class SRGANModel(BaseModel):
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
l_d_real_scaled.backward()
# fake
pred_d_fake = self.netD(fake_H.detach()) # detach to avoid BP to G
pred_d_fake = self.netD(fake_H) # detach to avoid BP to G
l_d_fake = self.cri_gan(pred_d_fake, False)
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
l_d_fake_scaled.backward()
@ -217,12 +238,12 @@ class SRGANModel(BaseModel):
# l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
# l_d_total = (l_d_real + l_d_fake) / 2
# l_d_total.backward()
pred_d_fake = self.netD(fake_H.detach()).detach()
pred_d_fake = self.netD(fake_H).detach()
pred_d_real = self.netD(var_ref)
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
l_d_real_scaled.backward()
pred_d_fake = self.netD(fake_H.detach())
pred_d_fake = self.netD(fake_H)
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
l_d_fake_scaled.backward()
@ -234,11 +255,14 @@ class SRGANModel(BaseModel):
os.makedirs("temp/lr", exist_ok=True)
os.makedirs("temp/gen", exist_ok=True)
os.makedirs("temp/pix", exist_ok=True)
gen_batch = self.fake_GenOut[0]
if isinstance(gen_batch, tuple):
gen_batch = gen_batch[0]
for i in range(self.var_L[0].shape[0]):
utils.save_image(self.var_H[0][i].cpu().detach(), os.path.join("temp/hr", "%05i_%02i.png" % (step, i)))
utils.save_image(self.var_L[0][i].cpu().detach(), os.path.join("temp/lr", "%05i_%02i.png" % (step, i)))
utils.save_image(self.pix[0][i].cpu().detach(), os.path.join("temp/pix", "%05i_%02i.png" % (step, i)))
utils.save_image(self.fake_H[0][i].cpu().detach(), os.path.join("temp/gen", "%05i_%02i.png" % (step, i)))
utils.save_image(gen_batch[i].cpu().detach(), os.path.join("temp/gen", "%05i_%02i.png" % (step, i)))
# set log TODO(handle mega-batches?)
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
@ -253,10 +277,15 @@ class SRGANModel(BaseModel):
self.log_dict['l_d_fake'] = l_d_fake.item()
self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())
def create_artificial_skips(self, truth_img):
med_skip = F.interpolate(truth_img, scale_factor=.5)
lo_skip = F.interpolate(truth_img, scale_factor=.25)
return med_skip, lo_skip
def test(self):
self.netG.eval()
with torch.no_grad():
self.fake_H = [self.netG(self.var_L[0])]
self.fake_GenOut = [self.netG(self.var_L[0])]
self.netG.train()
def get_current_log(self):
@ -265,7 +294,10 @@ class SRGANModel(BaseModel):
def get_current_visuals(self, need_GT=True):
out_dict = OrderedDict()
out_dict['LQ'] = self.var_L[0].detach()[0].float().cpu()
out_dict['rlt'] = self.fake_H[0].detach()[0].float().cpu()
gen_batch = self.fake_GenOut[0]
if isinstance(gen_batch, tuple):
gen_batch = gen_batch[0]
out_dict['rlt'] = gen_batch.detach()[0].float().cpu()
if need_GT:
out_dict['GT'] = self.var_H[0].detach()[0].float().cpu()
return out_dict

View File

@ -1,161 +0,0 @@
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
__all__ = ['ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202']
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.lrelu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
identity = torch.cat((identity, torch.zeros_like(identity)), 1)
out += identity
out = self.lrelu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, layers, num_filters=16, num_classes=10):
super(ResNet, self).__init__()
self.num_layers = sum(layers)
self.inplanes = num_filters
self.conv1 = conv3x3(3, num_filters)
self.bn1 = nn.BatchNorm2d(num_filters)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.layer1 = self._make_layer(block, num_filters, layers[0])
self.layer2 = self._make_layer(block, num_filters * 2, layers[1], stride=2)
self.skip_conv1 = conv3x3(3, num_filters*2)
self.layer3 = self._make_layer(block, num_filters * 4, layers[2], stride=2)
self.skip_conv2 = conv3x3(3, num_filters*4)
self.layer4 = self._make_layer(block, num_filters * 8, layers[2], stride=2)
self.fc1 = nn.Linear(num_filters * 8 * 8 * 8, 64, bias=True)
self.fc2 = nn.Linear(64, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
for m in self.modules():
if isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1:
downsample = nn.Sequential(
nn.AvgPool2d(1, stride=stride),
nn.BatchNorm2d(self.inplanes),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes
for _ in range(1, blocks):
layers.append(block(planes, planes))
return nn.Sequential(*layers)
def forward(self, x, gen_skips=None):
x_dim = x.size(-1)
if gen_skips is None:
gen_skips = {
int(x_dim/2): F.interpolate(x, scale_factor=1/2, mode='bilinear', align_corners=False),
int(x_dim/4): F.interpolate(x, scale_factor=1/4, mode='bilinear', align_corners=False),
}
x = self.conv1(x)
x = self.bn1(x)
x = self.lrelu(x)
x = self.layer1(x)
x = self.layer2(x)
x = (x + self.skip_conv1(gen_skips[int(x_dim/2)])) / 2
x = self.layer3(x)
x = (x + self.skip_conv2(gen_skips[int(x_dim/4)])) / 2
x = self.layer4(x)
x = x.view(x.size(0), -1)
x = self.lrelu(self.fc1(x))
x = self.fc2(x)
return x
def resnet20(**kwargs):
"""Constructs a ResNet-20 model.
"""
model = ResNet(BasicBlock, [3, 3, 3], **kwargs)
return model
def resnet32(**kwargs):
"""Constructs a ResNet-32 model.
"""
model = ResNet(BasicBlock, [5, 5, 5], **kwargs)
return model
def resnet44(**kwargs):
"""Constructs a ResNet-44 model.
"""
model = ResNet(BasicBlock, [7, 7, 7], **kwargs)
return model
def resnet56(**kwargs):
"""Constructs a ResNet-56 model.
"""
model = ResNet(BasicBlock, [9, 9, 9], **kwargs)
return model
def resnet110(**kwargs):
"""Constructs a ResNet-110 model.
"""
model = ResNet(BasicBlock, [18, 18, 18], **kwargs)
return model
def resnet1202(**kwargs):
"""Constructs a ResNet-1202 model.
"""
model = ResNet(BasicBlock, [200, 200, 200], **kwargs)
return model

View File

@ -0,0 +1,207 @@
import torch
import torch.nn as nn
import numpy as np
__all__ = ['FixupResNet', 'fixup_resnet18', 'fixup_resnet34', 'fixup_resnet50', 'fixup_resnet101', 'fixup_resnet152']
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class FixupBasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(FixupBasicBlock, self).__init__()
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.bias1a = nn.Parameter(torch.zeros(1))
self.conv1 = conv3x3(inplanes, planes, stride)
self.bias1b = nn.Parameter(torch.zeros(1))
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self.bias2a = nn.Parameter(torch.zeros(1))
self.conv2 = conv3x3(planes, planes)
self.scale = nn.Parameter(torch.ones(1))
self.bias2b = nn.Parameter(torch.zeros(1))
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x + self.bias1a)
out = self.lrelu(out + self.bias1b)
out = self.conv2(out + self.bias2a)
out = out * self.scale + self.bias2b
if self.downsample is not None:
identity = self.downsample(x + self.bias1a)
out += identity
out = self.lrelu(out)
return out
class FixupBottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(FixupBottleneck, self).__init__()
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.bias1a = nn.Parameter(torch.zeros(1))
self.conv1 = conv1x1(inplanes, planes)
self.bias1b = nn.Parameter(torch.zeros(1))
self.bias2a = nn.Parameter(torch.zeros(1))
self.conv2 = conv3x3(planes, planes, stride)
self.bias2b = nn.Parameter(torch.zeros(1))
self.bias3a = nn.Parameter(torch.zeros(1))
self.conv3 = conv1x1(planes, planes * self.expansion)
self.scale = nn.Parameter(torch.ones(1))
self.bias3b = nn.Parameter(torch.zeros(1))
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x + self.bias1a)
out = self.lrelu(out + self.bias1b)
out = self.conv2(out + self.bias2a)
out = self.lrelu(out + self.bias2b)
out = self.conv3(out + self.bias3a)
out = out * self.scale + self.bias3b
if self.downsample is not None:
identity = self.downsample(x + self.bias1a)
out += identity
out = self.lrelu(out)
return out
class FixupResNet(nn.Module):
def __init__(self, block, layers, num_filters=64, num_classes=1000, input_img_size=64):
super(FixupResNet, self).__init__()
self.num_layers = sum(layers)
self.inplanes = num_filters
self.conv1 = nn.Conv2d(3, num_filters, kernel_size=7, stride=2, padding=3,
bias=False)
self.bias1 = nn.Parameter(torch.zeros(1))
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.layer1 = self._make_layer(block, num_filters, layers[0], stride=1)
self.skip1 = nn.Conv2d(num_filters + 3, num_filters, kernel_size=5, stride=1, padding=2, bias=False)
self.skip1_bias = nn.Parameter(torch.zeros(1))
self.layer2 = self._make_layer(block, num_filters*2, layers[1], stride=2)
self.skip2 = nn.Conv2d(num_filters*2 + 3, num_filters*2, kernel_size=5, stride=1, padding=2, bias=False)
self.skip2_bias = nn.Parameter(torch.zeros(1))
self.layer3 = self._make_layer(block, num_filters*4, layers[2], stride=2)
self.layer4 = self._make_layer(block, num_filters*8, layers[3], stride=2)
self.layer5 = self._make_layer(block, num_filters*16, layers[4], stride=2)
self.bias2 = nn.Parameter(torch.zeros(1))
reduced_img_sz = int(input_img_size / 32)
self.fc1 = nn.Linear(num_filters * 16 * reduced_img_sz * reduced_img_sz, 100)
self.fc2 = nn.Linear(100, num_classes)
for m in self.modules():
if isinstance(m, FixupBasicBlock):
nn.init.normal_(m.conv1.weight, mean=0, std=np.sqrt(2 / (m.conv1.weight.shape[0] * np.prod(m.conv1.weight.shape[2:]))) * self.num_layers ** (-0.5))
nn.init.constant_(m.conv2.weight, 0)
if m.downsample is not None:
nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:]))))
elif isinstance(m, FixupBottleneck):
nn.init.normal_(m.conv1.weight, mean=0, std=np.sqrt(2 / (m.conv1.weight.shape[0] * np.prod(m.conv1.weight.shape[2:]))) * self.num_layers ** (-0.25))
nn.init.normal_(m.conv2.weight, mean=0, std=np.sqrt(2 / (m.conv2.weight.shape[0] * np.prod(m.conv2.weight.shape[2:]))) * self.num_layers ** (-0.25))
nn.init.constant_(m.conv3.weight, 0)
if m.downsample is not None:
nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:]))))
'''
elif isinstance(m, nn.Linear):
nn.init.constant_(m.weight, 0)
nn.init.constant_(m.bias, 0)'''
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = conv1x1(self.inplanes, planes * block.expansion, stride)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
# This class expects a medium skip (half-res) and low skip (quarter-res) provided as a tuple in the input.
hi, med_skip, lo_skip = x
x = self.conv1(hi)
x = self.lrelu(x + self.bias1)
x = self.layer1(x)
x = self.lrelu(self.skip1(torch.cat([x, med_skip], dim=1)) + self.skip1_bias)
x = self.layer2(x)
x = self.lrelu(self.skip2(torch.cat([x, lo_skip], dim=1)) + self.skip2_bias)
x = self.layer3(x)
x = self.layer4(x)
x = self.layer5(x)
x = x.view(x.size(0), -1)
x = self.lrelu(self.fc1(x))
x = self.fc2(x + self.bias2)
return x
def fixup_resnet18(**kwargs):
"""Constructs a Fixup-ResNet-18 model.2
"""
model = FixupResNet(FixupBasicBlock, [2, 2, 2, 2, 2], **kwargs)
return model
def fixup_resnet34(**kwargs):
"""Constructs a Fixup-ResNet-34 model.
"""
model = FixupResNet(FixupBasicBlock, [5, 4, 3, 3, 2], **kwargs)
return model
def fixup_resnet50(**kwargs):
"""Constructs a Fixup-ResNet-50 model.
"""
model = FixupResNet(FixupBottleneck, [3, 4, 6, 3, 2], **kwargs)
return model
def fixup_resnet101(**kwargs):
"""Constructs a Fixup-ResNet-101 model.
"""
model = FixupResNet(FixupBottleneck, [3, 4, 23, 3, 2], **kwargs)
return model
def fixup_resnet152(**kwargs):
"""Constructs a Fixup-ResNet-152 model.
"""
model = FixupResNet(FixupBottleneck, [3, 8, 36, 3, 2], **kwargs)
return model
__all__ = ['FixupResNet', 'fixup_resnet18', 'fixup_resnet34', 'fixup_resnet50', 'fixup_resnet101', 'fixup_resnet152']

View File

@ -0,0 +1,98 @@
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
import models.archs.arch_util as arch_util
class ResidualDenseBlock_5C(nn.Module):
def __init__(self, nf=64, gc=32, bias=True):
super(ResidualDenseBlock_5C, self).__init__()
# gc: growth channel, i.e. intermediate channels
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
# initialization
arch_util.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5],
0.1)
def forward(self, x):
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return x5 * 0.2 + x
class RRDB(nn.Module):
'''Residual in Residual Dense Block'''
def __init__(self, nf, gc=32):
super(RRDB, self).__init__()
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
def forward(self, x):
out = self.RDB1(x)
out = self.RDB2(out)
out = self.RDB3(out)
return out * 0.2 + x
class RRDBNet(nn.Module):
def __init__(self, in_nc, out_nc, nf, nb_lo, nb_med, nb_hi, gc=32, interpolation_scale_factor=2):
super(RRDBNet, self).__init__()
nfmed = int(nf/2)
nfhi = int(nf/8)
gcmed = int(gc/2)
gchi = int(gc/8)
RRDB_block_f_lo = functools.partial(RRDB, nf=nf, gc=gc)
RRDB_block_f_lo_med = functools.partial(RRDB, nf=nfmed, gc=gcmed)
RRDB_block_f_lo_hi = functools.partial(RRDB, nf=nfhi, gc=gchi)
self.conv_first = nn.Conv2d(in_nc, nf, 7, 1, padding=3, bias=True)
self.RRDB_trunk_lo = arch_util.make_layer(RRDB_block_f_lo, nb_lo)
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.lo_skip_conv1 = nn.Conv2d(nf, nf, 3, 1, padding=1, bias=True)
self.lo_skip_conv2 = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
#### upsampling
self.upconv1 = nn.Conv2d(nf, nfmed, 3, 1, padding=1, bias=True)
self.RRDB_trunk_med = arch_util.make_layer(RRDB_block_f_lo_med, nb_med)
self.trunk_conv_med = nn.Conv2d(nfmed, nfmed, 3, 1, 1, bias=True)
self.med_skip_conv1 = nn.Conv2d(nfmed, nfmed, 3, 1, padding=1, bias=True)
self.med_skip_conv2 = nn.Conv2d(nfmed, out_nc, 3, 1, 1, bias=True)
self.upconv2 = nn.Conv2d(nfmed, nfhi, 3, 1, padding=1, bias=True)
self.RRDB_trunk_hi = arch_util.make_layer(RRDB_block_f_lo_hi, nb_hi)
self.trunk_conv_hi = nn.Conv2d(nfhi, nfhi, 3, 1, 1, bias=True)
self.HRconv = nn.Conv2d(nfhi, nfhi, 5, 1, padding=2, bias=True)
self.conv_last = nn.Conv2d(nfhi, out_nc, 3, 1, 1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.interpolation_scale_factor = interpolation_scale_factor
def forward(self, x):
fea = self.conv_first(x)
branch = self.trunk_conv(self.RRDB_trunk_lo(fea))
fea = (fea + branch) / 2
lo_skip = self.lo_skip_conv2(self.lrelu(self.lo_skip_conv1(fea)))
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=self.interpolation_scale_factor, mode='nearest')))
branch = self.trunk_conv_med(self.RRDB_trunk_med(fea))
fea = (fea + branch) / 2
med_skip = self.med_skip_conv2(self.lrelu(self.med_skip_conv1(fea)))
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=self.interpolation_scale_factor, mode='nearest')))
branch = self.trunk_conv_hi(self.RRDB_trunk_hi(fea))
fea = (fea + branch) / 2
out = self.conv_last(self.lrelu(self.HRconv(fea)))
return out, med_skip, lo_skip

View File

@ -2,9 +2,10 @@ import torch
import models.archs.SRResNet_arch as SRResNet_arch
import models.archs.discriminator_vgg_arch as SRGAN_arch
import models.archs.DiscriminatorResnet_arch as DiscriminatorResnet_arch
import models.archs.DiscriminatorResnetBN_arch as DiscriminatorResnetBN_arch
import models.archs.DiscriminatorResnet_arch_passthrough as DiscriminatorResnet_arch_passthrough
import models.archs.FlatProcessorNetNew_arch as FlatProcessorNetNew_arch
import models.archs.RRDBNet_arch as RRDBNet_arch
import models.archs.RRDBNetXL_arch as RRDBNetXL_arch
#import models.archs.EDVR_arch as EDVR_arch
import models.archs.HighToLowResNet as HighToLowResNet
import models.archs.FlatProcessorNet_arch as FlatProcessorNet_arch
@ -26,6 +27,11 @@ def define_G(opt):
scale_per_step = math.sqrt(scale)
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], nb=opt_net['nb'], interpolation_scale_factor=scale_per_step)
elif which_model == 'RRDBNetXL':
scale_per_step = math.sqrt(scale)
netG = RRDBNetXL_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], nb_lo=opt_net['nblo'], nb_med=opt_net['nbmed'], nb_hi=opt_net['nbhi'],
interpolation_scale_factor=scale_per_step)
# image corruption
elif which_model == 'HighToLowResNet':
netG = HighToLowResNet.HighToLowResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
@ -59,6 +65,8 @@ def define_D(opt):
netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128)
elif which_model == 'discriminator_resnet':
netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
elif which_model == 'discriminator_resnet_passthrough':
netD = DiscriminatorResnet_arch_passthrough.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
else:
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
return netD

View File

@ -16,7 +16,7 @@ datasets:
dataroot_LQ: K:\4k6k\4k_closeup\lr_corrupted
doCrop: false
use_shuffle: true
n_workers: 12 # per GPU
n_workers: 0 # per GPU
batch_size: 40
target_size: 256
color: RGB
@ -43,7 +43,7 @@ path:
pretrain_model_G: ../experiments/rrdb_blacked_gan_g.pth
pretrain_model_D: ~
strict_load: true
resume_state: ../experiments/blacked_fix_and_upconv/training_state/9500.state
resume_state: ../experiments/blacked_fix_and_upconv/training_state/16500.state
#### training settings: learning rate scheme, loss
train:

View File

@ -0,0 +1,87 @@
#### general settings
name: blacked_fix_and_upconv_xl
use_tb_logger: true
model: srgan
distortion: sr
scale: 4
gpu_ids: [0]
amp_opt_level: O1
#### datasets
datasets:
train:
name: vixcloseup
mode: LQGT
dataroot_GT: K:\4k6k\4k_closeup\hr
dataroot_LQ: K:\4k6k\4k_closeup\lr_corrupted
doCrop: false
use_shuffle: true
n_workers: 8 # per GPU
batch_size: 6
target_size: 256
color: RGB
val:
name: adrianna_val
mode: LQGT
dataroot_GT: E:\4k6k\datasets\adrianna\val\hhq
dataroot_LQ: E:\4k6k\datasets\adrianna\val\hr
#### network structures
network_G:
which_model_G: RRDBNetXL
in_nc: 3
out_nc: 3
nf: 64
nblo: 18
nbmed: 8
nbhi: 6
network_D:
which_model_D: discriminator_resnet_passthrough
in_nc: 3
nf: 42
#### path
path:
pretrain_model_G: ../experiments/blacked_fix_and_upconv_xl_part1/models/3000_G.pth
pretrain_model_D: ~
strict_load: true
resume_state: ~
#### training settings: learning rate scheme, loss
train:
lr_G: !!float 1e-4
weight_decay_G: 0
beta1_G: 0.9
beta2_G: 0.99
lr_D: !!float 1e-4
weight_decay_D: 0
beta1_D: 0.9
beta2_D: 0.99
lr_scheme: MultiStepLR
niter: 400000
warmup_iter: -1 # no warm up
lr_steps: [20000, 40000, 50000, 60000]
lr_gamma: 0.5
mega_batch_factor: 1
pixel_criterion: l1
pixel_weight: !!float 1e-2
feature_criterion: l1
feature_weight: 1
feature_weight_decay: 1
feature_weight_decay_steps: 500
feature_weight_minimum: 1
gan_type: gan # gan | ragan
gan_weight: !!float 1e-2
D_update_ratio: 1
D_init_iters: -1
manual_seed: 10
val_freq: !!float 5e2
#### logger
logger:
print_freq: 50
save_checkpoint_freq: !!float 5e2

View File

@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs):
def main():
#### options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='options/train/train_ESRGAN_blacked.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='options/train/train_ESRGAN_blacked_xl.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)