Add ConvBnSilu to replace ConvBnRelu

Relu produced good performance gains over LeakyRelu, but
GAN performance degraded significantly. Try SiLU as an alternative
to see if it's the leaky-ness we are looking for or the smooth activation
curvature.
This commit is contained in:
James Betker 2020-07-05 13:39:08 -06:00
parent 9934e5d082
commit 10f7e49214
2 changed files with 79 additions and 4 deletions

View File

@ -3,7 +3,7 @@ from torch import nn
from switched_conv import BareConvSwitch, compute_attention_specificity
import torch.nn.functional as F
import functools
from models.archs.arch_util import initialize_weights, ConvBnRelu, ConvBnLelu
from models.archs.arch_util import initialize_weights, ConvBnRelu, ConvBnLelu, ConvBnSilu
from switched_conv_util import save_attention_to_image
@ -32,8 +32,8 @@ class MultiConvBlock(nn.Module):
class HalvingProcessingBlock(nn.Module):
def __init__(self, filters):
super(HalvingProcessingBlock, self).__init__()
self.bnconv1 = ConvBnLelu(filters, filters * 2, stride=2, bn=False, bias=False)
self.bnconv2 = ConvBnLelu(filters * 2, filters * 2, bn=True, bias=False)
self.bnconv1 = ConvBnSilu(filters, filters * 2, stride=2, bn=False, bias=False)
self.bnconv2 = ConvBnSilu(filters * 2, filters * 2, bn=True, bias=False)
def forward(self, x):
x = self.bnconv1(x)
return self.bnconv2(x)
@ -45,7 +45,7 @@ def create_sequential_growing_processing_block(filters_init, filter_growth, num_
convs = []
current_filters = filters_init
for i in range(num_convs):
convs.append(ConvBnRelu(current_filters, current_filters + filter_growth, bn=True, bias=False))
convs.append(ConvBnSilu(current_filters, current_filters + filter_growth, bn=True, bias=False))
current_filters += filter_growth
return nn.Sequential(*convs), current_filters

View File

@ -142,6 +142,45 @@ class PixelUnshuffle(nn.Module):
return x
# simply define a silu function
def silu(input):
'''
Applies the Sigmoid Linear Unit (SiLU) function element-wise:
SiLU(x) = x * sigmoid(x)
'''
return input * torch.sigmoid(input) # use torch.sigmoid to make sure that we created the most efficient implemetation based on builtin PyTorch functions
# create a class wrapper from PyTorch nn.Module, so
# the function now can be easily used in models
class SiLU(nn.Module):
'''
Applies the Sigmoid Linear Unit (SiLU) function element-wise:
SiLU(x) = x * sigmoid(x)
Shape:
- Input: (N, *) where * means, any number of additional
dimensions
- Output: (N, *), same shape as the input
References:
- Related paper:
https://arxiv.org/pdf/1606.08415.pdf
Examples:
>>> m = silu()
>>> input = torch.randn(2)
>>> output = m(input)
'''
def __init__(self):
'''
Init method.
'''
super().__init__() # init the base class
def forward(self, input):
'''
Forward pass of the function.
'''
return silu(input) # simply apply already implemented SiLU
''' Convenience class with Conv->BN->ReLU. Includes weight initialization and auto-padding for standard
kernel sizes. '''
class ConvBnRelu(nn.Module):
@ -176,6 +215,42 @@ class ConvBnRelu(nn.Module):
else:
return x
''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard
kernel sizes. '''
class ConvBnSilu(nn.Module):
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, silu=True, bn=True, bias=True):
super(ConvBnSilu, self).__init__()
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
assert kernel_size in padding_map.keys()
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
if bn:
self.bn = nn.BatchNorm2d(filters_out)
else:
self.bn = None
if silu:
self.silu = SiLU()
else:
self.silu = None
# Init params.
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.silu else 'linear')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.conv(x)
if self.bn:
x = self.bn(x)
if self.silu:
return self.silu(x)
else:
return x
''' Convenience class with Conv->BN->LeakyReLU. Includes weight initialization and auto-padding for standard
kernel sizes. '''
class ConvBnLelu(nn.Module):