diff --git a/codes/models/archs/arch_util.py b/codes/models/archs/arch_util.py index f79d4532..f4ed27b4 100644 --- a/codes/models/archs/arch_util.py +++ b/codes/models/archs/arch_util.py @@ -141,3 +141,72 @@ class PixelUnshuffle(nn.Module): x = x.permute(0, 1, 3, 5, 2, 4).contiguous().view(b, f * (self.r ** 2), w // self.r, h // self.r) return x + +''' Convenience class with Conv->BN->ReLU. Includes weight initialization and auto-padding for standard + kernel sizes. ''' +class ConvBnRelu(nn.Module): + def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, relu=True, bn=True, bias=True): + super(ConvBnRelu, self).__init__() + padding_map = {1: 0, 3: 1, 5: 2, 7: 3} + assert kernel_size in padding_map.keys() + self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias) + if bn: + self.bn = nn.BatchNorm2d(filters_out) + else: + self.bn = None + if relu: + self.relu = nn.ReLU() + else: + self.relu = None + + # Init params. + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.relu else 'linear') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.conv(x) + if self.bn: + x = self.bn(x) + if self.relu: + return self.relu(x) + else: + return x + +''' Convenience class with Conv->BN->LeakyReLU. Includes weight initialization and auto-padding for standard + kernel sizes. ''' +class ConvBnLelu(nn.Module): + def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, lelu=True, bn=True, bias=True): + super(ConvBnLelu, self).__init__() + padding_map = {1: 0, 3: 1, 5: 2, 7: 3} + assert kernel_size in padding_map.keys() + self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias) + if bn: + self.bn = nn.BatchNorm2d(filters_out) + else: + self.bn = None + if lelu: + self.lelu = nn.LeakyReLU(negative_slope=.1) + else: + self.lelu = None + + # Init params. + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, a=.1, mode='fan_out', + nonlinearity='leaky_relu' if self.lelu else 'linear') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.conv(x) + if self.bn: + x = self.bn(x) + if self.lelu: + return self.lelu(x) + else: + return x \ No newline at end of file