Make resnet w/ BN discriminator use leaky relus

This commit is contained in:
James Betker 2020-04-30 11:28:59 -06:00
parent 3781ea725c
commit bf634fc9fa

View File

@ -1,6 +1,7 @@
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
__all__ = ['ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202']
@ -20,7 +21,7 @@ class BasicBlock(nn.Module):
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
@ -30,7 +31,7 @@ class BasicBlock(nn.Module):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.lrelu(out)
out = self.conv2(out)
out = self.bn2(out)
@ -40,7 +41,7 @@ class BasicBlock(nn.Module):
identity = torch.cat((identity, torch.zeros_like(identity)), 1)
out += identity
out = self.relu(out)
out = self.lrelu(out)
return out
@ -53,10 +54,12 @@ class ResNet(nn.Module):
self.inplanes = num_filters
self.conv1 = conv3x3(3, num_filters)
self.bn1 = nn.BatchNorm2d(num_filters)
self.relu = nn.ReLU(inplace=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.layer1 = self._make_layer(block, num_filters, layers[0])
self.layer2 = self._make_layer(block, num_filters * 2, layers[1], stride=2)
self.skip_conv1 = conv3x3(3, num_filters*2)
self.layer3 = self._make_layer(block, num_filters * 4, layers[2], stride=2)
self.skip_conv2 = conv3x3(3, num_filters*4)
self.layer4 = self._make_layer(block, num_filters * 8, layers[2], stride=2)
self.fc1 = nn.Linear(num_filters * 8 * 8 * 8, 64, bias=True)
self.fc2 = nn.Linear(64, num_classes)
@ -91,18 +94,26 @@ class ResNet(nn.Module):
return nn.Sequential(*layers)
def forward(self, x):
def forward(self, x, gen_skips=None):
x_dim = x.size(-1)
if gen_skips is None:
gen_skips = {
int(x_dim/2): F.interpolate(x, scale_factor=1/2, mode='bilinear', align_corners=False),
int(x_dim/4): F.interpolate(x, scale_factor=1/4, mode='bilinear', align_corners=False),
}
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.lrelu(x)
x = self.layer1(x)
x = self.layer2(x)
x = (x + self.skip_conv1(gen_skips[int(x_dim/2)])) / 2
x = self.layer3(x)
x = (x + self.skip_conv2(gen_skips[int(x_dim/4)])) / 2
x = self.layer4(x)
x = x.view(x.size(0), -1)
x = self.relu(self.fc1(x))
x = self.lrelu(self.fc1(x))
x = self.fc2(x)
return x