forked from mrq/DL-Art-School
Transfer learning for styleSR
This is a concept from "Lifelong Learning GAN", although I'm skeptical of it's novelty - basically you scale and shift the weights for the generator and discriminator of a pretrained GAN to "shift" into new modalities, e.g. faces->birds or whatever. There are some interesting applications of this that I would like to try out.
This commit is contained in:
parent
2c65b6b28e
commit
ade2732c82
6
.idea/other.xml
Normal file
6
.idea/other.xml
Normal file
|
@ -0,0 +1,6 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="PySciProjectComponent">
|
||||
<option name="PY_MATPLOTLIB_IN_TOOLWINDOW" value="false" />
|
||||
</component>
|
||||
</project>
|
|
@ -1,4 +1,6 @@
|
|||
# Heavily based on the lucidrains stylegan2 discriminator implementation.
|
||||
import math
|
||||
import os
|
||||
from functools import partial
|
||||
from math import log2
|
||||
from random import random
|
||||
|
@ -6,31 +8,33 @@ from random import random
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
from torch.autograd import grad as torch_grad
|
||||
import trainer.losses as L
|
||||
from vector_quantize_pytorch import VectorQuantize
|
||||
|
||||
from models.styled_sr.stylegan2_base import attn_and_ff, PermuteToFrom, Blur, leaky_relu, exists
|
||||
from models.styled_sr.transfer_primitives import TransferConv2d, TransferLinear
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint, opt_get
|
||||
|
||||
|
||||
class DiscriminatorBlock(nn.Module):
|
||||
def __init__(self, input_channels, filters, downsample=True):
|
||||
def __init__(self, input_channels, filters, downsample=True, transfer_mode=False):
|
||||
super().__init__()
|
||||
self.filters = filters
|
||||
self.conv_res = nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1))
|
||||
self.conv_res = TransferConv2d(input_channels, filters, 1, stride=(2 if downsample else 1), transfer_mode=transfer_mode)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(input_channels, filters, 3, padding=1),
|
||||
TransferConv2d(input_channels, filters, 3, padding=1, transfer_mode=transfer_mode),
|
||||
leaky_relu(),
|
||||
nn.Conv2d(filters, filters, 3, padding=1),
|
||||
TransferConv2d(filters, filters, 3, padding=1, transfer_mode=transfer_mode),
|
||||
leaky_relu()
|
||||
)
|
||||
|
||||
self.downsample = nn.Sequential(
|
||||
Blur(),
|
||||
nn.Conv2d(filters, filters, 3, padding=1, stride=2)
|
||||
TransferConv2d(filters, filters, 3, padding=1, stride=2, transfer_mode=transfer_mode)
|
||||
) if downsample else None
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -42,9 +46,10 @@ class DiscriminatorBlock(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class StyleGan2Discriminator(nn.Module):
|
||||
class StyleSrDiscriminator(nn.Module):
|
||||
def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[],
|
||||
transparent=False, fmap_max=512, input_filters=3, quantize=False, do_checkpointing=False, mlp=False):
|
||||
transparent=False, fmap_max=512, input_filters=3, quantize=False, do_checkpointing=False, mlp=False,
|
||||
transfer_mode=False):
|
||||
super().__init__()
|
||||
num_layers = int(log2(image_size) - 1)
|
||||
|
||||
|
@ -63,7 +68,7 @@ class StyleGan2Discriminator(nn.Module):
|
|||
num_layer = ind + 1
|
||||
is_not_last = ind != (len(chan_in_out) - 1)
|
||||
|
||||
block = DiscriminatorBlock(in_chan, out_chan, downsample=is_not_last)
|
||||
block = DiscriminatorBlock(in_chan, out_chan, downsample=is_not_last, transfer_mode=transfer_mode)
|
||||
blocks.append(block)
|
||||
|
||||
attn_fn = attn_and_ff(out_chan) if num_layer in attn_layers else None
|
||||
|
@ -84,17 +89,23 @@ class StyleGan2Discriminator(nn.Module):
|
|||
chan_last = filters[-1]
|
||||
latent_dim = 2 * 2 * chan_last
|
||||
|
||||
self.final_conv = nn.Conv2d(chan_last, chan_last, 3, padding=1)
|
||||
self.final_conv = TransferConv2d(chan_last, chan_last, 3, padding=1, transfer_mode=transfer_mode)
|
||||
self.flatten = nn.Flatten()
|
||||
if mlp:
|
||||
self.to_logit = nn.Sequential(nn.Linear(latent_dim, 100),
|
||||
self.to_logit = nn.Sequential(TransferLinear(latent_dim, 100, transfer_mode=transfer_mode),
|
||||
leaky_relu(),
|
||||
nn.Linear(100, 1))
|
||||
TransferLinear(100, 1, transfer_mode=transfer_mode))
|
||||
else:
|
||||
self.to_logit = nn.Linear(latent_dim, 1)
|
||||
self.to_logit = TransferLinear(latent_dim, 1, transfer_mode=transfer_mode)
|
||||
|
||||
self._init_weights()
|
||||
|
||||
self.transfer_mode = transfer_mode
|
||||
if transfer_mode:
|
||||
for p in self.parameters():
|
||||
if not hasattr(p, 'FOR_TRANSFER_LEARNING'):
|
||||
p.DO_NOT_TRAIN = True
|
||||
|
||||
def forward(self, x):
|
||||
b, *_ = x.shape
|
||||
|
||||
|
@ -123,12 +134,12 @@ class StyleGan2Discriminator(nn.Module):
|
|||
|
||||
def _init_weights(self):
|
||||
for m in self.modules():
|
||||
if type(m) in {nn.Conv2d, nn.Linear}:
|
||||
if type(m) in {TransferConv2d, TransferLinear}:
|
||||
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
|
||||
|
||||
# Configures the network as partially pre-trained. This means:
|
||||
# 1) The top (high-resolution) `num_blocks` will have their weights re-initialized.
|
||||
# 2) The haed (linear layers) will also have their weights re-initialized
|
||||
# 2) The head (linear layers) will also have their weights re-initialized
|
||||
# 3) All intermediate blocks will be frozen until step `frozen_until_step`
|
||||
# These settings will be applied after the weights have been loaded (network_loaded())
|
||||
def configure_partial_training(self, bypass_blocks=0, num_blocks=2, frozen_until_step=0):
|
||||
|
@ -150,7 +161,7 @@ class StyleGan2Discriminator(nn.Module):
|
|||
reset_blocks.append(self.blocks[i])
|
||||
for bl in reset_blocks:
|
||||
for m in bl.modules():
|
||||
if type(m) in {nn.Conv2d, nn.Linear}:
|
||||
if type(m) in {TransferConv2d, TransferLinear}:
|
||||
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
|
||||
for p in m.parameters(recurse=True):
|
||||
p._NEW_BLOCK = True
|
||||
|
@ -237,13 +248,15 @@ class DiscAugmentor(nn.Module):
|
|||
self.prob = prob
|
||||
self.types = types
|
||||
|
||||
def forward(self, images, detach=False):
|
||||
def forward(self, images, real_images=False):
|
||||
if random() < self.prob:
|
||||
images = random_hflip(images, prob=0.5)
|
||||
images = DiffAugment(images, types=self.types)
|
||||
|
||||
if detach:
|
||||
images = images.detach()
|
||||
if real_images:
|
||||
self.hq_aug = images.detach().clone()
|
||||
else:
|
||||
self.gen_aug = images.detach().clone()
|
||||
|
||||
# Save away for use elsewhere (e.g. unet loss)
|
||||
self.aug_images = images
|
||||
|
@ -253,6 +266,11 @@ class DiscAugmentor(nn.Module):
|
|||
def network_loaded(self):
|
||||
self.D.network_loaded()
|
||||
|
||||
# Allows visualizing what the augmentor is up to.
|
||||
def visual_dbg(self, step, path):
|
||||
torchvision.utils.save_image(self.gen_aug, os.path.join(path, "%i_gen_aug.png" % (step)))
|
||||
torchvision.utils.save_image(self.hq_aug, os.path.join(path, "%i_hq_aug.png" % (step)))
|
||||
|
||||
|
||||
def loss_backwards(fp16, loss, optimizer, loss_id, **kwargs):
|
||||
if fp16:
|
||||
|
@ -294,12 +312,12 @@ class StyleSrGanDivergenceLoss(L.ConfigurableLoss):
|
|||
real_input = real_input + torch.rand_like(real_input) * self.noise
|
||||
|
||||
D = self.env['discriminators'][self.discriminator]
|
||||
fake = D(fake_input)
|
||||
fake = D(fake_input, real_images=False)
|
||||
if self.for_gen:
|
||||
return fake.mean()
|
||||
else:
|
||||
real_input.requires_grad_() # <-- Needed to compute gradients on the input.
|
||||
real = D(real_input)
|
||||
real = D(real_input, real_images=True)
|
||||
divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean()
|
||||
|
||||
# Apply gradient penalty. TODO: migrate this elsewhere.
|
||||
|
@ -315,10 +333,12 @@ class StyleSrGanDivergenceLoss(L.ConfigurableLoss):
|
|||
@register_model
|
||||
def register_styledsr_discriminator(opt_net, opt):
|
||||
attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
|
||||
disc = StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn,
|
||||
do_checkpointing=opt_get(opt_net, ['do_checkpointing'], False),
|
||||
quantize=opt_get(opt_net, ['quantize'], False),
|
||||
mlp=opt_get(opt_net, ['mlp_head'], True))
|
||||
disc = StyleSrDiscriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn,
|
||||
do_checkpointing=opt_get(opt_net, ['do_checkpointing'], False),
|
||||
quantize=opt_get(opt_net, ['quantize'], False),
|
||||
mlp=opt_get(opt_net, ['mlp_head'], True),
|
||||
transfer_mode=opt_get(opt_net, ['transfer_mode'], False)
|
||||
)
|
||||
if 'use_partial_pretrained' in opt_net.keys():
|
||||
disc.configure_partial_training(opt_net['bypass_blocks'], opt_net['partial_training_blocks'], opt_net['intermediate_blocks_frozen_until'])
|
||||
return DiscAugmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])
|
||||
|
|
|
@ -1,29 +1,37 @@
|
|||
from math import log2
|
||||
from random import random
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from models.RRDBNet_arch import RRDB
|
||||
from models.arch_util import ConvGnLelu, default_init_weights
|
||||
from models.styled_sr.stylegan2_base import StyleVectorizer, GeneratorBlock, Conv2DMod, leaky_relu, Blur
|
||||
from models.arch_util import kaiming_init
|
||||
from models.styled_sr.stylegan2_base import StyleVectorizer, GeneratorBlock
|
||||
from models.styled_sr.transfer_primitives import TransferConvGnLelu, TransferConv2d, TransferLinear
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint, opt_get
|
||||
|
||||
|
||||
def rrdb_init_weights(module, scale=1):
|
||||
for m in module.modules():
|
||||
if isinstance(m, TransferConv2d):
|
||||
kaiming_init(m, a=0, mode='fan_in', bias=0)
|
||||
m.weight.data *= scale
|
||||
elif isinstance(m, TransferLinear):
|
||||
kaiming_init(m, a=0, mode='fan_in', bias=0)
|
||||
m.weight.data *= scale
|
||||
|
||||
|
||||
class EncoderRRDB(nn.Module):
|
||||
def __init__(self, mid_channels=64, output_channels=32, growth_channels=32, init_weight=.1):
|
||||
def __init__(self, mid_channels=64, output_channels=32, growth_channels=32, init_weight=.1, transfer_mode=False):
|
||||
super(EncoderRRDB, self).__init__()
|
||||
for i in range(5):
|
||||
out_channels = output_channels if i == 4 else growth_channels
|
||||
self.add_module(
|
||||
f'conv{i+1}',
|
||||
nn.Conv2d(mid_channels + i * growth_channels, out_channels, 3,
|
||||
1, 1))
|
||||
TransferConv2d(mid_channels + i * growth_channels, out_channels, 3, 1, 1, transfer_mode=transfer_mode))
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
for i in range(5):
|
||||
default_init_weights(getattr(self, f'conv{i+1}'), init_weight)
|
||||
rrdb_init_weights(getattr(self, f'conv{i+1}'), init_weight)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.lrelu(self.conv1(x))
|
||||
|
@ -35,18 +43,18 @@ class EncoderRRDB(nn.Module):
|
|||
|
||||
|
||||
class StyledSrEncoder(nn.Module):
|
||||
def __init__(self, fea_out=256, initial_stride=1):
|
||||
def __init__(self, fea_out=256, initial_stride=1, transfer_mode=False):
|
||||
super().__init__()
|
||||
# Current assumes fea_out=256.
|
||||
self.initial_conv = ConvGnLelu(3, 32, kernel_size=7, stride=initial_stride, norm=False, activation=False, bias=True)
|
||||
self.initial_conv = TransferConvGnLelu(3, 32, kernel_size=7, stride=initial_stride, norm=False, activation=False, bias=True, transfer_mode=transfer_mode)
|
||||
self.rrdbs = nn.ModuleList([
|
||||
EncoderRRDB(32),
|
||||
EncoderRRDB(64),
|
||||
EncoderRRDB(96),
|
||||
EncoderRRDB(128),
|
||||
EncoderRRDB(160),
|
||||
EncoderRRDB(192),
|
||||
EncoderRRDB(224)])
|
||||
EncoderRRDB(32, transfer_mode=transfer_mode),
|
||||
EncoderRRDB(64, transfer_mode=transfer_mode),
|
||||
EncoderRRDB(96, transfer_mode=transfer_mode),
|
||||
EncoderRRDB(128, transfer_mode=transfer_mode),
|
||||
EncoderRRDB(160, transfer_mode=transfer_mode),
|
||||
EncoderRRDB(192, transfer_mode=transfer_mode),
|
||||
EncoderRRDB(224, transfer_mode=transfer_mode)])
|
||||
|
||||
def forward(self, x):
|
||||
fea = self.initial_conv(x)
|
||||
|
@ -56,13 +64,14 @@ class StyledSrEncoder(nn.Module):
|
|||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, image_size, latent_dim, initial_stride=1, start_level=3, upsample_levels=2):
|
||||
def __init__(self, image_size, latent_dim, initial_stride=1, start_level=3, upsample_levels=2, transfer_mode=False):
|
||||
super().__init__()
|
||||
total_levels = upsample_levels + 1 # The first level handles the raw encoder output and doesn't upsample.
|
||||
self.image_size = image_size
|
||||
self.scale = 2 ** upsample_levels
|
||||
self.latent_dim = latent_dim
|
||||
self.num_layers = total_levels
|
||||
self.transfer_mode = transfer_mode
|
||||
filters = [
|
||||
512, # 4x4
|
||||
512, # 8x8
|
||||
|
@ -75,7 +84,8 @@ class Generator(nn.Module):
|
|||
8, # 1024x1024
|
||||
]
|
||||
|
||||
self.encoder = StyledSrEncoder(filters[start_level], initial_stride)
|
||||
# I'm making a guess here that the encoder does not need transfer learning, hence fixed transfer_mode=False. This should be vetted.
|
||||
self.encoder = StyledSrEncoder(filters[start_level], initial_stride, transfer_mode=False)
|
||||
|
||||
in_out_pairs = list(zip(filters[:-1], filters[1:]))
|
||||
self.blocks = nn.ModuleList([])
|
||||
|
@ -88,13 +98,18 @@ class Generator(nn.Module):
|
|||
in_chan,
|
||||
out_chan,
|
||||
upsample=not_first,
|
||||
upsample_rgb=not_last
|
||||
upsample_rgb=not_last,
|
||||
transfer_learning_mode=transfer_mode
|
||||
)
|
||||
self.blocks.append(block)
|
||||
|
||||
def forward(self, lr, styles):
|
||||
b, c, h, w = lr.shape
|
||||
x = self.encoder(lr)
|
||||
if self.transfer_mode:
|
||||
with torch.no_grad():
|
||||
x = self.encoder(lr)
|
||||
else:
|
||||
x = self.encoder(lr)
|
||||
|
||||
styles = styles.transpose(0, 1)
|
||||
input_noise = torch.rand(b, h * self.scale, w * self.scale, 1).to(lr.device)
|
||||
|
@ -102,6 +117,7 @@ class Generator(nn.Module):
|
|||
rgb = F.interpolate(lr, size=x.shape[2:], mode="area")
|
||||
else:
|
||||
rgb = lr
|
||||
|
||||
for style, block in zip(styles, self.blocks):
|
||||
x, rgb = checkpoint(block, x, rgb, style, input_noise)
|
||||
|
||||
|
@ -109,16 +125,23 @@ class Generator(nn.Module):
|
|||
|
||||
|
||||
class StyledSrGenerator(nn.Module):
|
||||
def __init__(self, image_size, initial_stride=1, latent_dim=512, style_depth=8, lr_mlp=.1):
|
||||
def __init__(self, image_size, initial_stride=1, latent_dim=512, style_depth=8, lr_mlp=.1, transfer_mode=False):
|
||||
super().__init__()
|
||||
self.vectorizer = StyleVectorizer(latent_dim, style_depth, lr_mul=lr_mlp)
|
||||
self.gen = Generator(image_size=image_size, latent_dim=latent_dim, initial_stride=initial_stride)
|
||||
# Assume the vectorizer doesnt need transfer_mode=True. Re-evaluate this later.
|
||||
self.vectorizer = StyleVectorizer(latent_dim, style_depth, lr_mul=lr_mlp, transfer_mode=False)
|
||||
self.gen = Generator(image_size=image_size, latent_dim=latent_dim, initial_stride=initial_stride, transfer_mode=transfer_mode)
|
||||
self.mixed_prob = .9
|
||||
self._init_weights()
|
||||
self.transfer_mode = transfer_mode
|
||||
if transfer_mode:
|
||||
for p in self.parameters():
|
||||
if not hasattr(p, 'FOR_TRANSFER_LEARNING'):
|
||||
p.DO_NOT_TRAIN = True
|
||||
|
||||
|
||||
def _init_weights(self):
|
||||
for m in self.modules():
|
||||
if type(m) in {nn.Conv2d, nn.Linear} and hasattr(m, 'weight'):
|
||||
if type(m) in {TransferConv2d, TransferLinear} and hasattr(m, 'weight'):
|
||||
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
|
||||
|
||||
for block in self.gen.blocks:
|
||||
|
@ -132,7 +155,11 @@ class StyledSrGenerator(nn.Module):
|
|||
|
||||
# Synthesize style latents from noise.
|
||||
style = torch.randn(b*2, self.gen.latent_dim).to(x.device)
|
||||
w = self.vectorizer(style)
|
||||
if self.transfer_mode:
|
||||
with torch.no_grad():
|
||||
w = self.vectorizer(style)
|
||||
else:
|
||||
w = self.vectorizer(style)
|
||||
|
||||
# Randomly distribute styles across layers
|
||||
w_styles = w[:,None,:].expand(-1, self.gen.num_layers, -1).clone()
|
||||
|
@ -162,4 +189,6 @@ if __name__ == '__main__':
|
|||
|
||||
@register_model
|
||||
def register_styled_sr(opt_net, opt):
|
||||
return StyledSrGenerator(128, initial_stride=opt_get(opt_net, ['initial_stride'], 1))
|
||||
return StyledSrGenerator(128,
|
||||
initial_stride=opt_get(opt_net, ['initial_stride'], 1),
|
||||
transfer_mode=opt_get(opt_net, ['transfer_mode'], False))
|
||||
|
|
|
@ -6,8 +6,12 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
from kornia.filters import filter2D
|
||||
from linear_attention_transformer import ImageLinearAttention
|
||||
from torch import nn
|
||||
from torch import nn, Tensor
|
||||
from torch.autograd import grad as torch_grad
|
||||
from torch.nn import Parameter, init
|
||||
from torch.nn.modules.conv import _ConvNd
|
||||
|
||||
from models.styled_sr.transfer_primitives import TransferLinear
|
||||
|
||||
assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.'
|
||||
|
||||
|
@ -196,10 +200,9 @@ def slerp(val, low, high):
|
|||
res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
|
||||
return res
|
||||
|
||||
# stylegan2 classes
|
||||
|
||||
class EqualLinear(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, lr_mul=1, bias=True):
|
||||
def __init__(self, in_dim, out_dim, lr_mul=1, bias=True, transfer_mode=False):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.randn(out_dim, in_dim))
|
||||
if bias:
|
||||
|
@ -207,17 +210,28 @@ class EqualLinear(nn.Module):
|
|||
|
||||
self.lr_mul = lr_mul
|
||||
|
||||
self.transfer_mode = transfer_mode
|
||||
if transfer_mode:
|
||||
self.transfer_scale = nn.Parameter(torch.ones(out_features, in_features))
|
||||
self.transfer_scale.FOR_TRANSFER_LEARNING = True
|
||||
self.transfer_shift = nn.Parameter(torch.zeros(out_features, in_features))
|
||||
self.transfer_shift.FOR_TRANSFER_LEARNING = True
|
||||
|
||||
def forward(self, input):
|
||||
return F.linear(input, self.weight * self.lr_mul, bias=self.bias * self.lr_mul)
|
||||
if self.transfer_mode:
|
||||
weight = self.weight * self.transfer_scale + self.transfer_shift
|
||||
else:
|
||||
weight = self.weight
|
||||
return F.linear(input, weight * self.lr_mul, bias=self.bias * self.lr_mul)
|
||||
|
||||
|
||||
class StyleVectorizer(nn.Module):
|
||||
def __init__(self, emb, depth, lr_mul=0.1):
|
||||
def __init__(self, emb, depth, lr_mul=0.1, transfer_mode=False):
|
||||
super().__init__()
|
||||
|
||||
layers = []
|
||||
for i in range(depth):
|
||||
layers.extend([EqualLinear(emb, emb, lr_mul), leaky_relu()])
|
||||
layers.extend([EqualLinear(emb, emb, lr_mul, transfer_mode=transfer_mode), leaky_relu()])
|
||||
|
||||
self.net = nn.Sequential(*layers)
|
||||
|
||||
|
@ -227,13 +241,13 @@ class StyleVectorizer(nn.Module):
|
|||
|
||||
|
||||
class RGBBlock(nn.Module):
|
||||
def __init__(self, latent_dim, input_channel, upsample, rgba=False):
|
||||
def __init__(self, latent_dim, input_channel, upsample, rgba=False, transfer_mode=False):
|
||||
super().__init__()
|
||||
self.input_channel = input_channel
|
||||
self.to_style = nn.Linear(latent_dim, input_channel)
|
||||
|
||||
out_filters = 3 if not rgba else 4
|
||||
self.conv = Conv2DMod(input_channel, out_filters, 1, demod=False)
|
||||
self.conv = Conv2DMod(input_channel, out_filters, 1, demod=False, transfer_mode=transfer_mode)
|
||||
|
||||
self.upsample = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
|
||||
|
@ -307,25 +321,11 @@ class EqualLR:
|
|||
|
||||
def equal_lr(module, name='weight'):
|
||||
EqualLR.apply(module, name)
|
||||
|
||||
return module
|
||||
|
||||
|
||||
class EqualConv2d(nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
conv = nn.Conv2d(*args, **kwargs)
|
||||
conv.weight.data.normal_()
|
||||
conv.bias.data.zero_()
|
||||
self.conv = equal_lr(conv)
|
||||
|
||||
def forward(self, input):
|
||||
return self.conv(input)
|
||||
|
||||
|
||||
class Conv2DMod(nn.Module):
|
||||
def __init__(self, in_chan, out_chan, kernel, demod=True, stride=1, dilation=1, **kwargs):
|
||||
def __init__(self, in_chan, out_chan, kernel, demod=True, stride=1, dilation=1, transfer_mode=False, **kwargs):
|
||||
super().__init__()
|
||||
self.filters = out_chan
|
||||
self.demod = demod
|
||||
|
@ -334,6 +334,12 @@ class Conv2DMod(nn.Module):
|
|||
self.dilation = dilation
|
||||
self.weight = nn.Parameter(torch.randn((out_chan, in_chan, kernel, kernel)))
|
||||
nn.init.kaiming_normal_(self.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
|
||||
self.transfer_mode = transfer_mode
|
||||
if transfer_mode:
|
||||
self.transfer_scale = nn.Parameter(torch.ones(out_chan, in_chan, 1, 1))
|
||||
self.transfer_scale.FOR_TRANSFER_LEARNING = True
|
||||
self.transfer_shift = nn.Parameter(torch.zeros(out_chan, in_chan, 1, 1))
|
||||
self.transfer_shift.FOR_TRANSFER_LEARNING = True
|
||||
|
||||
def _get_same_padding(self, size, kernel, dilation, stride):
|
||||
return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2
|
||||
|
@ -341,8 +347,13 @@ class Conv2DMod(nn.Module):
|
|||
def forward(self, x, y):
|
||||
b, c, h, w = x.shape
|
||||
|
||||
if self.transfer_mode:
|
||||
weight = self.weight * self.transfer_scale + self.transfer_shift
|
||||
else:
|
||||
weight = self.weight
|
||||
|
||||
w1 = y[:, None, :, None, None]
|
||||
w2 = self.weight[None, :, :, :, :]
|
||||
w2 = weight[None, :, :, :, :]
|
||||
weights = w2 * (w1 + 1)
|
||||
|
||||
if self.demod:
|
||||
|
@ -362,34 +373,28 @@ class Conv2DMod(nn.Module):
|
|||
|
||||
|
||||
class GeneratorBlock(nn.Module):
|
||||
def __init__(self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, rgba=False, structure_input=False):
|
||||
def __init__(self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, rgba=False,
|
||||
transfer_learning_mode=False):
|
||||
super().__init__()
|
||||
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) if upsample else None
|
||||
|
||||
self.structure_input = structure_input
|
||||
if self.structure_input:
|
||||
self.structure_conv = nn.Conv2d(3, input_channels, 3, padding=1)
|
||||
input_channels = input_channels * 2
|
||||
self.to_style1 = TransferLinear(latent_dim, input_channels, transfer_mode=transfer_learning_mode)
|
||||
self.to_noise1 = TransferLinear(1, filters, transfer_mode=transfer_learning_mode)
|
||||
self.conv1 = Conv2DMod(input_channels, filters, 3, transfer_mode=transfer_learning_mode)
|
||||
|
||||
self.to_style1 = nn.Linear(latent_dim, input_channels)
|
||||
self.to_noise1 = nn.Linear(1, filters)
|
||||
self.conv1 = Conv2DMod(input_channels, filters, 3)
|
||||
|
||||
self.to_style2 = nn.Linear(latent_dim, filters)
|
||||
self.to_noise2 = nn.Linear(1, filters)
|
||||
self.conv2 = Conv2DMod(filters, filters, 3)
|
||||
self.to_style2 = TransferLinear(latent_dim, filters, transfer_mode=transfer_learning_mode)
|
||||
self.to_noise2 = TransferLinear(1, filters, transfer_mode=transfer_learning_mode)
|
||||
self.conv2 = Conv2DMod(filters, filters, 3, transfer_mode=transfer_learning_mode)
|
||||
|
||||
self.activation = leaky_relu()
|
||||
self.to_rgb = RGBBlock(latent_dim, filters, upsample_rgb, rgba)
|
||||
self.to_rgb = RGBBlock(latent_dim, filters, upsample_rgb, rgba, transfer_mode=transfer_learning_mode)
|
||||
|
||||
def forward(self, x, prev_rgb, istyle, inoise, structure_input=None):
|
||||
self.transfer_learning_mode = transfer_learning_mode
|
||||
|
||||
def forward(self, x, prev_rgb, istyle, inoise):
|
||||
if exists(self.upsample):
|
||||
x = self.upsample(x)
|
||||
|
||||
if self.structure_input:
|
||||
s = self.structure_conv(structure_input)
|
||||
x = torch.cat([x, s], dim=1)
|
||||
|
||||
inoise = inoise[:, :x.shape[2], :x.shape[3], :]
|
||||
noise1 = self.to_noise1(inoise).permute((0, 3, 2, 1))
|
||||
noise2 = self.to_noise2(inoise).permute((0, 3, 2, 1))
|
||||
|
|
136
codes/models/styled_sr/transfer_primitives.py
Normal file
136
codes/models/styled_sr/transfer_primitives.py
Normal file
|
@ -0,0 +1,136 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.nn import Parameter, init
|
||||
from torch.nn.modules.conv import _ConvNd
|
||||
from torch.nn.modules.utils import _ntuple
|
||||
|
||||
_pair = _ntuple(2)
|
||||
|
||||
class TransferConv2d(_ConvNd):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size,
|
||||
stride = 1,
|
||||
padding = 0,
|
||||
dilation = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
padding_mode: str = 'zeros',
|
||||
transfer_mode: bool = False
|
||||
):
|
||||
kernel_size = _pair(kernel_size)
|
||||
stride = _pair(stride)
|
||||
padding = _pair(padding)
|
||||
dilation = _pair(dilation)
|
||||
super().__init__(
|
||||
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
||||
False, _pair(0), groups, bias, padding_mode)
|
||||
|
||||
self.transfer_mode = transfer_mode
|
||||
if transfer_mode:
|
||||
self.transfer_scale = nn.Parameter(torch.ones(out_channels, in_channels, 1, 1))
|
||||
self.transfer_scale.FOR_TRANSFER_LEARNING = True
|
||||
self.transfer_shift = nn.Parameter(torch.zeros(out_channels, in_channels, 1, 1))
|
||||
self.transfer_shift.FOR_TRANSFER_LEARNING = True
|
||||
|
||||
def _conv_forward(self, input, weight):
|
||||
if self.transfer_mode:
|
||||
weight = weight * self.transfer_scale + self.transfer_shift
|
||||
else:
|
||||
weight = weight
|
||||
|
||||
if self.padding_mode != 'zeros':
|
||||
return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
|
||||
weight, self.bias, self.stride,
|
||||
_pair(0), self.dilation, self.groups)
|
||||
return F.conv2d(input, weight, self.bias, self.stride,
|
||||
self.padding, self.dilation, self.groups)
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
return self._conv_forward(input, self.weight)
|
||||
|
||||
|
||||
class TransferLinear(nn.Module):
|
||||
__constants__ = ['in_features', 'out_features']
|
||||
in_features: int
|
||||
out_features: int
|
||||
weight: Tensor
|
||||
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = True, transfer_mode: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.weight = Parameter(torch.Tensor(out_features, in_features))
|
||||
if bias:
|
||||
self.bias = Parameter(torch.Tensor(out_features))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
self.reset_parameters()
|
||||
self.transfer_mode = transfer_mode
|
||||
if transfer_mode:
|
||||
self.transfer_scale = nn.Parameter(torch.ones(out_features, in_features))
|
||||
self.transfer_scale.FOR_TRANSFER_LEARNING = True
|
||||
self.transfer_shift = nn.Parameter(torch.zeros(out_features, in_features))
|
||||
self.transfer_shift.FOR_TRANSFER_LEARNING = True
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
if self.bias is not None:
|
||||
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
init.uniform_(self.bias, -bound, bound)
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
if self.transfer_mode:
|
||||
weight = self.weight * self.transfer_scale + self.transfer_shift
|
||||
else:
|
||||
weight = self.weight
|
||||
return F.linear(input, weight, self.bias)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return 'in_features={}, out_features={}, bias={}'.format(
|
||||
self.in_features, self.out_features, self.bias is not None
|
||||
)
|
||||
|
||||
|
||||
class TransferConvGnLelu(nn.Module):
|
||||
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, num_groups=8, weight_init_factor=1, transfer_mode=False):
|
||||
super().__init__()
|
||||
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
||||
assert kernel_size in padding_map.keys()
|
||||
self.conv = TransferConv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias, transfer_mode=transfer_mode)
|
||||
if norm:
|
||||
self.gn = nn.GroupNorm(num_groups, filters_out)
|
||||
else:
|
||||
self.gn = None
|
||||
if activation:
|
||||
self.lelu = nn.LeakyReLU(negative_slope=.2)
|
||||
else:
|
||||
self.lelu = None
|
||||
|
||||
# Init params.
|
||||
for m in self.modules():
|
||||
if isinstance(m, TransferConv2d):
|
||||
nn.init.kaiming_normal_(m.weight, a=.1, mode='fan_out',
|
||||
nonlinearity='leaky_relu' if self.lelu else 'linear')
|
||||
m.weight.data *= weight_init_factor
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
if self.gn:
|
||||
x = self.gn(x)
|
||||
if self.lelu:
|
||||
return self.lelu(x)
|
||||
else:
|
||||
return x
|
0
codes/scripts/tsne_torch.py
Normal file
0
codes/scripts/tsne_torch.py
Normal file
Loading…
Reference in New Issue
Block a user