Convert PyramidRRDBDisc to RRDBDisc

Had numeric stability issues. This probably makes more sense anyways.
This commit is contained in:
James Betker 2020-11-11 12:14:14 -07:00
parent 72762f200c
commit 42a97de756
2 changed files with 17 additions and 17 deletions

View File

@ -6,7 +6,7 @@ from models.archs.arch_util import ConvBnLelu, ConvGnLelu, ExpansionBlock, ConvG
import torch.nn.functional as F import torch.nn.functional as F
from models.archs.SwitchedResidualGenerator_arch import gather_2d from models.archs.SwitchedResidualGenerator_arch import gather_2d
from models.archs.pyramid_arch import Pyramid from models.archs.pyramid_arch import Pyramid
from utils.util import checkpoint from utils.util import checkpoint, sequential_checkpoint
class Discriminator_VGG_128(nn.Module): class Discriminator_VGG_128(nn.Module):
@ -662,24 +662,24 @@ class SingleImageQualityEstimator(nn.Module):
return fea return fea
class PyramidRRDBDiscriminator(nn.Module): class RRDBDiscriminator(nn.Module):
def __init__(self, in_nc, nf, block=ConvGnLelu): def __init__(self, in_nc, nf, block=ConvGnLelu):
super(PyramidRRDBDiscriminator, self).__init__() super(RRDBDiscriminator, self).__init__()
self.initial_conv = block(in_nc, nf, kernel_size=3, stride=2, bias=True, norm=False, activation=True) self.initial_conv = block(in_nc, nf, kernel_size=3, stride=2, bias=True, norm=False, activation=True)
self.top_proc = nn.Sequential(*[RRDBWithBypass(nf), self.trunk = nn.ModuleList(*[RRDBWithBypass(nf),
RRDBWithBypass(nf)]) RRDBWithBypass(nf),
self.pyramid = Pyramid(nf, depth=3, processing_convs_per_layer=2, processing_at_point=2, RRDBWithBypass(nf),
scale_per_level=1.5, norm=True, return_outlevels=False) RRDBWithBypass(nf),
self.bottom_proc = nn.Sequential(*[RRDBWithBypass(nf), RRDBWithBypass(nf)])
RRDBWithBypass(nf),
ConvGnLelu(nf, nf // 2, kernel_size=1, activation=True, norm=True, bias=True), self.tail = nn.Sequential(*[
ConvGnLelu(nf // 2, nf // 4, kernel_size=1, activation=True, norm=True, bias=True), ConvGnLelu(nf, nf // 2, kernel_size=1, activation=True, norm=True, bias=True),
ConvGnLelu(nf // 4, 1, activation=False, norm=False, bias=True)]) ConvGnLelu(nf // 2, nf // 4, kernel_size=1, activation=True, norm=True, bias=True),
ConvGnLelu(nf // 4, 1, activation=False, norm=False, bias=True)])
def forward(self, x): def forward(self, x):
fea = self.initial_conv(x) fea = self.initial_conv(x)
fea = checkpoint(self.top_proc, fea) fea = sequential_checkpoint(self.top_proc, 2, fea)
fea = checkpoint(self.pyramid, fea) fea = checkpoint(self.tail, fea)
fea = checkpoint(self.bottom_proc, fea)
return torch.mean(fea, dim=[1,2,3]) return torch.mean(fea, dim=[1,2,3])

View File

@ -187,8 +187,8 @@ def define_D_net(opt_net, img_sz=None, wrap=False):
netD = SRGAN_arch.RefDiscriminatorVgg128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128) netD = SRGAN_arch.RefDiscriminatorVgg128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128)
elif which_model == "psnr_approximator": elif which_model == "psnr_approximator":
netD = SRGAN_arch.PsnrApproximator(nf=opt_net['nf'], input_img_factor=img_sz / 128) netD = SRGAN_arch.PsnrApproximator(nf=opt_net['nf'], input_img_factor=img_sz / 128)
elif which_model == "pyramid_rrdb_disc": elif which_model == "rrdb_disc":
netD = SRGAN_arch.PyramidRRDBDiscriminator(in_nc=3, nf=opt_net['nf']) netD = SRGAN_arch.RRDBDiscriminator(in_nc=3, nf=opt_net['nf'])
else: else:
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
return netD return netD