diff --git a/codes/models/archs/SRG1_arch.py b/codes/models/archs/SRG1_arch.py index 1224110d..5350cfbc 100644 --- a/codes/models/archs/SRG1_arch.py +++ b/codes/models/archs/SRG1_arch.py @@ -3,7 +3,7 @@ from torch import nn from switched_conv import BareConvSwitch, compute_attention_specificity import torch.nn.functional as F import functools -from models.archs.arch_util import initialize_weights, ConvBnRelu, ConvBnLelu +from models.archs.arch_util import initialize_weights, ConvBnRelu, ConvBnLelu, ConvBnSilu from switched_conv_util import save_attention_to_image @@ -32,8 +32,8 @@ class MultiConvBlock(nn.Module): class HalvingProcessingBlock(nn.Module): def __init__(self, filters): super(HalvingProcessingBlock, self).__init__() - self.bnconv1 = ConvBnLelu(filters, filters * 2, stride=2, bn=False, bias=False) - self.bnconv2 = ConvBnLelu(filters * 2, filters * 2, bn=True, bias=False) + self.bnconv1 = ConvBnSilu(filters, filters * 2, stride=2, bn=False, bias=False) + self.bnconv2 = ConvBnSilu(filters * 2, filters * 2, bn=True, bias=False) def forward(self, x): x = self.bnconv1(x) return self.bnconv2(x) @@ -45,7 +45,7 @@ def create_sequential_growing_processing_block(filters_init, filter_growth, num_ convs = [] current_filters = filters_init for i in range(num_convs): - convs.append(ConvBnRelu(current_filters, current_filters + filter_growth, bn=True, bias=False)) + convs.append(ConvBnSilu(current_filters, current_filters + filter_growth, bn=True, bias=False)) current_filters += filter_growth return nn.Sequential(*convs), current_filters diff --git a/codes/models/archs/arch_util.py b/codes/models/archs/arch_util.py index f4ed27b4..5c7efbd0 100644 --- a/codes/models/archs/arch_util.py +++ b/codes/models/archs/arch_util.py @@ -142,6 +142,45 @@ class PixelUnshuffle(nn.Module): return x +# simply define a silu function +def silu(input): + ''' + Applies the Sigmoid Linear Unit (SiLU) function element-wise: + SiLU(x) = x * sigmoid(x) + ''' + return input * torch.sigmoid(input) # use torch.sigmoid to make sure that we created the most efficient implemetation based on builtin PyTorch functions + +# create a class wrapper from PyTorch nn.Module, so +# the function now can be easily used in models +class SiLU(nn.Module): + ''' + Applies the Sigmoid Linear Unit (SiLU) function element-wise: + SiLU(x) = x * sigmoid(x) + Shape: + - Input: (N, *) where * means, any number of additional + dimensions + - Output: (N, *), same shape as the input + References: + - Related paper: + https://arxiv.org/pdf/1606.08415.pdf + Examples: + >>> m = silu() + >>> input = torch.randn(2) + >>> output = m(input) + ''' + def __init__(self): + ''' + Init method. + ''' + super().__init__() # init the base class + + def forward(self, input): + ''' + Forward pass of the function. + ''' + return silu(input) # simply apply already implemented SiLU + + ''' Convenience class with Conv->BN->ReLU. Includes weight initialization and auto-padding for standard kernel sizes. ''' class ConvBnRelu(nn.Module): @@ -176,6 +215,42 @@ class ConvBnRelu(nn.Module): else: return x + +''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard + kernel sizes. ''' +class ConvBnSilu(nn.Module): + def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, silu=True, bn=True, bias=True): + super(ConvBnSilu, self).__init__() + padding_map = {1: 0, 3: 1, 5: 2, 7: 3} + assert kernel_size in padding_map.keys() + self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias) + if bn: + self.bn = nn.BatchNorm2d(filters_out) + else: + self.bn = None + if silu: + self.silu = SiLU() + else: + self.silu = None + + # Init params. + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.silu else 'linear') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.conv(x) + if self.bn: + x = self.bn(x) + if self.silu: + return self.silu(x) + else: + return x + + ''' Convenience class with Conv->BN->LeakyReLU. Includes weight initialization and auto-padding for standard kernel sizes. ''' class ConvBnLelu(nn.Module):