DL-Art-School/codes/models/styled_sr/transfer_primitives.py
James Betker ade2732c82 Transfer learning for styleSR
This is a concept from "Lifelong Learning GAN", although I'm skeptical of it's novelty -
basically you scale and shift the weights for the generator and discriminator of a pretrained
GAN to "shift" into new modalities, e.g. faces->birds or whatever. There are some interesting
applications of this that I would like to try out.
2021-01-04 20:10:48 -07:00

136 lines
5.0 KiB
Python

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Parameter, init
from torch.nn.modules.conv import _ConvNd
from torch.nn.modules.utils import _ntuple
_pair = _ntuple(2)
class TransferConv2d(_ConvNd):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size,
stride = 1,
padding = 0,
dilation = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
transfer_mode: bool = False
):
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
super().__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _pair(0), groups, bias, padding_mode)
self.transfer_mode = transfer_mode
if transfer_mode:
self.transfer_scale = nn.Parameter(torch.ones(out_channels, in_channels, 1, 1))
self.transfer_scale.FOR_TRANSFER_LEARNING = True
self.transfer_shift = nn.Parameter(torch.zeros(out_channels, in_channels, 1, 1))
self.transfer_shift.FOR_TRANSFER_LEARNING = True
def _conv_forward(self, input, weight):
if self.transfer_mode:
weight = weight * self.transfer_scale + self.transfer_shift
else:
weight = weight
if self.padding_mode != 'zeros':
return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
weight, self.bias, self.stride,
_pair(0), self.dilation, self.groups)
return F.conv2d(input, weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
def forward(self, input: Tensor) -> Tensor:
return self._conv_forward(input, self.weight)
class TransferLinear(nn.Module):
__constants__ = ['in_features', 'out_features']
in_features: int
out_features: int
weight: Tensor
def __init__(self, in_features: int, out_features: int, bias: bool = True, transfer_mode: bool = False) -> None:
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.Tensor(out_features, in_features))
if bias:
self.bias = Parameter(torch.Tensor(out_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
self.transfer_mode = transfer_mode
if transfer_mode:
self.transfer_scale = nn.Parameter(torch.ones(out_features, in_features))
self.transfer_scale.FOR_TRANSFER_LEARNING = True
self.transfer_shift = nn.Parameter(torch.zeros(out_features, in_features))
self.transfer_shift.FOR_TRANSFER_LEARNING = True
def reset_parameters(self) -> None:
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
def forward(self, input: Tensor) -> Tensor:
if self.transfer_mode:
weight = self.weight * self.transfer_scale + self.transfer_shift
else:
weight = self.weight
return F.linear(input, weight, self.bias)
def extra_repr(self) -> str:
return 'in_features={}, out_features={}, bias={}'.format(
self.in_features, self.out_features, self.bias is not None
)
class TransferConvGnLelu(nn.Module):
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, num_groups=8, weight_init_factor=1, transfer_mode=False):
super().__init__()
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
assert kernel_size in padding_map.keys()
self.conv = TransferConv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias, transfer_mode=transfer_mode)
if norm:
self.gn = nn.GroupNorm(num_groups, filters_out)
else:
self.gn = None
if activation:
self.lelu = nn.LeakyReLU(negative_slope=.2)
else:
self.lelu = None
# Init params.
for m in self.modules():
if isinstance(m, TransferConv2d):
nn.init.kaiming_normal_(m.weight, a=.1, mode='fan_out',
nonlinearity='leaky_relu' if self.lelu else 'linear')
m.weight.data *= weight_init_factor
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.conv(x)
if self.gn:
x = self.gn(x)
if self.lelu:
return self.lelu(x)
else:
return x