Add ConvBnSilu to replace ConvBnRelu
Relu produced good performance gains over LeakyRelu, but GAN performance degraded significantly. Try SiLU as an alternative to see if it's the leaky-ness we are looking for or the smooth activation curvature.
This commit is contained in:
parent
9934e5d082
commit
10f7e49214
|
@ -3,7 +3,7 @@ from torch import nn
|
|||
from switched_conv import BareConvSwitch, compute_attention_specificity
|
||||
import torch.nn.functional as F
|
||||
import functools
|
||||
from models.archs.arch_util import initialize_weights, ConvBnRelu, ConvBnLelu
|
||||
from models.archs.arch_util import initialize_weights, ConvBnRelu, ConvBnLelu, ConvBnSilu
|
||||
from switched_conv_util import save_attention_to_image
|
||||
|
||||
|
||||
|
@ -32,8 +32,8 @@ class MultiConvBlock(nn.Module):
|
|||
class HalvingProcessingBlock(nn.Module):
|
||||
def __init__(self, filters):
|
||||
super(HalvingProcessingBlock, self).__init__()
|
||||
self.bnconv1 = ConvBnLelu(filters, filters * 2, stride=2, bn=False, bias=False)
|
||||
self.bnconv2 = ConvBnLelu(filters * 2, filters * 2, bn=True, bias=False)
|
||||
self.bnconv1 = ConvBnSilu(filters, filters * 2, stride=2, bn=False, bias=False)
|
||||
self.bnconv2 = ConvBnSilu(filters * 2, filters * 2, bn=True, bias=False)
|
||||
def forward(self, x):
|
||||
x = self.bnconv1(x)
|
||||
return self.bnconv2(x)
|
||||
|
@ -45,7 +45,7 @@ def create_sequential_growing_processing_block(filters_init, filter_growth, num_
|
|||
convs = []
|
||||
current_filters = filters_init
|
||||
for i in range(num_convs):
|
||||
convs.append(ConvBnRelu(current_filters, current_filters + filter_growth, bn=True, bias=False))
|
||||
convs.append(ConvBnSilu(current_filters, current_filters + filter_growth, bn=True, bias=False))
|
||||
current_filters += filter_growth
|
||||
return nn.Sequential(*convs), current_filters
|
||||
|
||||
|
|
|
@ -142,6 +142,45 @@ class PixelUnshuffle(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
# simply define a silu function
|
||||
def silu(input):
|
||||
'''
|
||||
Applies the Sigmoid Linear Unit (SiLU) function element-wise:
|
||||
SiLU(x) = x * sigmoid(x)
|
||||
'''
|
||||
return input * torch.sigmoid(input) # use torch.sigmoid to make sure that we created the most efficient implemetation based on builtin PyTorch functions
|
||||
|
||||
# create a class wrapper from PyTorch nn.Module, so
|
||||
# the function now can be easily used in models
|
||||
class SiLU(nn.Module):
|
||||
'''
|
||||
Applies the Sigmoid Linear Unit (SiLU) function element-wise:
|
||||
SiLU(x) = x * sigmoid(x)
|
||||
Shape:
|
||||
- Input: (N, *) where * means, any number of additional
|
||||
dimensions
|
||||
- Output: (N, *), same shape as the input
|
||||
References:
|
||||
- Related paper:
|
||||
https://arxiv.org/pdf/1606.08415.pdf
|
||||
Examples:
|
||||
>>> m = silu()
|
||||
>>> input = torch.randn(2)
|
||||
>>> output = m(input)
|
||||
'''
|
||||
def __init__(self):
|
||||
'''
|
||||
Init method.
|
||||
'''
|
||||
super().__init__() # init the base class
|
||||
|
||||
def forward(self, input):
|
||||
'''
|
||||
Forward pass of the function.
|
||||
'''
|
||||
return silu(input) # simply apply already implemented SiLU
|
||||
|
||||
|
||||
''' Convenience class with Conv->BN->ReLU. Includes weight initialization and auto-padding for standard
|
||||
kernel sizes. '''
|
||||
class ConvBnRelu(nn.Module):
|
||||
|
@ -176,6 +215,42 @@ class ConvBnRelu(nn.Module):
|
|||
else:
|
||||
return x
|
||||
|
||||
|
||||
''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard
|
||||
kernel sizes. '''
|
||||
class ConvBnSilu(nn.Module):
|
||||
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, silu=True, bn=True, bias=True):
|
||||
super(ConvBnSilu, self).__init__()
|
||||
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
||||
assert kernel_size in padding_map.keys()
|
||||
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
|
||||
if bn:
|
||||
self.bn = nn.BatchNorm2d(filters_out)
|
||||
else:
|
||||
self.bn = None
|
||||
if silu:
|
||||
self.silu = SiLU()
|
||||
else:
|
||||
self.silu = None
|
||||
|
||||
# Init params.
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.silu else 'linear')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
if self.bn:
|
||||
x = self.bn(x)
|
||||
if self.silu:
|
||||
return self.silu(x)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
''' Convenience class with Conv->BN->LeakyReLU. Includes weight initialization and auto-padding for standard
|
||||
kernel sizes. '''
|
||||
class ConvBnLelu(nn.Module):
|
||||
|
|
Loading…
Reference in New Issue
Block a user