From 3b4e54c4c5018dd686e19664a37040b97ec391ac Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 4 May 2020 14:01:43 -0600 Subject: [PATCH] Add support for passthrough disc/gen Add RRDBNetXL, which performs processing at multiple image sizes. Add DiscResnet_passthrough, which allows passthrough of image at different sizes for discrimination. Adjust the rest of the repo to allow generators that return more than just a single image. --- codes/models/SRGAN_model.py | 56 ++++- .../archs/DiscriminatorResnetBN_arch.py | 161 -------------- .../DiscriminatorResnet_arch_passthrough.py | 207 ++++++++++++++++++ codes/models/archs/RRDBNetXL_arch.py | 98 +++++++++ codes/models/networks.py | 10 +- codes/options/train/train_ESRGAN_blacked.yml | 4 +- .../options/train/train_ESRGAN_blacked_xl.yml | 87 ++++++++ codes/train.py | 2 +- 8 files changed, 448 insertions(+), 177 deletions(-) delete mode 100644 codes/models/archs/DiscriminatorResnetBN_arch.py create mode 100644 codes/models/archs/DiscriminatorResnet_arch_passthrough.py create mode 100644 codes/models/archs/RRDBNetXL_arch.py create mode 100644 codes/options/train/train_ESRGAN_blacked_xl.yml diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 7ee8589d..99068573 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -8,6 +8,7 @@ import models.lr_scheduler as lr_scheduler from models.base_model import BaseModel from models.loss import GANLoss from apex import amp +import torch.nn.functional as F import torchvision.utils as utils import os @@ -156,10 +157,21 @@ class SRGANModel(BaseModel): if step > self.D_init_iters: self.optimizer_G.zero_grad() - self.fake_H = [] + self.fake_GenOut = [] for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix): - fake_H = self.netG(var_L) - self.fake_H.append(fake_H.detach()) + fake_GenOut = self.netG(var_L) + + # Extract the image output. For generators that output skip-through connections, the master output is always + # the first element of the tuple. + if isinstance(fake_GenOut, tuple): + fake_H = fake_GenOut[0] + # TODO: Fix this. + self.fake_GenOut.append((fake_GenOut[0].detach(), + fake_GenOut[1].detach(), + fake_GenOut[2].detach())) + else: + fake_H = fake_GenOut + self.fake_GenOut.append(fake_GenOut.detach()) l_g_total = 0 if step % self.D_update_ratio == 0 and step > self.D_init_iters: @@ -178,11 +190,11 @@ class SRGANModel(BaseModel): self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay) if self.opt['train']['gan_type'] == 'gan': - pred_g_fake = self.netD(fake_H) + pred_g_fake = self.netD(fake_GenOut) l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) elif self.opt['train']['gan_type'] == 'ragan': pred_d_real = self.netD(var_ref).detach() - pred_g_fake = self.netD(fake_H) + pred_g_fake = self.netD(fake_GenOut) l_g_gan = self.l_gan_w * ( self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 @@ -196,8 +208,17 @@ class SRGANModel(BaseModel): for p in self.netD.parameters(): p.requires_grad = True + # Convert var_ref to have the same output format as the generator. This generally means interpolating the + # HR images to have the same output dimensions as each generator skip connection. + if isinstance(self.fake_GenOut[0], tuple): + var_ref_skips = [] + for ref, hi_res in zip(self.var_ref, self.var_H): + var_ref_skips.append((ref,) + self.create_artificial_skips(hi_res)) + else: + var_ref_skips = self.var_ref + self.optimizer_D.zero_grad() - for var_L, var_H, var_ref, pix, fake_H in zip(self.var_L, self.var_H, self.var_ref, self.pix, self.fake_H): + for var_L, var_H, var_ref, pix, fake_H in zip(self.var_L, self.var_H, var_ref_skips, self.pix, self.fake_GenOut): if self.opt['train']['gan_type'] == 'gan': # need to forward and backward separately, since batch norm statistics differ # real @@ -206,7 +227,7 @@ class SRGANModel(BaseModel): with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled: l_d_real_scaled.backward() # fake - pred_d_fake = self.netD(fake_H.detach()) # detach to avoid BP to G + pred_d_fake = self.netD(fake_H) # detach to avoid BP to G l_d_fake = self.cri_gan(pred_d_fake, False) with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: l_d_fake_scaled.backward() @@ -217,12 +238,12 @@ class SRGANModel(BaseModel): # l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False) # l_d_total = (l_d_real + l_d_fake) / 2 # l_d_total.backward() - pred_d_fake = self.netD(fake_H.detach()).detach() + pred_d_fake = self.netD(fake_H).detach() pred_d_real = self.netD(var_ref) l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5 with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled: l_d_real_scaled.backward() - pred_d_fake = self.netD(fake_H.detach()) + pred_d_fake = self.netD(fake_H) l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5 with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: l_d_fake_scaled.backward() @@ -234,11 +255,14 @@ class SRGANModel(BaseModel): os.makedirs("temp/lr", exist_ok=True) os.makedirs("temp/gen", exist_ok=True) os.makedirs("temp/pix", exist_ok=True) + gen_batch = self.fake_GenOut[0] + if isinstance(gen_batch, tuple): + gen_batch = gen_batch[0] for i in range(self.var_L[0].shape[0]): utils.save_image(self.var_H[0][i].cpu().detach(), os.path.join("temp/hr", "%05i_%02i.png" % (step, i))) utils.save_image(self.var_L[0][i].cpu().detach(), os.path.join("temp/lr", "%05i_%02i.png" % (step, i))) utils.save_image(self.pix[0][i].cpu().detach(), os.path.join("temp/pix", "%05i_%02i.png" % (step, i))) - utils.save_image(self.fake_H[0][i].cpu().detach(), os.path.join("temp/gen", "%05i_%02i.png" % (step, i))) + utils.save_image(gen_batch[i].cpu().detach(), os.path.join("temp/gen", "%05i_%02i.png" % (step, i))) # set log TODO(handle mega-batches?) if step % self.D_update_ratio == 0 and step > self.D_init_iters: @@ -253,10 +277,15 @@ class SRGANModel(BaseModel): self.log_dict['l_d_fake'] = l_d_fake.item() self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach()) + def create_artificial_skips(self, truth_img): + med_skip = F.interpolate(truth_img, scale_factor=.5) + lo_skip = F.interpolate(truth_img, scale_factor=.25) + return med_skip, lo_skip + def test(self): self.netG.eval() with torch.no_grad(): - self.fake_H = [self.netG(self.var_L[0])] + self.fake_GenOut = [self.netG(self.var_L[0])] self.netG.train() def get_current_log(self): @@ -265,7 +294,10 @@ class SRGANModel(BaseModel): def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict['LQ'] = self.var_L[0].detach()[0].float().cpu() - out_dict['rlt'] = self.fake_H[0].detach()[0].float().cpu() + gen_batch = self.fake_GenOut[0] + if isinstance(gen_batch, tuple): + gen_batch = gen_batch[0] + out_dict['rlt'] = gen_batch.detach()[0].float().cpu() if need_GT: out_dict['GT'] = self.var_H[0].detach()[0].float().cpu() return out_dict diff --git a/codes/models/archs/DiscriminatorResnetBN_arch.py b/codes/models/archs/DiscriminatorResnetBN_arch.py deleted file mode 100644 index 395d438b..00000000 --- a/codes/models/archs/DiscriminatorResnetBN_arch.py +++ /dev/null @@ -1,161 +0,0 @@ -import torch -import torch.nn as nn -import numpy as np -import torch.nn.functional as F - - -__all__ = ['ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202'] - - -def conv3x3(in_planes, out_planes, stride=1): - """3x3 convolution with padding""" - return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, - padding=1, bias=False) - - -class BasicBlock(nn.Module): - expansion = 1 - - def __init__(self, inplanes, planes, stride=1, downsample=None): - super(BasicBlock, self).__init__() - # Both self.conv1 and self.downsample layers downsample the input when stride != 1 - self.conv1 = conv3x3(inplanes, planes, stride) - self.bn1 = nn.BatchNorm2d(planes) - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - self.conv2 = conv3x3(planes, planes) - self.bn2 = nn.BatchNorm2d(planes) - self.downsample = downsample - - def forward(self, x): - identity = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.lrelu(out) - - out = self.conv2(out) - out = self.bn2(out) - - if self.downsample is not None: - identity = self.downsample(x) - identity = torch.cat((identity, torch.zeros_like(identity)), 1) - - out += identity - out = self.lrelu(out) - - return out - - -class ResNet(nn.Module): - - def __init__(self, block, layers, num_filters=16, num_classes=10): - super(ResNet, self).__init__() - self.num_layers = sum(layers) - self.inplanes = num_filters - self.conv1 = conv3x3(3, num_filters) - self.bn1 = nn.BatchNorm2d(num_filters) - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - self.layer1 = self._make_layer(block, num_filters, layers[0]) - self.layer2 = self._make_layer(block, num_filters * 2, layers[1], stride=2) - self.skip_conv1 = conv3x3(3, num_filters*2) - self.layer3 = self._make_layer(block, num_filters * 4, layers[2], stride=2) - self.skip_conv2 = conv3x3(3, num_filters*4) - self.layer4 = self._make_layer(block, num_filters * 8, layers[2], stride=2) - self.fc1 = nn.Linear(num_filters * 8 * 8 * 8, 64, bias=True) - self.fc2 = nn.Linear(64, num_classes) - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') - elif isinstance(m, nn.BatchNorm2d): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) - - # Zero-initialize the last BN in each residual branch, - # so that the residual branch starts with zeros, and each residual block behaves like an identity. - # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 - for m in self.modules(): - if isinstance(m, BasicBlock): - nn.init.constant_(m.bn2.weight, 0) - - def _make_layer(self, block, planes, blocks, stride=1): - downsample = None - if stride != 1: - downsample = nn.Sequential( - nn.AvgPool2d(1, stride=stride), - nn.BatchNorm2d(self.inplanes), - ) - - layers = [] - layers.append(block(self.inplanes, planes, stride, downsample)) - self.inplanes = planes - for _ in range(1, blocks): - layers.append(block(planes, planes)) - - return nn.Sequential(*layers) - - def forward(self, x, gen_skips=None): - x_dim = x.size(-1) - if gen_skips is None: - gen_skips = { - int(x_dim/2): F.interpolate(x, scale_factor=1/2, mode='bilinear', align_corners=False), - int(x_dim/4): F.interpolate(x, scale_factor=1/4, mode='bilinear', align_corners=False), - } - x = self.conv1(x) - x = self.bn1(x) - x = self.lrelu(x) - - x = self.layer1(x) - x = self.layer2(x) - x = (x + self.skip_conv1(gen_skips[int(x_dim/2)])) / 2 - x = self.layer3(x) - x = (x + self.skip_conv2(gen_skips[int(x_dim/4)])) / 2 - x = self.layer4(x) - - x = x.view(x.size(0), -1) - x = self.lrelu(self.fc1(x)) - x = self.fc2(x) - - return x - - -def resnet20(**kwargs): - """Constructs a ResNet-20 model. - """ - model = ResNet(BasicBlock, [3, 3, 3], **kwargs) - return model - - -def resnet32(**kwargs): - """Constructs a ResNet-32 model. - """ - model = ResNet(BasicBlock, [5, 5, 5], **kwargs) - return model - - -def resnet44(**kwargs): - """Constructs a ResNet-44 model. - """ - model = ResNet(BasicBlock, [7, 7, 7], **kwargs) - return model - - -def resnet56(**kwargs): - """Constructs a ResNet-56 model. - """ - model = ResNet(BasicBlock, [9, 9, 9], **kwargs) - return model - - -def resnet110(**kwargs): - """Constructs a ResNet-110 model. - """ - model = ResNet(BasicBlock, [18, 18, 18], **kwargs) - return model - - -def resnet1202(**kwargs): - """Constructs a ResNet-1202 model. - """ - model = ResNet(BasicBlock, [200, 200, 200], **kwargs) - return model \ No newline at end of file diff --git a/codes/models/archs/DiscriminatorResnet_arch_passthrough.py b/codes/models/archs/DiscriminatorResnet_arch_passthrough.py new file mode 100644 index 00000000..f9e1e101 --- /dev/null +++ b/codes/models/archs/DiscriminatorResnet_arch_passthrough.py @@ -0,0 +1,207 @@ +import torch +import torch.nn as nn +import numpy as np + + +__all__ = ['FixupResNet', 'fixup_resnet18', 'fixup_resnet34', 'fixup_resnet50', 'fixup_resnet101', 'fixup_resnet152'] + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class FixupBasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(FixupBasicBlock, self).__init__() + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.bias1a = nn.Parameter(torch.zeros(1)) + self.conv1 = conv3x3(inplanes, planes, stride) + self.bias1b = nn.Parameter(torch.zeros(1)) + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + self.bias2a = nn.Parameter(torch.zeros(1)) + self.conv2 = conv3x3(planes, planes) + self.scale = nn.Parameter(torch.ones(1)) + self.bias2b = nn.Parameter(torch.zeros(1)) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x + self.bias1a) + out = self.lrelu(out + self.bias1b) + + out = self.conv2(out + self.bias2a) + out = out * self.scale + self.bias2b + + if self.downsample is not None: + identity = self.downsample(x + self.bias1a) + + out += identity + out = self.lrelu(out) + + return out + +class FixupBottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(FixupBottleneck, self).__init__() + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.bias1a = nn.Parameter(torch.zeros(1)) + self.conv1 = conv1x1(inplanes, planes) + self.bias1b = nn.Parameter(torch.zeros(1)) + self.bias2a = nn.Parameter(torch.zeros(1)) + self.conv2 = conv3x3(planes, planes, stride) + self.bias2b = nn.Parameter(torch.zeros(1)) + self.bias3a = nn.Parameter(torch.zeros(1)) + self.conv3 = conv1x1(planes, planes * self.expansion) + self.scale = nn.Parameter(torch.ones(1)) + self.bias3b = nn.Parameter(torch.zeros(1)) + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x + self.bias1a) + out = self.lrelu(out + self.bias1b) + + out = self.conv2(out + self.bias2a) + out = self.lrelu(out + self.bias2b) + + out = self.conv3(out + self.bias3a) + out = out * self.scale + self.bias3b + + if self.downsample is not None: + identity = self.downsample(x + self.bias1a) + + out += identity + out = self.lrelu(out) + + return out + + +class FixupResNet(nn.Module): + + def __init__(self, block, layers, num_filters=64, num_classes=1000, input_img_size=64): + super(FixupResNet, self).__init__() + self.num_layers = sum(layers) + self.inplanes = num_filters + self.conv1 = nn.Conv2d(3, num_filters, kernel_size=7, stride=2, padding=3, + bias=False) + self.bias1 = nn.Parameter(torch.zeros(1)) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + self.layer1 = self._make_layer(block, num_filters, layers[0], stride=1) + self.skip1 = nn.Conv2d(num_filters + 3, num_filters, kernel_size=5, stride=1, padding=2, bias=False) + self.skip1_bias = nn.Parameter(torch.zeros(1)) + self.layer2 = self._make_layer(block, num_filters*2, layers[1], stride=2) + self.skip2 = nn.Conv2d(num_filters*2 + 3, num_filters*2, kernel_size=5, stride=1, padding=2, bias=False) + self.skip2_bias = nn.Parameter(torch.zeros(1)) + self.layer3 = self._make_layer(block, num_filters*4, layers[2], stride=2) + self.layer4 = self._make_layer(block, num_filters*8, layers[3], stride=2) + self.layer5 = self._make_layer(block, num_filters*16, layers[4], stride=2) + self.bias2 = nn.Parameter(torch.zeros(1)) + reduced_img_sz = int(input_img_size / 32) + self.fc1 = nn.Linear(num_filters * 16 * reduced_img_sz * reduced_img_sz, 100) + self.fc2 = nn.Linear(100, num_classes) + + for m in self.modules(): + if isinstance(m, FixupBasicBlock): + nn.init.normal_(m.conv1.weight, mean=0, std=np.sqrt(2 / (m.conv1.weight.shape[0] * np.prod(m.conv1.weight.shape[2:]))) * self.num_layers ** (-0.5)) + nn.init.constant_(m.conv2.weight, 0) + if m.downsample is not None: + nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:])))) + elif isinstance(m, FixupBottleneck): + nn.init.normal_(m.conv1.weight, mean=0, std=np.sqrt(2 / (m.conv1.weight.shape[0] * np.prod(m.conv1.weight.shape[2:]))) * self.num_layers ** (-0.25)) + nn.init.normal_(m.conv2.weight, mean=0, std=np.sqrt(2 / (m.conv2.weight.shape[0] * np.prod(m.conv2.weight.shape[2:]))) * self.num_layers ** (-0.25)) + nn.init.constant_(m.conv3.weight, 0) + if m.downsample is not None: + nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:])))) + ''' + elif isinstance(m, nn.Linear): + nn.init.constant_(m.weight, 0) + nn.init.constant_(m.bias, 0)''' + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = conv1x1(self.inplanes, planes * block.expansion, stride) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + # This class expects a medium skip (half-res) and low skip (quarter-res) provided as a tuple in the input. + hi, med_skip, lo_skip = x + + x = self.conv1(hi) + x = self.lrelu(x + self.bias1) + x = self.layer1(x) + x = self.lrelu(self.skip1(torch.cat([x, med_skip], dim=1)) + self.skip1_bias) + + x = self.layer2(x) + x = self.lrelu(self.skip2(torch.cat([x, lo_skip], dim=1)) + self.skip2_bias) + + x = self.layer3(x) + x = self.layer4(x) + x = self.layer5(x) + + x = x.view(x.size(0), -1) + x = self.lrelu(self.fc1(x)) + x = self.fc2(x + self.bias2) + + return x + + +def fixup_resnet18(**kwargs): + """Constructs a Fixup-ResNet-18 model.2 + """ + model = FixupResNet(FixupBasicBlock, [2, 2, 2, 2, 2], **kwargs) + return model + + +def fixup_resnet34(**kwargs): + """Constructs a Fixup-ResNet-34 model. + """ + model = FixupResNet(FixupBasicBlock, [5, 4, 3, 3, 2], **kwargs) + return model + + +def fixup_resnet50(**kwargs): + """Constructs a Fixup-ResNet-50 model. + """ + model = FixupResNet(FixupBottleneck, [3, 4, 6, 3, 2], **kwargs) + return model + + +def fixup_resnet101(**kwargs): + """Constructs a Fixup-ResNet-101 model. + """ + model = FixupResNet(FixupBottleneck, [3, 4, 23, 3, 2], **kwargs) + return model + + +def fixup_resnet152(**kwargs): + """Constructs a Fixup-ResNet-152 model. + """ + model = FixupResNet(FixupBottleneck, [3, 8, 36, 3, 2], **kwargs) + return model + + +__all__ = ['FixupResNet', 'fixup_resnet18', 'fixup_resnet34', 'fixup_resnet50', 'fixup_resnet101', 'fixup_resnet152'] \ No newline at end of file diff --git a/codes/models/archs/RRDBNetXL_arch.py b/codes/models/archs/RRDBNetXL_arch.py new file mode 100644 index 00000000..9831aa9f --- /dev/null +++ b/codes/models/archs/RRDBNetXL_arch.py @@ -0,0 +1,98 @@ +import functools +import torch +import torch.nn as nn +import torch.nn.functional as F +import models.archs.arch_util as arch_util + + +class ResidualDenseBlock_5C(nn.Module): + def __init__(self, nf=64, gc=32, bias=True): + super(ResidualDenseBlock_5C, self).__init__() + # gc: growth channel, i.e. intermediate channels + self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) + self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) + self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) + self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) + self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + # initialization + arch_util.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], + 0.1) + + def forward(self, x): + x1 = self.lrelu(self.conv1(x)) + x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) + x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) + x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + return x5 * 0.2 + x + + +class RRDB(nn.Module): + '''Residual in Residual Dense Block''' + + def __init__(self, nf, gc=32): + super(RRDB, self).__init__() + self.RDB1 = ResidualDenseBlock_5C(nf, gc) + self.RDB2 = ResidualDenseBlock_5C(nf, gc) + self.RDB3 = ResidualDenseBlock_5C(nf, gc) + + def forward(self, x): + out = self.RDB1(x) + out = self.RDB2(out) + out = self.RDB3(out) + return out * 0.2 + x + + +class RRDBNet(nn.Module): + def __init__(self, in_nc, out_nc, nf, nb_lo, nb_med, nb_hi, gc=32, interpolation_scale_factor=2): + super(RRDBNet, self).__init__() + nfmed = int(nf/2) + nfhi = int(nf/8) + gcmed = int(gc/2) + gchi = int(gc/8) + RRDB_block_f_lo = functools.partial(RRDB, nf=nf, gc=gc) + RRDB_block_f_lo_med = functools.partial(RRDB, nf=nfmed, gc=gcmed) + RRDB_block_f_lo_hi = functools.partial(RRDB, nf=nfhi, gc=gchi) + + self.conv_first = nn.Conv2d(in_nc, nf, 7, 1, padding=3, bias=True) + self.RRDB_trunk_lo = arch_util.make_layer(RRDB_block_f_lo, nb_lo) + self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.lo_skip_conv1 = nn.Conv2d(nf, nf, 3, 1, padding=1, bias=True) + self.lo_skip_conv2 = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) + + #### upsampling + self.upconv1 = nn.Conv2d(nf, nfmed, 3, 1, padding=1, bias=True) + self.RRDB_trunk_med = arch_util.make_layer(RRDB_block_f_lo_med, nb_med) + self.trunk_conv_med = nn.Conv2d(nfmed, nfmed, 3, 1, 1, bias=True) + self.med_skip_conv1 = nn.Conv2d(nfmed, nfmed, 3, 1, padding=1, bias=True) + self.med_skip_conv2 = nn.Conv2d(nfmed, out_nc, 3, 1, 1, bias=True) + + self.upconv2 = nn.Conv2d(nfmed, nfhi, 3, 1, padding=1, bias=True) + self.RRDB_trunk_hi = arch_util.make_layer(RRDB_block_f_lo_hi, nb_hi) + self.trunk_conv_hi = nn.Conv2d(nfhi, nfhi, 3, 1, 1, bias=True) + self.HRconv = nn.Conv2d(nfhi, nfhi, 5, 1, padding=2, bias=True) + self.conv_last = nn.Conv2d(nfhi, out_nc, 3, 1, 1, bias=True) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + self.interpolation_scale_factor = interpolation_scale_factor + + def forward(self, x): + fea = self.conv_first(x) + branch = self.trunk_conv(self.RRDB_trunk_lo(fea)) + fea = (fea + branch) / 2 + lo_skip = self.lo_skip_conv2(self.lrelu(self.lo_skip_conv1(fea))) + + fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=self.interpolation_scale_factor, mode='nearest'))) + branch = self.trunk_conv_med(self.RRDB_trunk_med(fea)) + fea = (fea + branch) / 2 + med_skip = self.med_skip_conv2(self.lrelu(self.med_skip_conv1(fea))) + + fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=self.interpolation_scale_factor, mode='nearest'))) + branch = self.trunk_conv_hi(self.RRDB_trunk_hi(fea)) + fea = (fea + branch) / 2 + out = self.conv_last(self.lrelu(self.HRconv(fea))) + + return out, med_skip, lo_skip \ No newline at end of file diff --git a/codes/models/networks.py b/codes/models/networks.py index af59b905..0e66bb61 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -2,9 +2,10 @@ import torch import models.archs.SRResNet_arch as SRResNet_arch import models.archs.discriminator_vgg_arch as SRGAN_arch import models.archs.DiscriminatorResnet_arch as DiscriminatorResnet_arch -import models.archs.DiscriminatorResnetBN_arch as DiscriminatorResnetBN_arch +import models.archs.DiscriminatorResnet_arch_passthrough as DiscriminatorResnet_arch_passthrough import models.archs.FlatProcessorNetNew_arch as FlatProcessorNetNew_arch import models.archs.RRDBNet_arch as RRDBNet_arch +import models.archs.RRDBNetXL_arch as RRDBNetXL_arch #import models.archs.EDVR_arch as EDVR_arch import models.archs.HighToLowResNet as HighToLowResNet import models.archs.FlatProcessorNet_arch as FlatProcessorNet_arch @@ -26,6 +27,11 @@ def define_G(opt): scale_per_step = math.sqrt(scale) netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], interpolation_scale_factor=scale_per_step) + elif which_model == 'RRDBNetXL': + scale_per_step = math.sqrt(scale) + netG = RRDBNetXL_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], + nf=opt_net['nf'], nb_lo=opt_net['nblo'], nb_med=opt_net['nbmed'], nb_hi=opt_net['nbhi'], + interpolation_scale_factor=scale_per_step) # image corruption elif which_model == 'HighToLowResNet': netG = HighToLowResNet.HighToLowResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], @@ -59,6 +65,8 @@ def define_D(opt): netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128) elif which_model == 'discriminator_resnet': netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz) + elif which_model == 'discriminator_resnet_passthrough': + netD = DiscriminatorResnet_arch_passthrough.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz) else: raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) return netD diff --git a/codes/options/train/train_ESRGAN_blacked.yml b/codes/options/train/train_ESRGAN_blacked.yml index 9e51ce5f..009b55c9 100644 --- a/codes/options/train/train_ESRGAN_blacked.yml +++ b/codes/options/train/train_ESRGAN_blacked.yml @@ -16,7 +16,7 @@ datasets: dataroot_LQ: K:\4k6k\4k_closeup\lr_corrupted doCrop: false use_shuffle: true - n_workers: 12 # per GPU + n_workers: 0 # per GPU batch_size: 40 target_size: 256 color: RGB @@ -43,7 +43,7 @@ path: pretrain_model_G: ../experiments/rrdb_blacked_gan_g.pth pretrain_model_D: ~ strict_load: true - resume_state: ../experiments/blacked_fix_and_upconv/training_state/9500.state + resume_state: ../experiments/blacked_fix_and_upconv/training_state/16500.state #### training settings: learning rate scheme, loss train: diff --git a/codes/options/train/train_ESRGAN_blacked_xl.yml b/codes/options/train/train_ESRGAN_blacked_xl.yml new file mode 100644 index 00000000..7f985e25 --- /dev/null +++ b/codes/options/train/train_ESRGAN_blacked_xl.yml @@ -0,0 +1,87 @@ +#### general settings +name: blacked_fix_and_upconv_xl +use_tb_logger: true +model: srgan +distortion: sr +scale: 4 +gpu_ids: [0] +amp_opt_level: O1 + +#### datasets +datasets: + train: + name: vixcloseup + mode: LQGT + dataroot_GT: K:\4k6k\4k_closeup\hr + dataroot_LQ: K:\4k6k\4k_closeup\lr_corrupted + doCrop: false + use_shuffle: true + n_workers: 8 # per GPU + batch_size: 6 + target_size: 256 + color: RGB + val: + name: adrianna_val + mode: LQGT + dataroot_GT: E:\4k6k\datasets\adrianna\val\hhq + dataroot_LQ: E:\4k6k\datasets\adrianna\val\hr + +#### network structures +network_G: + which_model_G: RRDBNetXL + in_nc: 3 + out_nc: 3 + nf: 64 + nblo: 18 + nbmed: 8 + nbhi: 6 +network_D: + which_model_D: discriminator_resnet_passthrough + in_nc: 3 + nf: 42 + +#### path +path: + pretrain_model_G: ../experiments/blacked_fix_and_upconv_xl_part1/models/3000_G.pth + pretrain_model_D: ~ + strict_load: true + resume_state: ~ + +#### training settings: learning rate scheme, loss +train: + lr_G: !!float 1e-4 + weight_decay_G: 0 + beta1_G: 0.9 + beta2_G: 0.99 + lr_D: !!float 1e-4 + weight_decay_D: 0 + beta1_D: 0.9 + beta2_D: 0.99 + lr_scheme: MultiStepLR + + niter: 400000 + warmup_iter: -1 # no warm up + lr_steps: [20000, 40000, 50000, 60000] + lr_gamma: 0.5 + mega_batch_factor: 1 + + pixel_criterion: l1 + pixel_weight: !!float 1e-2 + feature_criterion: l1 + feature_weight: 1 + feature_weight_decay: 1 + feature_weight_decay_steps: 500 + feature_weight_minimum: 1 + gan_type: gan # gan | ragan + gan_weight: !!float 1e-2 + + D_update_ratio: 1 + D_init_iters: -1 + + manual_seed: 10 + val_freq: !!float 5e2 + +#### logger +logger: + print_freq: 50 + save_checkpoint_freq: !!float 5e2 diff --git a/codes/train.py b/codes/train.py index 645ac867..18691a4d 100644 --- a/codes/train.py +++ b/codes/train.py @@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='options/train/train_ESRGAN_blacked.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='options/train/train_ESRGAN_blacked_xl.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0)