forked from mrq/DL-Art-School
136 lines
5.0 KiB
Python
136 lines
5.0 KiB
Python
|
import math
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
from torch import Tensor
|
||
|
from torch.nn import Parameter, init
|
||
|
from torch.nn.modules.conv import _ConvNd
|
||
|
from torch.nn.modules.utils import _ntuple
|
||
|
|
||
|
_pair = _ntuple(2)
|
||
|
|
||
|
class TransferConv2d(_ConvNd):
|
||
|
def __init__(
|
||
|
self,
|
||
|
in_channels: int,
|
||
|
out_channels: int,
|
||
|
kernel_size,
|
||
|
stride = 1,
|
||
|
padding = 0,
|
||
|
dilation = 1,
|
||
|
groups: int = 1,
|
||
|
bias: bool = True,
|
||
|
padding_mode: str = 'zeros',
|
||
|
transfer_mode: bool = False
|
||
|
):
|
||
|
kernel_size = _pair(kernel_size)
|
||
|
stride = _pair(stride)
|
||
|
padding = _pair(padding)
|
||
|
dilation = _pair(dilation)
|
||
|
super().__init__(
|
||
|
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
||
|
False, _pair(0), groups, bias, padding_mode)
|
||
|
|
||
|
self.transfer_mode = transfer_mode
|
||
|
if transfer_mode:
|
||
|
self.transfer_scale = nn.Parameter(torch.ones(out_channels, in_channels, 1, 1))
|
||
|
self.transfer_scale.FOR_TRANSFER_LEARNING = True
|
||
|
self.transfer_shift = nn.Parameter(torch.zeros(out_channels, in_channels, 1, 1))
|
||
|
self.transfer_shift.FOR_TRANSFER_LEARNING = True
|
||
|
|
||
|
def _conv_forward(self, input, weight):
|
||
|
if self.transfer_mode:
|
||
|
weight = weight * self.transfer_scale + self.transfer_shift
|
||
|
else:
|
||
|
weight = weight
|
||
|
|
||
|
if self.padding_mode != 'zeros':
|
||
|
return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
|
||
|
weight, self.bias, self.stride,
|
||
|
_pair(0), self.dilation, self.groups)
|
||
|
return F.conv2d(input, weight, self.bias, self.stride,
|
||
|
self.padding, self.dilation, self.groups)
|
||
|
|
||
|
def forward(self, input: Tensor) -> Tensor:
|
||
|
return self._conv_forward(input, self.weight)
|
||
|
|
||
|
|
||
|
class TransferLinear(nn.Module):
|
||
|
__constants__ = ['in_features', 'out_features']
|
||
|
in_features: int
|
||
|
out_features: int
|
||
|
weight: Tensor
|
||
|
|
||
|
def __init__(self, in_features: int, out_features: int, bias: bool = True, transfer_mode: bool = False) -> None:
|
||
|
super().__init__()
|
||
|
self.in_features = in_features
|
||
|
self.out_features = out_features
|
||
|
self.weight = Parameter(torch.Tensor(out_features, in_features))
|
||
|
if bias:
|
||
|
self.bias = Parameter(torch.Tensor(out_features))
|
||
|
else:
|
||
|
self.register_parameter('bias', None)
|
||
|
self.reset_parameters()
|
||
|
self.transfer_mode = transfer_mode
|
||
|
if transfer_mode:
|
||
|
self.transfer_scale = nn.Parameter(torch.ones(out_features, in_features))
|
||
|
self.transfer_scale.FOR_TRANSFER_LEARNING = True
|
||
|
self.transfer_shift = nn.Parameter(torch.zeros(out_features, in_features))
|
||
|
self.transfer_shift.FOR_TRANSFER_LEARNING = True
|
||
|
|
||
|
def reset_parameters(self) -> None:
|
||
|
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||
|
if self.bias is not None:
|
||
|
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
|
||
|
bound = 1 / math.sqrt(fan_in)
|
||
|
init.uniform_(self.bias, -bound, bound)
|
||
|
|
||
|
def forward(self, input: Tensor) -> Tensor:
|
||
|
if self.transfer_mode:
|
||
|
weight = self.weight * self.transfer_scale + self.transfer_shift
|
||
|
else:
|
||
|
weight = self.weight
|
||
|
return F.linear(input, weight, self.bias)
|
||
|
|
||
|
def extra_repr(self) -> str:
|
||
|
return 'in_features={}, out_features={}, bias={}'.format(
|
||
|
self.in_features, self.out_features, self.bias is not None
|
||
|
)
|
||
|
|
||
|
|
||
|
class TransferConvGnLelu(nn.Module):
|
||
|
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, num_groups=8, weight_init_factor=1, transfer_mode=False):
|
||
|
super().__init__()
|
||
|
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
||
|
assert kernel_size in padding_map.keys()
|
||
|
self.conv = TransferConv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias, transfer_mode=transfer_mode)
|
||
|
if norm:
|
||
|
self.gn = nn.GroupNorm(num_groups, filters_out)
|
||
|
else:
|
||
|
self.gn = None
|
||
|
if activation:
|
||
|
self.lelu = nn.LeakyReLU(negative_slope=.2)
|
||
|
else:
|
||
|
self.lelu = None
|
||
|
|
||
|
# Init params.
|
||
|
for m in self.modules():
|
||
|
if isinstance(m, TransferConv2d):
|
||
|
nn.init.kaiming_normal_(m.weight, a=.1, mode='fan_out',
|
||
|
nonlinearity='leaky_relu' if self.lelu else 'linear')
|
||
|
m.weight.data *= weight_init_factor
|
||
|
if m.bias is not None:
|
||
|
m.bias.data.zero_()
|
||
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||
|
nn.init.constant_(m.weight, 1)
|
||
|
nn.init.constant_(m.bias, 0)
|
||
|
|
||
|
def forward(self, x):
|
||
|
x = self.conv(x)
|
||
|
if self.gn:
|
||
|
x = self.gn(x)
|
||
|
if self.lelu:
|
||
|
return self.lelu(x)
|
||
|
else:
|
||
|
return x
|