From 05aafef938e85f0243685cd6f51a03559d0020c4 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 22 Apr 2020 00:39:55 -0600 Subject: [PATCH] Support variant input sizes and scales --- codes/models/archs/RRDBNet_arch.py | 8 +++++--- codes/models/archs/discriminator_vgg_arch.py | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/codes/models/archs/RRDBNet_arch.py b/codes/models/archs/RRDBNet_arch.py index 9d61256c..4025726f 100644 --- a/codes/models/archs/RRDBNet_arch.py +++ b/codes/models/archs/RRDBNet_arch.py @@ -46,7 +46,7 @@ class RRDB(nn.Module): class RRDBNet(nn.Module): - def __init__(self, in_nc, out_nc, nf, nb, gc=32): + def __init__(self, in_nc, out_nc, nf, nb, gc=32, interpolation_scale_factor=2): super(RRDBNet, self).__init__() RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) @@ -61,13 +61,15 @@ class RRDBNet(nn.Module): self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + self.interpolation_scale_factor = interpolation_scale_factor + def forward(self, x): fea = self.conv_first(x) trunk = self.trunk_conv(self.RRDB_trunk(fea)) fea = fea + trunk - fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) - fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) + fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=self.interpolation_scale_factor, mode='nearest'))) + fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=self.interpolation_scale_factor, mode='nearest'))) out = self.conv_last(self.lrelu(self.HRconv(fea))) return out diff --git a/codes/models/archs/discriminator_vgg_arch.py b/codes/models/archs/discriminator_vgg_arch.py index 006974ec..ae51ba16 100644 --- a/codes/models/archs/discriminator_vgg_arch.py +++ b/codes/models/archs/discriminator_vgg_arch.py @@ -32,7 +32,7 @@ class Discriminator_VGG_128(nn.Module): self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True) - self.linear1 = nn.Linear(512 * 4 * input_img_factor * 4 * input_img_factor, 100) + self.linear1 = nn.Linear(int(512 * 4 * input_img_factor * 4 * input_img_factor), 100) self.linear2 = nn.Linear(100, 1) # activation function