Transfer learning for styleSR

This is a concept from "Lifelong Learning GAN", although I'm skeptical of it's novelty -
basically you scale and shift the weights for the generator and discriminator of a pretrained
GAN to "shift" into new modalities, e.g. faces->birds or whatever. There are some interesting
applications of this that I would like to try out.
This commit is contained in:
James Betker 2021-01-04 20:10:48 -07:00
parent 2c65b6b28e
commit ade2732c82
6 changed files with 289 additions and 93 deletions

6
.idea/other.xml Normal file
View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PySciProjectComponent">
<option name="PY_MATPLOTLIB_IN_TOOLWINDOW" value="false" />
</component>
</project>

View File

@ -1,4 +1,6 @@
# Heavily based on the lucidrains stylegan2 discriminator implementation.
import math
import os
from functools import partial
from math import log2
from random import random
@ -6,31 +8,33 @@ from random import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.autograd import grad as torch_grad
import trainer.losses as L
from vector_quantize_pytorch import VectorQuantize
from models.styled_sr.stylegan2_base import attn_and_ff, PermuteToFrom, Blur, leaky_relu, exists
from models.styled_sr.transfer_primitives import TransferConv2d, TransferLinear
from trainer.networks import register_model
from utils.util import checkpoint, opt_get
class DiscriminatorBlock(nn.Module):
def __init__(self, input_channels, filters, downsample=True):
def __init__(self, input_channels, filters, downsample=True, transfer_mode=False):
super().__init__()
self.filters = filters
self.conv_res = nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1))
self.conv_res = TransferConv2d(input_channels, filters, 1, stride=(2 if downsample else 1), transfer_mode=transfer_mode)
self.net = nn.Sequential(
nn.Conv2d(input_channels, filters, 3, padding=1),
TransferConv2d(input_channels, filters, 3, padding=1, transfer_mode=transfer_mode),
leaky_relu(),
nn.Conv2d(filters, filters, 3, padding=1),
TransferConv2d(filters, filters, 3, padding=1, transfer_mode=transfer_mode),
leaky_relu()
)
self.downsample = nn.Sequential(
Blur(),
nn.Conv2d(filters, filters, 3, padding=1, stride=2)
TransferConv2d(filters, filters, 3, padding=1, stride=2, transfer_mode=transfer_mode)
) if downsample else None
def forward(self, x):
@ -42,9 +46,10 @@ class DiscriminatorBlock(nn.Module):
return x
class StyleGan2Discriminator(nn.Module):
class StyleSrDiscriminator(nn.Module):
def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[],
transparent=False, fmap_max=512, input_filters=3, quantize=False, do_checkpointing=False, mlp=False):
transparent=False, fmap_max=512, input_filters=3, quantize=False, do_checkpointing=False, mlp=False,
transfer_mode=False):
super().__init__()
num_layers = int(log2(image_size) - 1)
@ -63,7 +68,7 @@ class StyleGan2Discriminator(nn.Module):
num_layer = ind + 1
is_not_last = ind != (len(chan_in_out) - 1)
block = DiscriminatorBlock(in_chan, out_chan, downsample=is_not_last)
block = DiscriminatorBlock(in_chan, out_chan, downsample=is_not_last, transfer_mode=transfer_mode)
blocks.append(block)
attn_fn = attn_and_ff(out_chan) if num_layer in attn_layers else None
@ -84,17 +89,23 @@ class StyleGan2Discriminator(nn.Module):
chan_last = filters[-1]
latent_dim = 2 * 2 * chan_last
self.final_conv = nn.Conv2d(chan_last, chan_last, 3, padding=1)
self.final_conv = TransferConv2d(chan_last, chan_last, 3, padding=1, transfer_mode=transfer_mode)
self.flatten = nn.Flatten()
if mlp:
self.to_logit = nn.Sequential(nn.Linear(latent_dim, 100),
self.to_logit = nn.Sequential(TransferLinear(latent_dim, 100, transfer_mode=transfer_mode),
leaky_relu(),
nn.Linear(100, 1))
TransferLinear(100, 1, transfer_mode=transfer_mode))
else:
self.to_logit = nn.Linear(latent_dim, 1)
self.to_logit = TransferLinear(latent_dim, 1, transfer_mode=transfer_mode)
self._init_weights()
self.transfer_mode = transfer_mode
if transfer_mode:
for p in self.parameters():
if not hasattr(p, 'FOR_TRANSFER_LEARNING'):
p.DO_NOT_TRAIN = True
def forward(self, x):
b, *_ = x.shape
@ -123,12 +134,12 @@ class StyleGan2Discriminator(nn.Module):
def _init_weights(self):
for m in self.modules():
if type(m) in {nn.Conv2d, nn.Linear}:
if type(m) in {TransferConv2d, TransferLinear}:
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
# Configures the network as partially pre-trained. This means:
# 1) The top (high-resolution) `num_blocks` will have their weights re-initialized.
# 2) The haed (linear layers) will also have their weights re-initialized
# 2) The head (linear layers) will also have their weights re-initialized
# 3) All intermediate blocks will be frozen until step `frozen_until_step`
# These settings will be applied after the weights have been loaded (network_loaded())
def configure_partial_training(self, bypass_blocks=0, num_blocks=2, frozen_until_step=0):
@ -150,7 +161,7 @@ class StyleGan2Discriminator(nn.Module):
reset_blocks.append(self.blocks[i])
for bl in reset_blocks:
for m in bl.modules():
if type(m) in {nn.Conv2d, nn.Linear}:
if type(m) in {TransferConv2d, TransferLinear}:
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
for p in m.parameters(recurse=True):
p._NEW_BLOCK = True
@ -237,13 +248,15 @@ class DiscAugmentor(nn.Module):
self.prob = prob
self.types = types
def forward(self, images, detach=False):
def forward(self, images, real_images=False):
if random() < self.prob:
images = random_hflip(images, prob=0.5)
images = DiffAugment(images, types=self.types)
if detach:
images = images.detach()
if real_images:
self.hq_aug = images.detach().clone()
else:
self.gen_aug = images.detach().clone()
# Save away for use elsewhere (e.g. unet loss)
self.aug_images = images
@ -253,6 +266,11 @@ class DiscAugmentor(nn.Module):
def network_loaded(self):
self.D.network_loaded()
# Allows visualizing what the augmentor is up to.
def visual_dbg(self, step, path):
torchvision.utils.save_image(self.gen_aug, os.path.join(path, "%i_gen_aug.png" % (step)))
torchvision.utils.save_image(self.hq_aug, os.path.join(path, "%i_hq_aug.png" % (step)))
def loss_backwards(fp16, loss, optimizer, loss_id, **kwargs):
if fp16:
@ -294,12 +312,12 @@ class StyleSrGanDivergenceLoss(L.ConfigurableLoss):
real_input = real_input + torch.rand_like(real_input) * self.noise
D = self.env['discriminators'][self.discriminator]
fake = D(fake_input)
fake = D(fake_input, real_images=False)
if self.for_gen:
return fake.mean()
else:
real_input.requires_grad_() # <-- Needed to compute gradients on the input.
real = D(real_input)
real = D(real_input, real_images=True)
divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean()
# Apply gradient penalty. TODO: migrate this elsewhere.
@ -315,10 +333,12 @@ class StyleSrGanDivergenceLoss(L.ConfigurableLoss):
@register_model
def register_styledsr_discriminator(opt_net, opt):
attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
disc = StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn,
do_checkpointing=opt_get(opt_net, ['do_checkpointing'], False),
quantize=opt_get(opt_net, ['quantize'], False),
mlp=opt_get(opt_net, ['mlp_head'], True))
disc = StyleSrDiscriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn,
do_checkpointing=opt_get(opt_net, ['do_checkpointing'], False),
quantize=opt_get(opt_net, ['quantize'], False),
mlp=opt_get(opt_net, ['mlp_head'], True),
transfer_mode=opt_get(opt_net, ['transfer_mode'], False)
)
if 'use_partial_pretrained' in opt_net.keys():
disc.configure_partial_training(opt_net['bypass_blocks'], opt_net['partial_training_blocks'], opt_net['intermediate_blocks_frozen_until'])
return DiscAugmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])

View File

@ -1,29 +1,37 @@
from math import log2
from random import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.RRDBNet_arch import RRDB
from models.arch_util import ConvGnLelu, default_init_weights
from models.styled_sr.stylegan2_base import StyleVectorizer, GeneratorBlock, Conv2DMod, leaky_relu, Blur
from models.arch_util import kaiming_init
from models.styled_sr.stylegan2_base import StyleVectorizer, GeneratorBlock
from models.styled_sr.transfer_primitives import TransferConvGnLelu, TransferConv2d, TransferLinear
from trainer.networks import register_model
from utils.util import checkpoint, opt_get
def rrdb_init_weights(module, scale=1):
for m in module.modules():
if isinstance(m, TransferConv2d):
kaiming_init(m, a=0, mode='fan_in', bias=0)
m.weight.data *= scale
elif isinstance(m, TransferLinear):
kaiming_init(m, a=0, mode='fan_in', bias=0)
m.weight.data *= scale
class EncoderRRDB(nn.Module):
def __init__(self, mid_channels=64, output_channels=32, growth_channels=32, init_weight=.1):
def __init__(self, mid_channels=64, output_channels=32, growth_channels=32, init_weight=.1, transfer_mode=False):
super(EncoderRRDB, self).__init__()
for i in range(5):
out_channels = output_channels if i == 4 else growth_channels
self.add_module(
f'conv{i+1}',
nn.Conv2d(mid_channels + i * growth_channels, out_channels, 3,
1, 1))
TransferConv2d(mid_channels + i * growth_channels, out_channels, 3, 1, 1, transfer_mode=transfer_mode))
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
for i in range(5):
default_init_weights(getattr(self, f'conv{i+1}'), init_weight)
rrdb_init_weights(getattr(self, f'conv{i+1}'), init_weight)
def forward(self, x):
x1 = self.lrelu(self.conv1(x))
@ -35,18 +43,18 @@ class EncoderRRDB(nn.Module):
class StyledSrEncoder(nn.Module):
def __init__(self, fea_out=256, initial_stride=1):
def __init__(self, fea_out=256, initial_stride=1, transfer_mode=False):
super().__init__()
# Current assumes fea_out=256.
self.initial_conv = ConvGnLelu(3, 32, kernel_size=7, stride=initial_stride, norm=False, activation=False, bias=True)
self.initial_conv = TransferConvGnLelu(3, 32, kernel_size=7, stride=initial_stride, norm=False, activation=False, bias=True, transfer_mode=transfer_mode)
self.rrdbs = nn.ModuleList([
EncoderRRDB(32),
EncoderRRDB(64),
EncoderRRDB(96),
EncoderRRDB(128),
EncoderRRDB(160),
EncoderRRDB(192),
EncoderRRDB(224)])
EncoderRRDB(32, transfer_mode=transfer_mode),
EncoderRRDB(64, transfer_mode=transfer_mode),
EncoderRRDB(96, transfer_mode=transfer_mode),
EncoderRRDB(128, transfer_mode=transfer_mode),
EncoderRRDB(160, transfer_mode=transfer_mode),
EncoderRRDB(192, transfer_mode=transfer_mode),
EncoderRRDB(224, transfer_mode=transfer_mode)])
def forward(self, x):
fea = self.initial_conv(x)
@ -56,13 +64,14 @@ class StyledSrEncoder(nn.Module):
class Generator(nn.Module):
def __init__(self, image_size, latent_dim, initial_stride=1, start_level=3, upsample_levels=2):
def __init__(self, image_size, latent_dim, initial_stride=1, start_level=3, upsample_levels=2, transfer_mode=False):
super().__init__()
total_levels = upsample_levels + 1 # The first level handles the raw encoder output and doesn't upsample.
self.image_size = image_size
self.scale = 2 ** upsample_levels
self.latent_dim = latent_dim
self.num_layers = total_levels
self.transfer_mode = transfer_mode
filters = [
512, # 4x4
512, # 8x8
@ -75,7 +84,8 @@ class Generator(nn.Module):
8, # 1024x1024
]
self.encoder = StyledSrEncoder(filters[start_level], initial_stride)
# I'm making a guess here that the encoder does not need transfer learning, hence fixed transfer_mode=False. This should be vetted.
self.encoder = StyledSrEncoder(filters[start_level], initial_stride, transfer_mode=False)
in_out_pairs = list(zip(filters[:-1], filters[1:]))
self.blocks = nn.ModuleList([])
@ -88,13 +98,18 @@ class Generator(nn.Module):
in_chan,
out_chan,
upsample=not_first,
upsample_rgb=not_last
upsample_rgb=not_last,
transfer_learning_mode=transfer_mode
)
self.blocks.append(block)
def forward(self, lr, styles):
b, c, h, w = lr.shape
x = self.encoder(lr)
if self.transfer_mode:
with torch.no_grad():
x = self.encoder(lr)
else:
x = self.encoder(lr)
styles = styles.transpose(0, 1)
input_noise = torch.rand(b, h * self.scale, w * self.scale, 1).to(lr.device)
@ -102,6 +117,7 @@ class Generator(nn.Module):
rgb = F.interpolate(lr, size=x.shape[2:], mode="area")
else:
rgb = lr
for style, block in zip(styles, self.blocks):
x, rgb = checkpoint(block, x, rgb, style, input_noise)
@ -109,16 +125,23 @@ class Generator(nn.Module):
class StyledSrGenerator(nn.Module):
def __init__(self, image_size, initial_stride=1, latent_dim=512, style_depth=8, lr_mlp=.1):
def __init__(self, image_size, initial_stride=1, latent_dim=512, style_depth=8, lr_mlp=.1, transfer_mode=False):
super().__init__()
self.vectorizer = StyleVectorizer(latent_dim, style_depth, lr_mul=lr_mlp)
self.gen = Generator(image_size=image_size, latent_dim=latent_dim, initial_stride=initial_stride)
# Assume the vectorizer doesnt need transfer_mode=True. Re-evaluate this later.
self.vectorizer = StyleVectorizer(latent_dim, style_depth, lr_mul=lr_mlp, transfer_mode=False)
self.gen = Generator(image_size=image_size, latent_dim=latent_dim, initial_stride=initial_stride, transfer_mode=transfer_mode)
self.mixed_prob = .9
self._init_weights()
self.transfer_mode = transfer_mode
if transfer_mode:
for p in self.parameters():
if not hasattr(p, 'FOR_TRANSFER_LEARNING'):
p.DO_NOT_TRAIN = True
def _init_weights(self):
for m in self.modules():
if type(m) in {nn.Conv2d, nn.Linear} and hasattr(m, 'weight'):
if type(m) in {TransferConv2d, TransferLinear} and hasattr(m, 'weight'):
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
for block in self.gen.blocks:
@ -132,7 +155,11 @@ class StyledSrGenerator(nn.Module):
# Synthesize style latents from noise.
style = torch.randn(b*2, self.gen.latent_dim).to(x.device)
w = self.vectorizer(style)
if self.transfer_mode:
with torch.no_grad():
w = self.vectorizer(style)
else:
w = self.vectorizer(style)
# Randomly distribute styles across layers
w_styles = w[:,None,:].expand(-1, self.gen.num_layers, -1).clone()
@ -162,4 +189,6 @@ if __name__ == '__main__':
@register_model
def register_styled_sr(opt_net, opt):
return StyledSrGenerator(128, initial_stride=opt_get(opt_net, ['initial_stride'], 1))
return StyledSrGenerator(128,
initial_stride=opt_get(opt_net, ['initial_stride'], 1),
transfer_mode=opt_get(opt_net, ['transfer_mode'], False))

View File

@ -6,8 +6,12 @@ import torch
import torch.nn.functional as F
from kornia.filters import filter2D
from linear_attention_transformer import ImageLinearAttention
from torch import nn
from torch import nn, Tensor
from torch.autograd import grad as torch_grad
from torch.nn import Parameter, init
from torch.nn.modules.conv import _ConvNd
from models.styled_sr.transfer_primitives import TransferLinear
assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.'
@ -196,10 +200,9 @@ def slerp(val, low, high):
res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
return res
# stylegan2 classes
class EqualLinear(nn.Module):
def __init__(self, in_dim, out_dim, lr_mul=1, bias=True):
def __init__(self, in_dim, out_dim, lr_mul=1, bias=True, transfer_mode=False):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_dim, in_dim))
if bias:
@ -207,17 +210,28 @@ class EqualLinear(nn.Module):
self.lr_mul = lr_mul
self.transfer_mode = transfer_mode
if transfer_mode:
self.transfer_scale = nn.Parameter(torch.ones(out_features, in_features))
self.transfer_scale.FOR_TRANSFER_LEARNING = True
self.transfer_shift = nn.Parameter(torch.zeros(out_features, in_features))
self.transfer_shift.FOR_TRANSFER_LEARNING = True
def forward(self, input):
return F.linear(input, self.weight * self.lr_mul, bias=self.bias * self.lr_mul)
if self.transfer_mode:
weight = self.weight * self.transfer_scale + self.transfer_shift
else:
weight = self.weight
return F.linear(input, weight * self.lr_mul, bias=self.bias * self.lr_mul)
class StyleVectorizer(nn.Module):
def __init__(self, emb, depth, lr_mul=0.1):
def __init__(self, emb, depth, lr_mul=0.1, transfer_mode=False):
super().__init__()
layers = []
for i in range(depth):
layers.extend([EqualLinear(emb, emb, lr_mul), leaky_relu()])
layers.extend([EqualLinear(emb, emb, lr_mul, transfer_mode=transfer_mode), leaky_relu()])
self.net = nn.Sequential(*layers)
@ -227,13 +241,13 @@ class StyleVectorizer(nn.Module):
class RGBBlock(nn.Module):
def __init__(self, latent_dim, input_channel, upsample, rgba=False):
def __init__(self, latent_dim, input_channel, upsample, rgba=False, transfer_mode=False):
super().__init__()
self.input_channel = input_channel
self.to_style = nn.Linear(latent_dim, input_channel)
out_filters = 3 if not rgba else 4
self.conv = Conv2DMod(input_channel, out_filters, 1, demod=False)
self.conv = Conv2DMod(input_channel, out_filters, 1, demod=False, transfer_mode=transfer_mode)
self.upsample = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
@ -307,25 +321,11 @@ class EqualLR:
def equal_lr(module, name='weight'):
EqualLR.apply(module, name)
return module
class EqualConv2d(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
conv = nn.Conv2d(*args, **kwargs)
conv.weight.data.normal_()
conv.bias.data.zero_()
self.conv = equal_lr(conv)
def forward(self, input):
return self.conv(input)
class Conv2DMod(nn.Module):
def __init__(self, in_chan, out_chan, kernel, demod=True, stride=1, dilation=1, **kwargs):
def __init__(self, in_chan, out_chan, kernel, demod=True, stride=1, dilation=1, transfer_mode=False, **kwargs):
super().__init__()
self.filters = out_chan
self.demod = demod
@ -334,6 +334,12 @@ class Conv2DMod(nn.Module):
self.dilation = dilation
self.weight = nn.Parameter(torch.randn((out_chan, in_chan, kernel, kernel)))
nn.init.kaiming_normal_(self.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
self.transfer_mode = transfer_mode
if transfer_mode:
self.transfer_scale = nn.Parameter(torch.ones(out_chan, in_chan, 1, 1))
self.transfer_scale.FOR_TRANSFER_LEARNING = True
self.transfer_shift = nn.Parameter(torch.zeros(out_chan, in_chan, 1, 1))
self.transfer_shift.FOR_TRANSFER_LEARNING = True
def _get_same_padding(self, size, kernel, dilation, stride):
return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2
@ -341,8 +347,13 @@ class Conv2DMod(nn.Module):
def forward(self, x, y):
b, c, h, w = x.shape
if self.transfer_mode:
weight = self.weight * self.transfer_scale + self.transfer_shift
else:
weight = self.weight
w1 = y[:, None, :, None, None]
w2 = self.weight[None, :, :, :, :]
w2 = weight[None, :, :, :, :]
weights = w2 * (w1 + 1)
if self.demod:
@ -362,34 +373,28 @@ class Conv2DMod(nn.Module):
class GeneratorBlock(nn.Module):
def __init__(self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, rgba=False, structure_input=False):
def __init__(self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, rgba=False,
transfer_learning_mode=False):
super().__init__()
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) if upsample else None
self.structure_input = structure_input
if self.structure_input:
self.structure_conv = nn.Conv2d(3, input_channels, 3, padding=1)
input_channels = input_channels * 2
self.to_style1 = TransferLinear(latent_dim, input_channels, transfer_mode=transfer_learning_mode)
self.to_noise1 = TransferLinear(1, filters, transfer_mode=transfer_learning_mode)
self.conv1 = Conv2DMod(input_channels, filters, 3, transfer_mode=transfer_learning_mode)
self.to_style1 = nn.Linear(latent_dim, input_channels)
self.to_noise1 = nn.Linear(1, filters)
self.conv1 = Conv2DMod(input_channels, filters, 3)
self.to_style2 = nn.Linear(latent_dim, filters)
self.to_noise2 = nn.Linear(1, filters)
self.conv2 = Conv2DMod(filters, filters, 3)
self.to_style2 = TransferLinear(latent_dim, filters, transfer_mode=transfer_learning_mode)
self.to_noise2 = TransferLinear(1, filters, transfer_mode=transfer_learning_mode)
self.conv2 = Conv2DMod(filters, filters, 3, transfer_mode=transfer_learning_mode)
self.activation = leaky_relu()
self.to_rgb = RGBBlock(latent_dim, filters, upsample_rgb, rgba)
self.to_rgb = RGBBlock(latent_dim, filters, upsample_rgb, rgba, transfer_mode=transfer_learning_mode)
def forward(self, x, prev_rgb, istyle, inoise, structure_input=None):
self.transfer_learning_mode = transfer_learning_mode
def forward(self, x, prev_rgb, istyle, inoise):
if exists(self.upsample):
x = self.upsample(x)
if self.structure_input:
s = self.structure_conv(structure_input)
x = torch.cat([x, s], dim=1)
inoise = inoise[:, :x.shape[2], :x.shape[3], :]
noise1 = self.to_noise1(inoise).permute((0, 3, 2, 1))
noise2 = self.to_noise2(inoise).permute((0, 3, 2, 1))

View File

@ -0,0 +1,136 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Parameter, init
from torch.nn.modules.conv import _ConvNd
from torch.nn.modules.utils import _ntuple
_pair = _ntuple(2)
class TransferConv2d(_ConvNd):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size,
stride = 1,
padding = 0,
dilation = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
transfer_mode: bool = False
):
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
super().__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _pair(0), groups, bias, padding_mode)
self.transfer_mode = transfer_mode
if transfer_mode:
self.transfer_scale = nn.Parameter(torch.ones(out_channels, in_channels, 1, 1))
self.transfer_scale.FOR_TRANSFER_LEARNING = True
self.transfer_shift = nn.Parameter(torch.zeros(out_channels, in_channels, 1, 1))
self.transfer_shift.FOR_TRANSFER_LEARNING = True
def _conv_forward(self, input, weight):
if self.transfer_mode:
weight = weight * self.transfer_scale + self.transfer_shift
else:
weight = weight
if self.padding_mode != 'zeros':
return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
weight, self.bias, self.stride,
_pair(0), self.dilation, self.groups)
return F.conv2d(input, weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
def forward(self, input: Tensor) -> Tensor:
return self._conv_forward(input, self.weight)
class TransferLinear(nn.Module):
__constants__ = ['in_features', 'out_features']
in_features: int
out_features: int
weight: Tensor
def __init__(self, in_features: int, out_features: int, bias: bool = True, transfer_mode: bool = False) -> None:
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.Tensor(out_features, in_features))
if bias:
self.bias = Parameter(torch.Tensor(out_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
self.transfer_mode = transfer_mode
if transfer_mode:
self.transfer_scale = nn.Parameter(torch.ones(out_features, in_features))
self.transfer_scale.FOR_TRANSFER_LEARNING = True
self.transfer_shift = nn.Parameter(torch.zeros(out_features, in_features))
self.transfer_shift.FOR_TRANSFER_LEARNING = True
def reset_parameters(self) -> None:
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
def forward(self, input: Tensor) -> Tensor:
if self.transfer_mode:
weight = self.weight * self.transfer_scale + self.transfer_shift
else:
weight = self.weight
return F.linear(input, weight, self.bias)
def extra_repr(self) -> str:
return 'in_features={}, out_features={}, bias={}'.format(
self.in_features, self.out_features, self.bias is not None
)
class TransferConvGnLelu(nn.Module):
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, num_groups=8, weight_init_factor=1, transfer_mode=False):
super().__init__()
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
assert kernel_size in padding_map.keys()
self.conv = TransferConv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias, transfer_mode=transfer_mode)
if norm:
self.gn = nn.GroupNorm(num_groups, filters_out)
else:
self.gn = None
if activation:
self.lelu = nn.LeakyReLU(negative_slope=.2)
else:
self.lelu = None
# Init params.
for m in self.modules():
if isinstance(m, TransferConv2d):
nn.init.kaiming_normal_(m.weight, a=.1, mode='fan_out',
nonlinearity='leaky_relu' if self.lelu else 'linear')
m.weight.data *= weight_init_factor
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.conv(x)
if self.gn:
x = self.gn(x)
if self.lelu:
return self.lelu(x)
else:
return x

View File