Add ConvBnSilu to replace ConvBnRelu
Relu produced good performance gains over LeakyRelu, but GAN performance degraded significantly. Try SiLU as an alternative to see if it's the leaky-ness we are looking for or the smooth activation curvature.
This commit is contained in:
parent
9934e5d082
commit
10f7e49214
|
@ -3,7 +3,7 @@ from torch import nn
|
||||||
from switched_conv import BareConvSwitch, compute_attention_specificity
|
from switched_conv import BareConvSwitch, compute_attention_specificity
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import functools
|
import functools
|
||||||
from models.archs.arch_util import initialize_weights, ConvBnRelu, ConvBnLelu
|
from models.archs.arch_util import initialize_weights, ConvBnRelu, ConvBnLelu, ConvBnSilu
|
||||||
from switched_conv_util import save_attention_to_image
|
from switched_conv_util import save_attention_to_image
|
||||||
|
|
||||||
|
|
||||||
|
@ -32,8 +32,8 @@ class MultiConvBlock(nn.Module):
|
||||||
class HalvingProcessingBlock(nn.Module):
|
class HalvingProcessingBlock(nn.Module):
|
||||||
def __init__(self, filters):
|
def __init__(self, filters):
|
||||||
super(HalvingProcessingBlock, self).__init__()
|
super(HalvingProcessingBlock, self).__init__()
|
||||||
self.bnconv1 = ConvBnLelu(filters, filters * 2, stride=2, bn=False, bias=False)
|
self.bnconv1 = ConvBnSilu(filters, filters * 2, stride=2, bn=False, bias=False)
|
||||||
self.bnconv2 = ConvBnLelu(filters * 2, filters * 2, bn=True, bias=False)
|
self.bnconv2 = ConvBnSilu(filters * 2, filters * 2, bn=True, bias=False)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.bnconv1(x)
|
x = self.bnconv1(x)
|
||||||
return self.bnconv2(x)
|
return self.bnconv2(x)
|
||||||
|
@ -45,7 +45,7 @@ def create_sequential_growing_processing_block(filters_init, filter_growth, num_
|
||||||
convs = []
|
convs = []
|
||||||
current_filters = filters_init
|
current_filters = filters_init
|
||||||
for i in range(num_convs):
|
for i in range(num_convs):
|
||||||
convs.append(ConvBnRelu(current_filters, current_filters + filter_growth, bn=True, bias=False))
|
convs.append(ConvBnSilu(current_filters, current_filters + filter_growth, bn=True, bias=False))
|
||||||
current_filters += filter_growth
|
current_filters += filter_growth
|
||||||
return nn.Sequential(*convs), current_filters
|
return nn.Sequential(*convs), current_filters
|
||||||
|
|
||||||
|
|
|
@ -142,6 +142,45 @@ class PixelUnshuffle(nn.Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# simply define a silu function
|
||||||
|
def silu(input):
|
||||||
|
'''
|
||||||
|
Applies the Sigmoid Linear Unit (SiLU) function element-wise:
|
||||||
|
SiLU(x) = x * sigmoid(x)
|
||||||
|
'''
|
||||||
|
return input * torch.sigmoid(input) # use torch.sigmoid to make sure that we created the most efficient implemetation based on builtin PyTorch functions
|
||||||
|
|
||||||
|
# create a class wrapper from PyTorch nn.Module, so
|
||||||
|
# the function now can be easily used in models
|
||||||
|
class SiLU(nn.Module):
|
||||||
|
'''
|
||||||
|
Applies the Sigmoid Linear Unit (SiLU) function element-wise:
|
||||||
|
SiLU(x) = x * sigmoid(x)
|
||||||
|
Shape:
|
||||||
|
- Input: (N, *) where * means, any number of additional
|
||||||
|
dimensions
|
||||||
|
- Output: (N, *), same shape as the input
|
||||||
|
References:
|
||||||
|
- Related paper:
|
||||||
|
https://arxiv.org/pdf/1606.08415.pdf
|
||||||
|
Examples:
|
||||||
|
>>> m = silu()
|
||||||
|
>>> input = torch.randn(2)
|
||||||
|
>>> output = m(input)
|
||||||
|
'''
|
||||||
|
def __init__(self):
|
||||||
|
'''
|
||||||
|
Init method.
|
||||||
|
'''
|
||||||
|
super().__init__() # init the base class
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
'''
|
||||||
|
Forward pass of the function.
|
||||||
|
'''
|
||||||
|
return silu(input) # simply apply already implemented SiLU
|
||||||
|
|
||||||
|
|
||||||
''' Convenience class with Conv->BN->ReLU. Includes weight initialization and auto-padding for standard
|
''' Convenience class with Conv->BN->ReLU. Includes weight initialization and auto-padding for standard
|
||||||
kernel sizes. '''
|
kernel sizes. '''
|
||||||
class ConvBnRelu(nn.Module):
|
class ConvBnRelu(nn.Module):
|
||||||
|
@ -176,6 +215,42 @@ class ConvBnRelu(nn.Module):
|
||||||
else:
|
else:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard
|
||||||
|
kernel sizes. '''
|
||||||
|
class ConvBnSilu(nn.Module):
|
||||||
|
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, silu=True, bn=True, bias=True):
|
||||||
|
super(ConvBnSilu, self).__init__()
|
||||||
|
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
||||||
|
assert kernel_size in padding_map.keys()
|
||||||
|
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
|
||||||
|
if bn:
|
||||||
|
self.bn = nn.BatchNorm2d(filters_out)
|
||||||
|
else:
|
||||||
|
self.bn = None
|
||||||
|
if silu:
|
||||||
|
self.silu = SiLU()
|
||||||
|
else:
|
||||||
|
self.silu = None
|
||||||
|
|
||||||
|
# Init params.
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.silu else 'linear')
|
||||||
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||||
|
nn.init.constant_(m.weight, 1)
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
if self.bn:
|
||||||
|
x = self.bn(x)
|
||||||
|
if self.silu:
|
||||||
|
return self.silu(x)
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
''' Convenience class with Conv->BN->LeakyReLU. Includes weight initialization and auto-padding for standard
|
''' Convenience class with Conv->BN->LeakyReLU. Includes weight initialization and auto-padding for standard
|
||||||
kernel sizes. '''
|
kernel sizes. '''
|
||||||
class ConvBnLelu(nn.Module):
|
class ConvBnLelu(nn.Module):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user