forked from mrq/DL-Art-School
Add capability to place additional conv into discriminator
This should allow us to support larger images sizes. May need to add another one of these.
This commit is contained in:
parent
bad33de906
commit
dfcbe5f2db
|
@ -5,7 +5,7 @@ import torchvision
|
|||
|
||||
class Discriminator_VGG_128(nn.Module):
|
||||
# input_img_factor = multiplier to support images over 128x128. Only certain factors are supported.
|
||||
def __init__(self, in_nc, nf, input_img_factor=1):
|
||||
def __init__(self, in_nc, nf, input_img_factor=1, extra_conv=False):
|
||||
super(Discriminator_VGG_128, self).__init__()
|
||||
# [64, 128, 128]
|
||||
self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||
|
@ -31,8 +31,18 @@ class Discriminator_VGG_128(nn.Module):
|
|||
self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True)
|
||||
self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
|
||||
self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True)
|
||||
final_nf = nf * 8
|
||||
|
||||
self.linear1 = nn.Linear(int(nf * 8 * 4 * input_img_factor * 4 * input_img_factor), 100)
|
||||
self.extra_conv = extra_conv
|
||||
if self.extra_conv:
|
||||
self.conv5_0 = nn.Conv2d(nf * 8, nf * 16, 3, 1, 1, bias=False)
|
||||
self.bn5_0 = nn.BatchNorm2d(nf * 16, affine=True)
|
||||
self.conv5_1 = nn.Conv2d(nf * 16, nf * 16, 4, 2, 1, bias=False)
|
||||
self.bn5_1 = nn.BatchNorm2d(nf * 16, affine=True)
|
||||
input_img_factor = input_img_factor // 2
|
||||
final_nf = nf * 16
|
||||
|
||||
self.linear1 = nn.Linear(final_nf * 4 * input_img_factor * 4 * input_img_factor, 100)
|
||||
self.linear2 = nn.Linear(100, 1)
|
||||
|
||||
# activation function
|
||||
|
@ -57,6 +67,10 @@ class Discriminator_VGG_128(nn.Module):
|
|||
fea = self.lrelu(self.bn4_0(self.conv4_0(fea)))
|
||||
fea = self.lrelu(self.bn4_1(self.conv4_1(fea)))
|
||||
|
||||
if self.extra_conv:
|
||||
fea = self.lrelu(self.bn5_0(self.conv5_0(fea)))
|
||||
fea = self.lrelu(self.bn5_1(self.conv5_1(fea)))
|
||||
|
||||
fea = fea.contiguous().view(fea.size(0), -1)
|
||||
fea = self.lrelu(self.linear1(fea))
|
||||
out = self.linear2(fea)
|
||||
|
|
|
@ -95,7 +95,7 @@ def define_D(opt):
|
|||
which_model = opt_net['which_model_D']
|
||||
|
||||
if which_model == 'discriminator_vgg_128':
|
||||
netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128)
|
||||
netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz // 128, extra_conv=opt_net['extra_conv'])
|
||||
elif which_model == 'discriminator_resnet':
|
||||
netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
|
||||
elif which_model == 'discriminator_resnet_passthrough':
|
||||
|
|
Loading…
Reference in New Issue
Block a user