Make resnet w/ BN discriminator use leaky relus
This commit is contained in:
parent
3781ea725c
commit
bf634fc9fa
|
@ -1,6 +1,7 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
__all__ = ['ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202']
|
||||
|
@ -20,7 +21,7 @@ class BasicBlock(nn.Module):
|
|||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.downsample = downsample
|
||||
|
@ -30,7 +31,7 @@ class BasicBlock(nn.Module):
|
|||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
out = self.lrelu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
@ -40,7 +41,7 @@ class BasicBlock(nn.Module):
|
|||
identity = torch.cat((identity, torch.zeros_like(identity)), 1)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
out = self.lrelu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
@ -53,10 +54,12 @@ class ResNet(nn.Module):
|
|||
self.inplanes = num_filters
|
||||
self.conv1 = conv3x3(3, num_filters)
|
||||
self.bn1 = nn.BatchNorm2d(num_filters)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
self.layer1 = self._make_layer(block, num_filters, layers[0])
|
||||
self.layer2 = self._make_layer(block, num_filters * 2, layers[1], stride=2)
|
||||
self.skip_conv1 = conv3x3(3, num_filters*2)
|
||||
self.layer3 = self._make_layer(block, num_filters * 4, layers[2], stride=2)
|
||||
self.skip_conv2 = conv3x3(3, num_filters*4)
|
||||
self.layer4 = self._make_layer(block, num_filters * 8, layers[2], stride=2)
|
||||
self.fc1 = nn.Linear(num_filters * 8 * 8 * 8, 64, bias=True)
|
||||
self.fc2 = nn.Linear(64, num_classes)
|
||||
|
@ -91,18 +94,26 @@ class ResNet(nn.Module):
|
|||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, gen_skips=None):
|
||||
x_dim = x.size(-1)
|
||||
if gen_skips is None:
|
||||
gen_skips = {
|
||||
int(x_dim/2): F.interpolate(x, scale_factor=1/2, mode='bilinear', align_corners=False),
|
||||
int(x_dim/4): F.interpolate(x, scale_factor=1/4, mode='bilinear', align_corners=False),
|
||||
}
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.lrelu(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = (x + self.skip_conv1(gen_skips[int(x_dim/2)])) / 2
|
||||
x = self.layer3(x)
|
||||
x = (x + self.skip_conv2(gen_skips[int(x_dim/4)])) / 2
|
||||
x = self.layer4(x)
|
||||
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.relu(self.fc1(x))
|
||||
x = self.lrelu(self.fc1(x))
|
||||
x = self.fc2(x)
|
||||
|
||||
return x
|
||||
|
|
Loading…
Reference in New Issue
Block a user