Transfer learning for styleSR

This is a concept from "Lifelong Learning GAN", although I'm skeptical of it's novelty -
basically you scale and shift the weights for the generator and discriminator of a pretrained
GAN to "shift" into new modalities, e.g. faces->birds or whatever. There are some interesting
applications of this that I would like to try out.
This commit is contained in:
James Betker 2021-01-04 20:10:48 -07:00
parent 2c65b6b28e
commit ade2732c82
6 changed files with 289 additions and 93 deletions

6
.idea/other.xml Normal file
View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PySciProjectComponent">
<option name="PY_MATPLOTLIB_IN_TOOLWINDOW" value="false" />
</component>
</project>

View File

@ -1,4 +1,6 @@
# Heavily based on the lucidrains stylegan2 discriminator implementation.
import math
import os
from functools import partial
from math import log2
from random import random
@ -6,31 +8,33 @@ from random import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.autograd import grad as torch_grad
import trainer.losses as L
from vector_quantize_pytorch import VectorQuantize
from models.styled_sr.stylegan2_base import attn_and_ff, PermuteToFrom, Blur, leaky_relu, exists
from models.styled_sr.transfer_primitives import TransferConv2d, TransferLinear
from trainer.networks import register_model
from utils.util import checkpoint, opt_get
class DiscriminatorBlock(nn.Module):
def __init__(self, input_channels, filters, downsample=True):
def __init__(self, input_channels, filters, downsample=True, transfer_mode=False):
super().__init__()
self.filters = filters
self.conv_res = nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1))
self.conv_res = TransferConv2d(input_channels, filters, 1, stride=(2 if downsample else 1), transfer_mode=transfer_mode)
self.net = nn.Sequential(
nn.Conv2d(input_channels, filters, 3, padding=1),
TransferConv2d(input_channels, filters, 3, padding=1, transfer_mode=transfer_mode),
leaky_relu(),
nn.Conv2d(filters, filters, 3, padding=1),
TransferConv2d(filters, filters, 3, padding=1, transfer_mode=transfer_mode),
leaky_relu()
)
self.downsample = nn.Sequential(
Blur(),
nn.Conv2d(filters, filters, 3, padding=1, stride=2)
TransferConv2d(filters, filters, 3, padding=1, stride=2, transfer_mode=transfer_mode)
) if downsample else None
def forward(self, x):
@ -42,9 +46,10 @@ class DiscriminatorBlock(nn.Module):
return x
class StyleGan2Discriminator(nn.Module):
class StyleSrDiscriminator(nn.Module):
def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[],
transparent=False, fmap_max=512, input_filters=3, quantize=False, do_checkpointing=False, mlp=False):
transparent=False, fmap_max=512, input_filters=3, quantize=False, do_checkpointing=False, mlp=False,
transfer_mode=False):
super().__init__()
num_layers = int(log2(image_size) - 1)
@ -63,7 +68,7 @@ class StyleGan2Discriminator(nn.Module):
num_layer = ind + 1
is_not_last = ind != (len(chan_in_out) - 1)
block = DiscriminatorBlock(in_chan, out_chan, downsample=is_not_last)
block = DiscriminatorBlock(in_chan, out_chan, downsample=is_not_last, transfer_mode=transfer_mode)
blocks.append(block)
attn_fn = attn_and_ff(out_chan) if num_layer in attn_layers else None
@ -84,17 +89,23 @@ class StyleGan2Discriminator(nn.Module):
chan_last = filters[-1]
latent_dim = 2 * 2 * chan_last
self.final_conv = nn.Conv2d(chan_last, chan_last, 3, padding=1)
self.final_conv = TransferConv2d(chan_last, chan_last, 3, padding=1, transfer_mode=transfer_mode)
self.flatten = nn.Flatten()
if mlp:
self.to_logit = nn.Sequential(nn.Linear(latent_dim, 100),
self.to_logit = nn.Sequential(TransferLinear(latent_dim, 100, transfer_mode=transfer_mode),
leaky_relu(),
nn.Linear(100, 1))
TransferLinear(100, 1, transfer_mode=transfer_mode))
else:
self.to_logit = nn.Linear(latent_dim, 1)
self.to_logit = TransferLinear(latent_dim, 1, transfer_mode=transfer_mode)
self._init_weights()
self.transfer_mode = transfer_mode
if transfer_mode:
for p in self.parameters():
if not hasattr(p, 'FOR_TRANSFER_LEARNING'):
p.DO_NOT_TRAIN = True
def forward(self, x):
b, *_ = x.shape
@ -123,12 +134,12 @@ class StyleGan2Discriminator(nn.Module):
def _init_weights(self):
for m in self.modules():
if type(m) in {nn.Conv2d, nn.Linear}:
if type(m) in {TransferConv2d, TransferLinear}:
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
# Configures the network as partially pre-trained. This means:
# 1) The top (high-resolution) `num_blocks` will have their weights re-initialized.
# 2) The haed (linear layers) will also have their weights re-initialized
# 2) The head (linear layers) will also have their weights re-initialized
# 3) All intermediate blocks will be frozen until step `frozen_until_step`
# These settings will be applied after the weights have been loaded (network_loaded())
def configure_partial_training(self, bypass_blocks=0, num_blocks=2, frozen_until_step=0):
@ -150,7 +161,7 @@ class StyleGan2Discriminator(nn.Module):
reset_blocks.append(self.blocks[i])
for bl in reset_blocks:
for m in bl.modules():
if type(m) in {nn.Conv2d, nn.Linear}:
if type(m) in {TransferConv2d, TransferLinear}:
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
for p in m.parameters(recurse=True):
p._NEW_BLOCK = True
@ -237,13 +248,15 @@ class DiscAugmentor(nn.Module):
self.prob = prob
self.types = types
def forward(self, images, detach=False):
def forward(self, images, real_images=False):
if random() < self.prob:
images = random_hflip(images, prob=0.5)
images = DiffAugment(images, types=self.types)
if detach:
images = images.detach()
if real_images:
self.hq_aug = images.detach().clone()
else:
self.gen_aug = images.detach().clone()
# Save away for use elsewhere (e.g. unet loss)
self.aug_images = images
@ -253,6 +266,11 @@ class DiscAugmentor(nn.Module):
def network_loaded(self):
self.D.network_loaded<