forked from mrq/DL-Art-School
Transfer learning for styleSR
This is a concept from "Lifelong Learning GAN", although I'm skeptical of it's novelty - basically you scale and shift the weights for the generator and discriminator of a pretrained GAN to "shift" into new modalities, e.g. faces->birds or whatever. There are some interesting applications of this that I would like to try out.
This commit is contained in:
parent
2c65b6b28e
commit
ade2732c82
6
.idea/other.xml
Normal file
6
.idea/other.xml
Normal file
|
@ -0,0 +1,6 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="PySciProjectComponent">
|
||||
<option name="PY_MATPLOTLIB_IN_TOOLWINDOW" value="false" />
|
||||
</component>
|
||||
</project>
|
|
@ -1,4 +1,6 @@
|
|||
# Heavily based on the lucidrains stylegan2 discriminator implementation.
|
||||
import math
|
||||
import os
|
||||
from functools import partial
|
||||
from math import log2
|
||||
from random import random
|
||||
|
@ -6,31 +8,33 @@ from random import random
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
from torch.autograd import grad as torch_grad
|
||||
import trainer.losses as L
|
||||
from vector_quantize_pytorch import VectorQuantize
|
||||
|
||||
from models.styled_sr.stylegan2_base import attn_and_ff, PermuteToFrom, Blur, leaky_relu, exists
|
||||
from models.styled_sr.transfer_primitives import TransferConv2d, TransferLinear
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint, opt_get
|
||||
|
||||
|
||||
class DiscriminatorBlock(nn.Module):
|
||||
def __init__(self, input_channels, filters, downsample=True):
|
||||
def __init__(self, input_channels, filters, downsample=True, transfer_mode=False):
|
||||
super().__init__()
|
||||
self.filters = filters
|
||||
self.conv_res = nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1))
|
||||
self.conv_res = TransferConv2d(input_channels, filters, 1, stride=(2 if downsample else 1), transfer_mode=transfer_mode)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(input_channels, filters, 3, padding=1),
|
||||
TransferConv2d(input_channels, filters, 3, padding=1, transfer_mode=transfer_mode),
|
||||
leaky_relu(),
|
||||
nn.Conv2d(filters, filters, 3, padding=1),
|
||||
TransferConv2d(filters, filters, 3, padding=1, transfer_mode=transfer_mode),
|
||||
leaky_relu()
|
||||
)
|
||||
|
||||
self.downsample = nn.Sequential(
|
||||
Blur(),
|
||||
nn.Conv2d(filters, filters, 3, padding=1, stride=2)
|
||||
TransferConv2d(filters, filters, 3, padding=1, stride=2, transfer_mode=transfer_mode)
|
||||
) if downsample else None
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -42,9 +46,10 @@ class DiscriminatorBlock(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class StyleGan2Discriminator(nn.Module):
|
||||
class StyleSrDiscriminator(nn.Module):
|
||||
def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[],
|
||||
transparent=False, fmap_max=512, input_filters=3, quantize=False, do_checkpointing=False, mlp=False):
|
||||
transparent=False, fmap_max=512, input_filters=3, quantize=False, do_checkpointing=False, mlp=False,
|
||||
transfer_mode=False):
|
||||
super().__init__()
|
||||
num_layers = int(log2(image_size) - 1)
|
||||
|
||||
|
@ -63,7 +68,7 @@ class StyleGan2Discriminator(nn.Module):
|
|||
num_layer = ind + 1
|
||||
is_not_last = ind != (len(chan_in_out) - 1)
|
||||
|
||||
block = DiscriminatorBlock(in_chan, out_chan, downsample=is_not_last)
|
||||
block = DiscriminatorBlock(in_chan, out_chan, downsample=is_not_last, transfer_mode=transfer_mode)
|
||||
blocks.append(block)
|
||||
|
||||
attn_fn = attn_and_ff(out_chan) if num_layer in attn_layers else None
|
||||
|
@ -84,17 +89,23 @@ class StyleGan2Discriminator(nn.Module):
|
|||
chan_last = filters[-1]
|
||||
latent_dim = 2 * 2 * chan_last
|
||||
|
||||
self.final_conv = nn.Conv2d(chan_last, chan_last, 3, padding=1)
|
||||
self.final_conv = TransferConv2d(chan_last, chan_last, 3, padding=1, transfer_mode=transfer_mode)
|
||||
self.flatten = nn.Flatten()
|
||||
if mlp:
|
||||
self.to_logit = nn.Sequential(nn.Linear(latent_dim, 100),
|
||||
self.to_logit = nn.Sequential(TransferLinear(latent_dim, 100, transfer_mode=transfer_mode),
|
||||
leaky_relu(),
|
||||
nn.Linear(100, 1))
|
||||
TransferLinear(100, 1, transfer_mode=transfer_mode))
|
||||
else:
|
||||
self.to_logit = nn.Linear(latent_dim, 1)
|
||||
self.to_logit = TransferLinear(latent_dim, 1, transfer_mode=transfer_mode)
|
||||
|
||||
self._init_weights()
|
||||
|
||||
self.transfer_mode = transfer_mode
|
||||
if transfer_mode:
|
||||
for p in self.parameters():
|
||||
if not hasattr(p, 'FOR_TRANSFER_LEARNING'):
|
||||
p.DO_NOT_TRAIN = True
|
||||
|
||||
def forward(self, x):
|
||||
b, *_ = x.shape
|
||||
|
||||
|
@ -123,12 +134,12 @@ class StyleGan2Discriminator(nn.Module):
|
|||
|
||||
def _init_weights(self):
|
||||
for m in self.modules():
|
||||
if type(m) in {nn.Conv2d, nn.Linear}:
|
||||
if type(m) in {TransferConv2d, TransferLinear}:
|
||||
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
|
||||
|
||||
# Configures the network as partially pre-trained. This means:
|
||||
# 1) The top (high-resolution) `num_blocks` will have their weights re-initialized.
|
||||
# 2) The haed (linear layers) will also have their weights re-initialized
|
||||
# 2) The head (linear layers) will also have their weights re-initialized
|
||||
# 3) All intermediate blocks will be frozen until step `frozen_until_step`
|
||||
# These settings will be applied after the weights have been loaded (network_loaded())
|
||||
def configure_partial_training(self, bypass_blocks=0, num_blocks=2, frozen_until_step=0):
|
||||
|
@ -150,7 +161,7 @@ class StyleGan2Discriminator(nn.Module):
|
|||
reset_blocks.append(self.blocks[i])
|
||||
for bl in reset_blocks:
|
||||
for m in bl.modules():
|
||||
if type(m) in {nn.Conv2d, nn.Linear}:
|
||||
if type(m) in {TransferConv2d, TransferLinear}:
|
||||
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
|
||||
for p in m.parameters(recurse=True):
|
||||
p._NEW_BLOCK = True
|
||||
|
@ -237,13 +248,15 @@ class DiscAugmentor(nn.Module):
|
|||
self.prob = prob
|
||||
self.types = types
|
||||
|
||||
def forward(self, images, detach=False):
|
||||
def forward(self, images, real_images=False):
|
||||
if random() < self.prob:
|
||||
images = random_hflip(images, prob=0.5)
|
||||
images = DiffAugment(images, types=self.types)
|
||||
|
||||
if detach:
|
||||
images = images.detach()
|
||||
if real_images:
|
||||
self.hq_aug = images.detach().clone()
|
||||
else:
|
||||
self.gen_aug = images.detach().clone()
|
||||
|
||||
# Save away for use elsewhere (e.g. unet loss)
|
||||
self.aug_images = images
|
||||
|
@ -253,6 +266,11 @@ class DiscAugmentor(nn.Module):
|
|||
def network_loaded(self):
|
||||
self.D.network_loaded< |