Transfer learning for styleSR

This is a concept from "Lifelong Learning GAN", although I'm skeptical of it's novelty -
basically you scale and shift the weights for the generator and discriminator of a pretrained
GAN to "shift" into new modalities, e.g. faces->birds or whatever. There are some interesting
applications of this that I would like to try out.
This commit is contained in:
James Betker 2021-01-04 20:10:48 -07:00
parent 2c65b6b28e
commit ade2732c82
6 changed files with 289 additions and 93 deletions

6
.idea/other.xml Normal file
View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PySciProjectComponent">
<option name="PY_MATPLOTLIB_IN_TOOLWINDOW" value="false" />
</component>
</project>

View File

@ -1,4 +1,6 @@
# Heavily based on the lucidrains stylegan2 discriminator implementation. # Heavily based on the lucidrains stylegan2 discriminator implementation.
import math
import os
from functools import partial from functools import partial
from math import log2 from math import log2
from random import random from random import random
@ -6,31 +8,33 @@ from random import random
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchvision
from torch.autograd import grad as torch_grad from torch.autograd import grad as torch_grad
import trainer.losses as L import trainer.losses as L
from vector_quantize_pytorch import VectorQuantize from vector_quantize_pytorch import VectorQuantize
from models.styled_sr.stylegan2_base import attn_and_ff, PermuteToFrom, Blur, leaky_relu, exists from models.styled_sr.stylegan2_base import attn_and_ff, PermuteToFrom, Blur, leaky_relu, exists
from models.styled_sr.transfer_primitives import TransferConv2d, TransferLinear
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import checkpoint, opt_get from utils.util import checkpoint, opt_get
class DiscriminatorBlock(nn.Module): class DiscriminatorBlock(nn.Module):
def __init__(self, input_channels, filters, downsample=True): def __init__(self, input_channels, filters, downsample=True, transfer_mode=False):
super().__init__() super().__init__()
self.filters = filters self.filters = filters
self.conv_res = nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1)) self.conv_res = TransferConv2d(input_channels, filters, 1, stride=(2 if downsample else 1), transfer_mode=transfer_mode)
self.net = nn.Sequential( self.net = nn.Sequential(
nn.Conv2d(input_channels, filters, 3, padding=1), TransferConv2d(input_channels, filters, 3, padding=1, transfer_mode=transfer_mode),
leaky_relu(), leaky_relu(),
nn.Conv2d(filters, filters, 3, padding=1), TransferConv2d(filters, filters, 3, padding=1, transfer_mode=transfer_mode),
leaky_relu() leaky_relu()
) )
self.downsample = nn.Sequential( self.downsample = nn.Sequential(
Blur(), Blur(),
nn.Conv2d(filters, filters, 3, padding=1, stride=2) TransferConv2d(filters, filters, 3, padding=1, stride=2, transfer_mode=transfer_mode)
) if downsample else None ) if downsample else None
def forward(self, x): def forward(self, x):
@ -42,9 +46,10 @@ class DiscriminatorBlock(nn.Module):
return x return x
class StyleGan2Discriminator(nn.Module): class StyleSrDiscriminator(nn.Module):
def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[], def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[],
transparent=False, fmap_max=512, input_filters=3, quantize=False, do_checkpointing=False, mlp=False): transparent=False, fmap_max=512, input_filters=3, quantize=False, do_checkpointing=False, mlp=False,
transfer_mode=False):
super().__init__() super().__init__()
num_layers = int(log2(image_size) - 1) num_layers = int(log2(image_size) - 1)
@ -63,7 +68,7 @@ class StyleGan2Discriminator(nn.Module):
num_layer = ind + 1 num_layer = ind + 1
is_not_last = ind != (len(chan_in_out) - 1) is_not_last = ind != (len(chan_in_out) - 1)
block = DiscriminatorBlock(in_chan, out_chan, downsample=is_not_last) block = DiscriminatorBlock(in_chan, out_chan, downsample=is_not_last, transfer_mode=transfer_mode)
blocks.append(block) blocks.append(block)
attn_fn = attn_and_ff(out_chan) if num_layer in attn_layers else None attn_fn = attn_and_ff(out_chan) if num_layer in attn_layers else None
@ -84,17 +89,23 @@ class StyleGan2Discriminator(nn.Module):
chan_last = filters[-1] chan_last = filters[-1]
latent_dim = 2 * 2 * chan_last latent_dim = 2 * 2 * chan_last
self.final_conv = nn.Conv2d(chan_last, chan_last, 3, padding=1) self.final_conv = TransferConv2d(chan_last, chan_last, 3, padding=1, transfer_mode=transfer_mode)
self.flatten = nn.Flatten() self.flatten = nn.Flatten()
if mlp: if mlp:
self.to_logit = nn.Sequential(nn.Linear(latent_dim, 100), self.to_logit = nn.Sequential(TransferLinear(latent_dim, 100, transfer_mode=transfer_mode),
leaky_relu(), leaky_relu(),
nn.Linear(100, 1)) TransferLinear(100, 1, transfer_mode=transfer_mode))
else: else:
self.to_logit = nn.Linear(latent_dim, 1) self.to_logit = TransferLinear(latent_dim, 1, transfer_mode=transfer_mode)
self._init_weights() self._init_weights()
self.transfer_mode = transfer_mode
if transfer_mode:
for p in self.parameters():
if not hasattr(p, 'FOR_TRANSFER_LEARNING'):
p.DO_NOT_TRAIN = True
def forward(self, x): def forward(self, x):
b, *_ = x.shape b, *_ = x.shape
@ -123,12 +134,12 @@ class StyleGan2Discriminator(nn.Module):
def _init_weights(self): def _init_weights(self):
for m in self.modules(): for m in self.modules():
if type(m) in {nn.Conv2d, nn.Linear}: if type(m) in {TransferConv2d, TransferLinear}:
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
# Configures the network as partially pre-trained. This means: # Configures the network as partially pre-trained. This means:
# 1) The top (high-resolution) `num_blocks` will have their weights re-initialized. # 1) The top (high-resolution) `num_blocks` will have their weights re-initialized.
# 2) The haed (linear layers) will also have their weights re-initialized # 2) The head (linear layers) will also have their weights re-initialized
# 3) All intermediate blocks will be frozen until step `frozen_until_step` # 3) All intermediate blocks will be frozen until step `frozen_until_step`
# These settings will be applied after the weights have been loaded (network_loaded()) # These settings will be applied after the weights have been loaded (network_loaded())
def configure_partial_training(self, bypass_blocks=0, num_blocks=2, frozen_until_step=0): def configure_partial_training(self, bypass_blocks=0, num_blocks=2, frozen_until_step=0):
@ -150,7 +161,7 @@ class StyleGan2Discriminator(nn.Module):
reset_blocks.append(self.blocks[i]) reset_blocks.append(self.blocks[i])
for bl in reset_blocks: for bl in reset_blocks:
for m in bl.modules(): for m in bl.modules():
if type(m) in {nn.Conv2d, nn.Linear}: if type(m) in {TransferConv2d, TransferLinear}:
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
for p in m.parameters(recurse=True): for p in m.parameters(recurse=True):
p._NEW_BLOCK = True p._NEW_BLOCK = True
@ -237,13 +248,15 @@ class DiscAugmentor(nn.Module):
self.prob = prob self.prob = prob
self.types = types self.types = types
def forward(self, images, detach=False): def forward(self, images, real_images=False):
if random() < self.prob: if random() < self.prob:
images = random_hflip(images, prob=0.5) images = random_hflip(images, prob=0.5)
images = DiffAugment(images, types=self.types) images = DiffAugment(images, types=self.types)
if detach: if real_images:
images = images.detach() self.hq_aug = images.detach().clone()
else:
self.gen_aug = images.detach().clone()
# Save away for use elsewhere (e.g. unet loss) # Save away for use elsewhere (e.g. unet loss)
self.aug_images = images self.aug_images = images
@ -253,6 +266,11 @@ class DiscAugmentor(nn.Module):
def network_loaded(self): def network_loaded(self):
self.D.network_loaded() self.D.network_loaded()
# Allows visualizing what the augmentor is up to.
def visual_dbg(self, step, path):
torchvision.utils.save_image(self.gen_aug, os.path.join(path, "%i_gen_aug.png" % (step)))
torchvision.utils.save_image(self.hq_aug, os.path.join(path, "%i_hq_aug.png" % (step)))
def loss_backwards(fp16, loss, optimizer, loss_id, **kwargs): def loss_backwards(fp16, loss, optimizer, loss_id, **kwargs):
if fp16: if fp16:
@ -294,12 +312,12 @@ class StyleSrGanDivergenceLoss(L.ConfigurableLoss):
real_input = real_input + torch.rand_like(real_input) * self.noise real_input = real_input + torch.rand_like(real_input) * self.noise
D = self.env['discriminators'][self.discriminator] D = self.env['discriminators'][self.discriminator]
fake = D(fake_input) fake = D(fake_input, real_images=False)
if self.for_gen: if self.for_gen:
return fake.mean() return fake.mean()
else: else:
real_input.requires_grad_() # <-- Needed to compute gradients on the input. real_input.requires_grad_() # <-- Needed to compute gradients on the input.
real = D(real_input) real = D(real_input, real_images=True)
divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean() divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean()
# Apply gradient penalty. TODO: migrate this elsewhere. # Apply gradient penalty. TODO: migrate this elsewhere.
@ -315,10 +333,12 @@ class StyleSrGanDivergenceLoss(L.ConfigurableLoss):
@register_model @register_model
def register_styledsr_discriminator(opt_net, opt): def register_styledsr_discriminator(opt_net, opt):
attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else [] attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
disc = StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn, disc = StyleSrDiscriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn,
do_checkpointing=opt_get(opt_net, ['do_checkpointing'], False), do_checkpointing=opt_get(opt_net, ['do_checkpointing'], False),
quantize=opt_get(opt_net, ['quantize'], False), quantize=opt_get(opt_net, ['quantize'], False),
mlp=opt_get(opt_net, ['mlp_head'], True)) mlp=opt_get(opt_net, ['mlp_head'], True),
transfer_mode=opt_get(opt_net, ['transfer_mode'], False)
)
if 'use_partial_pretrained' in opt_net.keys(): if 'use_partial_pretrained' in opt_net.keys():
disc.configure_partial_training(opt_net['bypass_blocks'], opt_net['partial_training_blocks'], opt_net['intermediate_blocks_frozen_until']) disc.configure_partial_training(opt_net['bypass_blocks'], opt_net['partial_training_blocks'], opt_net['intermediate_blocks_frozen_until'])
return DiscAugmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability']) return DiscAugmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])

View File

@ -1,29 +1,37 @@
from math import log2
from random import random from random import random
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from models.RRDBNet_arch import RRDB from models.arch_util import kaiming_init
from models.arch_util import ConvGnLelu, default_init_weights from models.styled_sr.stylegan2_base import StyleVectorizer, GeneratorBlock
from models.styled_sr.stylegan2_base import StyleVectorizer, GeneratorBlock, Conv2DMod, leaky_relu, Blur from models.styled_sr.transfer_primitives import TransferConvGnLelu, TransferConv2d, TransferLinear
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import checkpoint, opt_get from utils.util import checkpoint, opt_get
def rrdb_init_weights(module, scale=1):
for m in module.modules():
if isinstance(m, TransferConv2d):
kaiming_init(m, a=0, mode='fan_in', bias=0)
m.weight.data *= scale
elif isinstance(m, TransferLinear):
kaiming_init(m, a=0, mode='fan_in', bias=0)
m.weight.data *= scale
class EncoderRRDB(nn.Module): class EncoderRRDB(nn.Module):
def __init__(self, mid_channels=64, output_channels=32, growth_channels=32, init_weight=.1): def __init__(self, mid_channels=64, output_channels=32, growth_channels=32, init_weight=.1, transfer_mode=False):
super(EncoderRRDB, self).__init__() super(EncoderRRDB, self).__init__()
for i in range(5): for i in range(5):
out_channels = output_channels if i == 4 else growth_channels out_channels = output_channels if i == 4 else growth_channels
self.add_module( self.add_module(
f'conv{i+1}', f'conv{i+1}',
nn.Conv2d(mid_channels + i * growth_channels, out_channels, 3, TransferConv2d(mid_channels + i * growth_channels, out_channels, 3, 1, 1, transfer_mode=transfer_mode))
1, 1))
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
for i in range(5): for i in range(5):
default_init_weights(getattr(self, f'conv{i+1}'), init_weight) rrdb_init_weights(getattr(self, f'conv{i+1}'), init_weight)
def forward(self, x): def forward(self, x):
x1 = self.lrelu(self.conv1(x)) x1 = self.lrelu(self.conv1(x))
@ -35,18 +43,18 @@ class EncoderRRDB(nn.Module):
class StyledSrEncoder(nn.Module): class StyledSrEncoder(nn.Module):
def __init__(self, fea_out=256, initial_stride=1): def __init__(self, fea_out=256, initial_stride=1, transfer_mode=False):
super().__init__() super().__init__()
# Current assumes fea_out=256. # Current assumes fea_out=256.
self.initial_conv = ConvGnLelu(3, 32, kernel_size=7, stride=initial_stride, norm=False, activation=False, bias=True) self.initial_conv = TransferConvGnLelu(3, 32, kernel_size=7, stride=initial_stride, norm=False, activation=False, bias=True, transfer_mode=transfer_mode)
self.rrdbs = nn.ModuleList([ self.rrdbs = nn.ModuleList([
EncoderRRDB(32), EncoderRRDB(32, transfer_mode=transfer_mode),
EncoderRRDB(64), EncoderRRDB(64, transfer_mode=transfer_mode),
EncoderRRDB(96), EncoderRRDB(96, transfer_mode=transfer_mode),
EncoderRRDB(128), EncoderRRDB(128, transfer_mode=transfer_mode),
EncoderRRDB(160), EncoderRRDB(160, transfer_mode=transfer_mode),
EncoderRRDB(192), EncoderRRDB(192, transfer_mode=transfer_mode),
EncoderRRDB(224)]) EncoderRRDB(224, transfer_mode=transfer_mode)])
def forward(self, x): def forward(self, x):
fea = self.initial_conv(x) fea = self.initial_conv(x)
@ -56,13 +64,14 @@ class StyledSrEncoder(nn.Module):
class Generator(nn.Module): class Generator(nn.Module):
def __init__(self, image_size, latent_dim, initial_stride=1, start_level=3, upsample_levels=2): def __init__(self, image_size, latent_dim, initial_stride=1, start_level=3, upsample_levels=2, transfer_mode=False):
super().__init__() super().__init__()
total_levels = upsample_levels + 1 # The first level handles the raw encoder output and doesn't upsample. total_levels = upsample_levels + 1 # The first level handles the raw encoder output and doesn't upsample.
self.image_size = image_size self.image_size = image_size
self.scale = 2 ** upsample_levels self.scale = 2 ** upsample_levels
self.latent_dim = latent_dim self.latent_dim = latent_dim
self.num_layers = total_levels self.num_layers = total_levels
self.transfer_mode = transfer_mode
filters = [ filters = [
512, # 4x4 512, # 4x4
512, # 8x8 512, # 8x8
@ -75,7 +84,8 @@ class Generator(nn.Module):
8, # 1024x1024 8, # 1024x1024
] ]
self.encoder = StyledSrEncoder(filters[start_level], initial_stride) # I'm making a guess here that the encoder does not need transfer learning, hence fixed transfer_mode=False. This should be vetted.
self.encoder = StyledSrEncoder(filters[start_level], initial_stride, transfer_mode=False)
in_out_pairs = list(zip(filters[:-1], filters[1:])) in_out_pairs = list(zip(filters[:-1], filters[1:]))
self.blocks = nn.ModuleList([]) self.blocks = nn.ModuleList([])
@ -88,13 +98,18 @@ class Generator(nn.Module):
in_chan, in_chan,
out_chan, out_chan,
upsample=not_first, upsample=not_first,
upsample_rgb=not_last upsample_rgb=not_last,
transfer_learning_mode=transfer_mode
) )
self.blocks.append(block) self.blocks.append(block)
def forward(self, lr, styles): def forward(self, lr, styles):
b, c, h, w = lr.shape b, c, h, w = lr.shape
x = self.encoder(lr) if self.transfer_mode:
with torch.no_grad():
x = self.encoder(lr)
else:
x = self.encoder(lr)
styles = styles.transpose(0, 1) styles = styles.transpose(0, 1)
input_noise = torch.rand(b, h * self.scale, w * self.scale, 1).to(lr.device) input_noise = torch.rand(b, h * self.scale, w * self.scale, 1).to(lr.device)
@ -102,6 +117,7 @@ class Generator(nn.Module):
rgb = F.interpolate(lr, size=x.shape[2:], mode="area") rgb = F.interpolate(lr, size=x.shape[2:], mode="area")
else: else:
rgb = lr rgb = lr
for style, block in zip(styles, self.blocks): for style, block in zip(styles, self.blocks):
x, rgb = checkpoint(block, x, rgb, style, input_noise) x, rgb = checkpoint(block, x, rgb, style, input_noise)
@ -109,16 +125,23 @@ class Generator(nn.Module):
class StyledSrGenerator(nn.Module): class StyledSrGenerator(nn.Module):
def __init__(self, image_size, initial_stride=1, latent_dim=512, style_depth=8, lr_mlp=.1): def __init__(self, image_size, initial_stride=1, latent_dim=512, style_depth=8, lr_mlp=.1, transfer_mode=False):
super().__init__() super().__init__()
self.vectorizer = StyleVectorizer(latent_dim, style_depth, lr_mul=lr_mlp) # Assume the vectorizer doesnt need transfer_mode=True. Re-evaluate this later.
self.gen = Generator(image_size=image_size, latent_dim=latent_dim, initial_stride=initial_stride) self.vectorizer = StyleVectorizer(latent_dim, style_depth, lr_mul=lr_mlp, transfer_mode=False)
self.gen = Generator(image_size=image_size, latent_dim=latent_dim, initial_stride=initial_stride, transfer_mode=transfer_mode)
self.mixed_prob = .9 self.mixed_prob = .9
self._init_weights() self._init_weights()
self.transfer_mode = transfer_mode
if transfer_mode:
for p in self.parameters():
if not hasattr(p, 'FOR_TRANSFER_LEARNING'):
p.DO_NOT_TRAIN = True
def _init_weights(self): def _init_weights(self):
for m in self.modules(): for m in self.modules():
if type(m) in {nn.Conv2d, nn.Linear} and hasattr(m, 'weight'): if type(m) in {TransferConv2d, TransferLinear} and hasattr(m, 'weight'):
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
for block in self.gen.blocks: for block in self.gen.blocks:
@ -132,7 +155,11 @@ class StyledSrGenerator(nn.Module):
# Synthesize style latents from noise. # Synthesize style latents from noise.
style = torch.randn(b*2, self.gen.latent_dim).to(x.device) style = torch.randn(b*2, self.gen.latent_dim).to(x.device)
w = self.vectorizer(style) if self.transfer_mode:
with torch.no_grad():
w = self.vectorizer(style)
else:
w = self.vectorizer(style)
# Randomly distribute styles across layers # Randomly distribute styles across layers
w_styles = w[:,None,:].expand(-1, self.gen.num_layers, -1).clone() w_styles = w[:,None,:].expand(-1, self.gen.num_layers, -1).clone()
@ -162,4 +189,6 @@ if __name__ == '__main__':
@register_model @register_model
def register_styled_sr(opt_net, opt): def register_styled_sr(opt_net, opt):
return StyledSrGenerator(128, initial_stride=opt_get(opt_net, ['initial_stride'], 1)) return StyledSrGenerator(128,
initial_stride=opt_get(opt_net, ['initial_stride'], 1),
transfer_mode=opt_get(opt_net, ['transfer_mode'], False))

View File

@ -6,8 +6,12 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from kornia.filters import filter2D from kornia.filters import filter2D
from linear_attention_transformer import ImageLinearAttention from linear_attention_transformer import ImageLinearAttention
from torch import nn from torch import nn, Tensor
from torch.autograd import grad as torch_grad from torch.autograd import grad as torch_grad
from torch.nn import Parameter, init
from torch.nn.modules.conv import _ConvNd
from models.styled_sr.transfer_primitives import TransferLinear
assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.' assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.'
@ -196,10 +200,9 @@ def slerp(val, low, high):
res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
return res return res
# stylegan2 classes
class EqualLinear(nn.Module): class EqualLinear(nn.Module):
def __init__(self, in_dim, out_dim, lr_mul=1, bias=True): def __init__(self, in_dim, out_dim, lr_mul=1, bias=True, transfer_mode=False):
super().__init__() super().__init__()
self.weight = nn.Parameter(torch.randn(out_dim, in_dim)) self.weight = nn.Parameter(torch.randn(out_dim, in_dim))
if bias: if bias:
@ -207,17 +210,28 @@ class EqualLinear(nn.Module):
self.lr_mul = lr_mul self.lr_mul = lr_mul
self.transfer_mode = transfer_mode
if transfer_mode:
self.transfer_scale = nn.Parameter(torch.ones(out_features, in_features))
self.transfer_scale.FOR_TRANSFER_LEARNING = True
self.transfer_shift = nn.Parameter(torch.zeros(out_features, in_features))
self.transfer_shift.FOR_TRANSFER_LEARNING = True
def forward(self, input): def forward(self, input):
return F.linear(input, self.weight * self.lr_mul, bias=self.bias * self.lr_mul) if self.transfer_mode:
weight = self.weight * self.transfer_scale + self.transfer_shift
else:
weight = self.weight
return F.linear(input, weight * self.lr_mul, bias=self.bias * self.lr_mul)
class StyleVectorizer(nn.Module): class StyleVectorizer(nn.Module):
def __init__(self, emb, depth, lr_mul=0.1): def __init__(self, emb, depth, lr_mul=0.1, transfer_mode=False):
super().__init__() super().__init__()
layers = [] layers = []
for i in range(depth): for i in range(depth):
layers.extend([EqualLinear(emb, emb, lr_mul), leaky_relu()]) layers.extend([EqualLinear(emb, emb, lr_mul, transfer_mode=transfer_mode), leaky_relu()])
self.net = nn.Sequential(*layers) self.net = nn.Sequential(*layers)
@ -227,13 +241,13 @@ class StyleVectorizer(nn.Module):
class RGBBlock(nn.Module): class RGBBlock(nn.Module):
def __init__(self, latent_dim, input_channel, upsample, rgba=False): def __init__(self, latent_dim, input_channel, upsample, rgba=False, transfer_mode=False):
super().__init__() super().__init__()
self.input_channel = input_channel self.input_channel = input_channel
self.to_style = nn.Linear(latent_dim, input_channel) self.to_style = nn.Linear(latent_dim, input_channel)
out_filters = 3 if not rgba else 4 out_filters = 3 if not rgba else 4
self.conv = Conv2DMod(input_channel, out_filters, 1, demod=False) self.conv = Conv2DMod(input_channel, out_filters, 1, demod=False, transfer_mode=transfer_mode)
self.upsample = nn.Sequential( self.upsample = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
@ -307,25 +321,11 @@ class EqualLR:
def equal_lr(module, name='weight'): def equal_lr(module, name='weight'):
EqualLR.apply(module, name) EqualLR.apply(module, name)
return module return module
class EqualConv2d(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
conv = nn.Conv2d(*args, **kwargs)
conv.weight.data.normal_()
conv.bias.data.zero_()
self.conv = equal_lr(conv)
def forward(self, input):
return self.conv(input)
class Conv2DMod(nn.Module): class Conv2DMod(nn.Module):
def __init__(self, in_chan, out_chan, kernel, demod=True, stride=1, dilation=1, **kwargs): def __init__(self, in_chan, out_chan, kernel, demod=True, stride=1, dilation=1, transfer_mode=False, **kwargs):
super().__init__() super().__init__()
self.filters = out_chan self.filters = out_chan
self.demod = demod self.demod = demod
@ -334,6 +334,12 @@ class Conv2DMod(nn.Module):
self.dilation = dilation self.dilation = dilation
self.weight = nn.Parameter(torch.randn((out_chan, in_chan, kernel, kernel))) self.weight = nn.Parameter(torch.randn((out_chan, in_chan, kernel, kernel)))
nn.init.kaiming_normal_(self.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') nn.init.kaiming_normal_(self.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
self.transfer_mode = transfer_mode
if transfer_mode:
self.transfer_scale = nn.Parameter(torch.ones(out_chan, in_chan, 1, 1))
self.transfer_scale.FOR_TRANSFER_LEARNING = True
self.transfer_shift = nn.Parameter(torch.zeros(out_chan, in_chan, 1, 1))
self.transfer_shift.FOR_TRANSFER_LEARNING = True
def _get_same_padding(self, size, kernel, dilation, stride): def _get_same_padding(self, size, kernel, dilation, stride):
return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2 return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2
@ -341,8 +347,13 @@ class Conv2DMod(nn.Module):
def forward(self, x, y): def forward(self, x, y):
b, c, h, w = x.shape b, c, h, w = x.shape
if self.transfer_mode:
weight = self.weight * self.transfer_scale + self.transfer_shift
else:
weight = self.weight
w1 = y[:, None, :, None, None] w1 = y[:, None, :, None, None]
w2 = self.weight[None, :, :, :, :] w2 = weight[None, :, :, :, :]
weights = w2 * (w1 + 1) weights = w2 * (w1 + 1)
if self.demod: if self.demod:
@ -362,34 +373,28 @@ class Conv2DMod(nn.Module):
class GeneratorBlock(nn.Module): class GeneratorBlock(nn.Module):
def __init__(self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, rgba=False, structure_input=False): def __init__(self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, rgba=False,
transfer_learning_mode=False):
super().__init__() super().__init__()
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) if upsample else None self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) if upsample else None
self.structure_input = structure_input self.to_style1 = TransferLinear(latent_dim, input_channels, transfer_mode=transfer_learning_mode)
if self.structure_input: self.to_noise1 = TransferLinear(1, filters, transfer_mode=transfer_learning_mode)
self.structure_conv = nn.Conv2d(3, input_channels, 3, padding=1) self.conv1 = Conv2DMod(input_channels, filters, 3, transfer_mode=transfer_learning_mode)
input_channels = input_channels * 2
self.to_style1 = nn.Linear(latent_dim, input_channels) self.to_style2 = TransferLinear(latent_dim, filters, transfer_mode=transfer_learning_mode)
self.to_noise1 = nn.Linear(1, filters) self.to_noise2 = TransferLinear(1, filters, transfer_mode=transfer_learning_mode)
self.conv1 = Conv2DMod(input_channels, filters, 3) self.conv2 = Conv2DMod(filters, filters, 3, transfer_mode=transfer_learning_mode)
self.to_style2 = nn.Linear(latent_dim, filters)
self.to_noise2 = nn.Linear(1, filters)
self.conv2 = Conv2DMod(filters, filters, 3)
self.activation = leaky_relu() self.activation = leaky_relu()
self.to_rgb = RGBBlock(latent_dim, filters, upsample_rgb, rgba) self.to_rgb = RGBBlock(latent_dim, filters, upsample_rgb, rgba, transfer_mode=transfer_learning_mode)
def forward(self, x, prev_rgb, istyle, inoise, structure_input=None): self.transfer_learning_mode = transfer_learning_mode
def forward(self, x, prev_rgb, istyle, inoise):
if exists(self.upsample): if exists(self.upsample):
x = self.upsample(x) x = self.upsample(x)
if self.structure_input:
s = self.structure_conv(structure_input)
x = torch.cat([x, s], dim=1)
inoise = inoise[:, :x.shape[2], :x.shape[3], :] inoise = inoise[:, :x.shape[2], :x.shape[3], :]
noise1 = self.to_noise1(inoise).permute((0, 3, 2, 1)) noise1 = self.to_noise1(inoise).permute((0, 3, 2, 1))
noise2 = self.to_noise2(inoise).permute((0, 3, 2, 1)) noise2 = self.to_noise2(inoise).permute((0, 3, 2, 1))

View File

@ -0,0 +1,136 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Parameter, init
from torch.nn.modules.conv import _ConvNd
from torch.nn.modules.utils import _ntuple
_pair = _ntuple(2)
class TransferConv2d(_ConvNd):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size,
stride = 1,
padding = 0,
dilation = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
transfer_mode: bool = False
):
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
super().__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _pair(0), groups, bias, padding_mode)
self.transfer_mode = transfer_mode
if transfer_mode:
self.transfer_scale = nn.Parameter(torch.ones(out_channels, in_channels, 1, 1))
self.transfer_scale.FOR_TRANSFER_LEARNING = True
self.transfer_shift = nn.Parameter(torch.zeros(out_channels, in_channels, 1, 1))
self.transfer_shift.FOR_TRANSFER_LEARNING = True
def _conv_forward(self, input, weight):
if self.transfer_mode:
weight = weight * self.transfer_scale + self.transfer_shift
else:
weight = weight
if self.padding_mode != 'zeros':
return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
weight, self.bias, self.stride,
_pair(0), self.dilation, self.groups)
return F.conv2d(input, weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
def forward(self, input: Tensor) -> Tensor:
return self._conv_forward(input, self.weight)
class TransferLinear(nn.Module):
__constants__ = ['in_features', 'out_features']
in_features: int
out_features: int
weight: Tensor
def __init__(self, in_features: int, out_features: int, bias: bool = True, transfer_mode: bool = False) -> None:
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.Tensor(out_features, in_features))
if bias:
self.bias = Parameter(torch.Tensor(out_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
self.transfer_mode = transfer_mode
if transfer_mode:
self.transfer_scale = nn.Parameter(torch.ones(out_features, in_features))
self.transfer_scale.FOR_TRANSFER_LEARNING = True
self.transfer_shift = nn.Parameter(torch.zeros(out_features, in_features))
self.transfer_shift.FOR_TRANSFER_LEARNING = True
def reset_parameters(self) -> None:
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
def forward(self, input: Tensor) -> Tensor:
if self.transfer_mode:
weight = self.weight * self.transfer_scale + self.transfer_shift
else:
weight = self.weight
return F.linear(input, weight, self.bias)
def extra_repr(self) -> str:
return 'in_features={}, out_features={}, bias={}'.format(
self.in_features, self.out_features, self.bias is not None
)
class TransferConvGnLelu(nn.Module):
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, num_groups=8, weight_init_factor=1, transfer_mode=False):
super().__init__()
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
assert kernel_size in padding_map.keys()
self.conv = TransferConv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias, transfer_mode=transfer_mode)
if norm:
self.gn = nn.GroupNorm(num_groups, filters_out)
else:
self.gn = None
if activation:
self.lelu = nn.LeakyReLU(negative_slope=.2)
else:
self.lelu = None
# Init params.
for m in self.modules():
if isinstance(m, TransferConv2d):
nn.init.kaiming_normal_(m.weight, a=.1, mode='fan_out',
nonlinearity='leaky_relu' if self.lelu else 'linear')
m.weight.data *= weight_init_factor
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.conv(x)
if self.gn:
x = self.gn(x)
if self.lelu:
return self.lelu(x)
else:
return x

View File