PyramidRRDB net
This commit is contained in:
parent
a1760f8969
commit
72762f200c
|
@ -116,7 +116,7 @@ class RRDBWithBypass(nn.Module):
|
||||||
out = self.rdb3(out)
|
out = self.rdb3(out)
|
||||||
bypass = self.bypass(torch.cat([x, out], dim=1))
|
bypass = self.bypass(torch.cat([x, out], dim=1))
|
||||||
self.bypass_map = bypass.detach().clone()
|
self.bypass_map = bypass.detach().clone()
|
||||||
# Emperically, we use 0.2 to scale the residual for better performance
|
# Empirically, we use 0.2 to scale the residual for better performance
|
||||||
return out * 0.2 * bypass + x
|
return out * 0.2 * bypass + x
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,11 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from models.archs.RRDBNet_arch import RRDB, RRDBWithBypass
|
||||||
from models.archs.arch_util import ConvBnLelu, ConvGnLelu, ExpansionBlock, ConvGnSilu
|
from models.archs.arch_util import ConvBnLelu, ConvGnLelu, ExpansionBlock, ConvGnSilu
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from models.archs.SwitchedResidualGenerator_arch import gather_2d
|
from models.archs.SwitchedResidualGenerator_arch import gather_2d
|
||||||
|
from models.archs.pyramid_arch import Pyramid
|
||||||
from utils.util import checkpoint
|
from utils.util import checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
@ -78,6 +81,7 @@ class Discriminator_VGG_128(nn.Module):
|
||||||
out = self.linear2(fea)
|
out = self.linear2(fea)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class Discriminator_VGG_128_GN(nn.Module):
|
class Discriminator_VGG_128_GN(nn.Module):
|
||||||
# input_img_factor = multiplier to support images over 128x128. Only certain factors are supported.
|
# input_img_factor = multiplier to support images over 128x128. Only certain factors are supported.
|
||||||
def __init__(self, in_nc, nf, input_img_factor=1, do_checkpointing=False):
|
def __init__(self, in_nc, nf, input_img_factor=1, do_checkpointing=False):
|
||||||
|
@ -656,3 +660,26 @@ class SingleImageQualityEstimator(nn.Module):
|
||||||
fea = self.lrelu(self.conv4_2(fea))
|
fea = self.lrelu(self.conv4_2(fea))
|
||||||
fea = self.sigmoid(self.conv4_3(fea))
|
fea = self.sigmoid(self.conv4_3(fea))
|
||||||
return fea
|
return fea
|
||||||
|
|
||||||
|
|
||||||
|
class PyramidRRDBDiscriminator(nn.Module):
|
||||||
|
def __init__(self, in_nc, nf, block=ConvGnLelu):
|
||||||
|
super(PyramidRRDBDiscriminator, self).__init__()
|
||||||
|
self.initial_conv = block(in_nc, nf, kernel_size=3, stride=2, bias=True, norm=False, activation=True)
|
||||||
|
self.top_proc = nn.Sequential(*[RRDBWithBypass(nf),
|
||||||
|
RRDBWithBypass(nf)])
|
||||||
|
self.pyramid = Pyramid(nf, depth=3, processing_convs_per_layer=2, processing_at_point=2,
|
||||||
|
scale_per_level=1.5, norm=True, return_outlevels=False)
|
||||||
|
self.bottom_proc = nn.Sequential(*[RRDBWithBypass(nf),
|
||||||
|
RRDBWithBypass(nf),
|
||||||
|
ConvGnLelu(nf, nf // 2, kernel_size=1, activation=True, norm=True, bias=True),
|
||||||
|
ConvGnLelu(nf // 2, nf // 4, kernel_size=1, activation=True, norm=True, bias=True),
|
||||||
|
ConvGnLelu(nf // 4, 1, activation=False, norm=False, bias=True)])
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
fea = self.initial_conv(x)
|
||||||
|
fea = checkpoint(self.top_proc, fea)
|
||||||
|
fea = checkpoint(self.pyramid, fea)
|
||||||
|
fea = checkpoint(self.bottom_proc, fea)
|
||||||
|
return torch.mean(fea, dim=[1,2,3])
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from models.archs.arch_util import ConvGnLelu, UpconvBlock, ExpansionBlock
|
from models.archs.arch_util import ConvGnLelu, ExpansionBlock
|
||||||
from models.flownet2.networks.resample2d_package.resample2d import Resample2d
|
from models.flownet2.networks.resample2d_package.resample2d import Resample2d
|
||||||
from utils.util import checkpoint
|
from utils.util import checkpoint
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
|
@ -187,6 +187,8 @@ def define_D_net(opt_net, img_sz=None, wrap=False):
|
||||||
netD = SRGAN_arch.RefDiscriminatorVgg128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128)
|
netD = SRGAN_arch.RefDiscriminatorVgg128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128)
|
||||||
elif which_model == "psnr_approximator":
|
elif which_model == "psnr_approximator":
|
||||||
netD = SRGAN_arch.PsnrApproximator(nf=opt_net['nf'], input_img_factor=img_sz / 128)
|
netD = SRGAN_arch.PsnrApproximator(nf=opt_net['nf'], input_img_factor=img_sz / 128)
|
||||||
|
elif which_model == "pyramid_rrdb_disc":
|
||||||
|
netD = SRGAN_arch.PyramidRRDBDiscriminator(in_nc=3, nf=opt_net['nf'])
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
||||||
return netD
|
return netD
|
||||||
|
|
Loading…
Reference in New Issue
Block a user