Remove pyramid_disc hard dependencies
This commit is contained in:
parent
6b679e2b51
commit
5c10264538
|
@ -5,7 +5,6 @@ from models.archs.RRDBNet_arch import RRDB, RRDBWithBypass
|
||||||
from models.archs.arch_util import ConvBnLelu, ConvGnLelu, ExpansionBlock, ConvGnSilu, ResidualBlockGN
|
from models.archs.arch_util import ConvBnLelu, ConvGnLelu, ExpansionBlock, ConvGnSilu, ResidualBlockGN
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from models.archs.SwitchedResidualGenerator_arch import gather_2d
|
from models.archs.SwitchedResidualGenerator_arch import gather_2d
|
||||||
from models.archs.pyramid_arch import Pyramid
|
|
||||||
from utils.util import checkpoint
|
from utils.util import checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
@ -660,29 +659,3 @@ class SingleImageQualityEstimator(nn.Module):
|
||||||
fea = self.lrelu(self.conv4_2(fea))
|
fea = self.lrelu(self.conv4_2(fea))
|
||||||
fea = self.sigmoid(self.conv4_3(fea))
|
fea = self.sigmoid(self.conv4_3(fea))
|
||||||
return fea
|
return fea
|
||||||
|
|
||||||
|
|
||||||
class PyramidDiscriminator(nn.Module):
|
|
||||||
def __init__(self, in_nc, nf, block=ConvGnLelu):
|
|
||||||
super(PyramidDiscriminator, self).__init__()
|
|
||||||
self.initial_conv = block(in_nc, nf, kernel_size=3, stride=2, bias=True, norm=False, activation=True)
|
|
||||||
self.top_proc = nn.Sequential(*[ResidualBlockGN(nf),
|
|
||||||
ResidualBlockGN(nf),
|
|
||||||
ResidualBlockGN(nf)])
|
|
||||||
self.pyramid = Pyramid(nf, depth=3, processing_convs_per_layer=2, processing_at_point=2,
|
|
||||||
scale_per_level=1.5, norm=True, return_outlevels=False)
|
|
||||||
self.bottom_proc = nn.Sequential(*[ResidualBlockGN(nf),
|
|
||||||
ResidualBlockGN(nf),
|
|
||||||
ResidualBlockGN(nf),
|
|
||||||
ResidualBlockGN(nf),
|
|
||||||
ConvGnLelu(nf, nf // 2, kernel_size=1, activation=True, norm=False, bias=True),
|
|
||||||
ConvGnLelu(nf // 2, nf // 4, kernel_size=1, activation=True, norm=False, bias=True),
|
|
||||||
ConvGnLelu(nf // 4, 1, kernel_size=1, activation=False, norm=False, bias=True)])
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
fea = self.initial_conv(x)
|
|
||||||
fea = checkpoint(self.top_proc, fea)
|
|
||||||
fea = checkpoint(self.pyramid, fea)
|
|
||||||
fea = checkpoint(self.bottom_proc, fea)
|
|
||||||
return torch.mean(fea, dim=[1,2,3])
|
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,6 @@ import models.archs.rcan as rcan
|
||||||
from models.archs import srg2_classic
|
from models.archs import srg2_classic
|
||||||
from models.archs.biggan.biggan_discriminator import BigGanDiscriminator
|
from models.archs.biggan.biggan_discriminator import BigGanDiscriminator
|
||||||
from models.archs.stylegan.Discriminator_StyleGAN import StyleGanDiscriminator
|
from models.archs.stylegan.Discriminator_StyleGAN import StyleGanDiscriminator
|
||||||
from models.archs.pyramid_arch import BasicResamplingFlowNet
|
|
||||||
from models.archs.rrdb_with_adain_latent import AdaRRDBNet, LinearLatentEstimator
|
from models.archs.rrdb_with_adain_latent import AdaRRDBNet, LinearLatentEstimator
|
||||||
from models.archs.rrdb_with_latent import LatentEstimator, RRDBNetWithLatent, LatentEstimator2
|
from models.archs.rrdb_with_latent import LatentEstimator, RRDBNetWithLatent, LatentEstimator2
|
||||||
from models.archs.teco_resgen import TecoGen
|
from models.archs.teco_resgen import TecoGen
|
||||||
|
@ -198,8 +197,6 @@ def define_D_net(opt_net, img_sz=None, wrap=False):
|
||||||
netD = SRGAN_arch.RefDiscriminatorVgg128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128)
|
netD = SRGAN_arch.RefDiscriminatorVgg128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128)
|
||||||
elif which_model == "psnr_approximator":
|
elif which_model == "psnr_approximator":
|
||||||
netD = SRGAN_arch.PsnrApproximator(nf=opt_net['nf'], input_img_factor=img_sz / 128)
|
netD = SRGAN_arch.PsnrApproximator(nf=opt_net['nf'], input_img_factor=img_sz / 128)
|
||||||
elif which_model == "pyramid_disc":
|
|
||||||
netD = SRGAN_arch.PyramidDiscriminator(in_nc=3, nf=opt_net['nf'])
|
|
||||||
elif which_model == "stylegan2_discriminator":
|
elif which_model == "stylegan2_discriminator":
|
||||||
attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
|
attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
|
||||||
disc = stylegan2.StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn)
|
disc = stylegan2.StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user