diff --git a/codes/models/archs/DiscriminatorResnetBN_arch.py b/codes/models/archs/DiscriminatorResnetBN_arch.py index ccd25432..395d438b 100644 --- a/codes/models/archs/DiscriminatorResnetBN_arch.py +++ b/codes/models/archs/DiscriminatorResnetBN_arch.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn import numpy as np +import torch.nn.functional as F __all__ = ['ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202'] @@ -20,7 +21,7 @@ class BasicBlock(nn.Module): # Both self.conv1 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = nn.BatchNorm2d(planes) - self.relu = nn.ReLU(inplace=True) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) self.conv2 = conv3x3(planes, planes) self.bn2 = nn.BatchNorm2d(planes) self.downsample = downsample @@ -30,7 +31,7 @@ class BasicBlock(nn.Module): out = self.conv1(x) out = self.bn1(out) - out = self.relu(out) + out = self.lrelu(out) out = self.conv2(out) out = self.bn2(out) @@ -40,7 +41,7 @@ class BasicBlock(nn.Module): identity = torch.cat((identity, torch.zeros_like(identity)), 1) out += identity - out = self.relu(out) + out = self.lrelu(out) return out @@ -53,10 +54,12 @@ class ResNet(nn.Module): self.inplanes = num_filters self.conv1 = conv3x3(3, num_filters) self.bn1 = nn.BatchNorm2d(num_filters) - self.relu = nn.ReLU(inplace=True) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) self.layer1 = self._make_layer(block, num_filters, layers[0]) self.layer2 = self._make_layer(block, num_filters * 2, layers[1], stride=2) + self.skip_conv1 = conv3x3(3, num_filters*2) self.layer3 = self._make_layer(block, num_filters * 4, layers[2], stride=2) + self.skip_conv2 = conv3x3(3, num_filters*4) self.layer4 = self._make_layer(block, num_filters * 8, layers[2], stride=2) self.fc1 = nn.Linear(num_filters * 8 * 8 * 8, 64, bias=True) self.fc2 = nn.Linear(64, num_classes) @@ -91,18 +94,26 @@ class ResNet(nn.Module): return nn.Sequential(*layers) - def forward(self, x): + def forward(self, x, gen_skips=None): + x_dim = x.size(-1) + if gen_skips is None: + gen_skips = { + int(x_dim/2): F.interpolate(x, scale_factor=1/2, mode='bilinear', align_corners=False), + int(x_dim/4): F.interpolate(x, scale_factor=1/4, mode='bilinear', align_corners=False), + } x = self.conv1(x) x = self.bn1(x) - x = self.relu(x) + x = self.lrelu(x) x = self.layer1(x) x = self.layer2(x) + x = (x + self.skip_conv1(gen_skips[int(x_dim/2)])) / 2 x = self.layer3(x) + x = (x + self.skip_conv2(gen_skips[int(x_dim/4)])) / 2 x = self.layer4(x) x = x.view(x.size(0), -1) - x = self.relu(self.fc1(x)) + x = self.lrelu(self.fc1(x)) x = self.fc2(x) return x