resnext discriminator
This commit is contained in:
parent
55f2764fef
commit
8beaa47933
|
@ -12,6 +12,7 @@ import models.archs.SPSR_arch as spsr
|
||||||
import models.archs.StructuredSwitchedGenerator as ssg
|
import models.archs.StructuredSwitchedGenerator as ssg
|
||||||
import models.archs.rcan as rcan
|
import models.archs.rcan as rcan
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
import torchvision
|
||||||
|
|
||||||
logger = logging.getLogger('base')
|
logger = logging.getLogger('base')
|
||||||
|
|
||||||
|
@ -126,6 +127,9 @@ def define_D_net(opt_net, img_sz=None, wrap=False):
|
||||||
netD = DiscriminatorResnet_arch_passthrough.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz,
|
netD = DiscriminatorResnet_arch_passthrough.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz,
|
||||||
number_skips=opt_net['number_skips'], use_bn=True,
|
number_skips=opt_net['number_skips'], use_bn=True,
|
||||||
disable_passthrough=opt_net['disable_passthrough'])
|
disable_passthrough=opt_net['disable_passthrough'])
|
||||||
|
elif which_model == 'resnext':
|
||||||
|
netD = torchvision.models.resnext50_32x4d(pretrained=True)
|
||||||
|
netD.fc = torch.nn.Linear(512 * 4, 1)
|
||||||
elif which_model == 'discriminator_pix':
|
elif which_model == 'discriminator_pix':
|
||||||
netD = SRGAN_arch.Discriminator_VGG_PixLoss(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
|
netD = SRGAN_arch.Discriminator_VGG_PixLoss(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
|
||||||
elif which_model == "discriminator_unet":
|
elif which_model == "discriminator_unet":
|
||||||
|
|
Loading…
Reference in New Issue
Block a user