GLEAN!
This commit is contained in:
parent
c717765bcb
commit
92f9a129f7
116
codes/models/glean/glean.py
Normal file
116
codes/models/glean/glean.py
Normal file
|
@ -0,0 +1,116 @@
|
|||
import math
|
||||
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
|
||||
from models.RRDBNet_arch import RRDB
|
||||
from models.arch_util import ConvGnLelu
|
||||
|
||||
|
||||
# Produces a convolutional feature (`f`) and a reduced feature map with double the filters.
|
||||
from models.glean.stylegan2_latent_bank import Stylegan2LatentBank
|
||||
from models.stylegan.stylegan2_rosinality import EqualLinear
|
||||
from utils.util import checkpoint, sequential_checkpoint
|
||||
|
||||
|
||||
class GleanEncoderBlock(nn.Module):
|
||||
def __init__(self, nf):
|
||||
super().__init__()
|
||||
self.structural_latent_conv = ConvGnLelu(nf, nf, kernel_size=1, activation=False, norm=False, bias=True)
|
||||
self.process = nn.Sequential(
|
||||
ConvGnLelu(nf, nf*2, kernel_size=3, stride=2, activation=True, norm=False, bias=False),
|
||||
ConvGnLelu(nf*2, nf*2, kernel_size=3, activation=True, norm=False, bias=False)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
structural_latent = self.structural_latent_conv(x)
|
||||
fea = self.process(x)
|
||||
return fea, structural_latent
|
||||
|
||||
|
||||
# Produces RRDB features, a list of convolutional features (`f` shape=[l][b,c,h,w] l=levels aka f_sub)
|
||||
# and latent vectors (`C` shape=[b,l,f] l=levels aka C_sub) for use with the latent bank.
|
||||
# Note that latent levels and convolutional feature levels do not necessarily match, per the paper.
|
||||
class GleanEncoder(nn.Module):
|
||||
def __init__(self, nf, nb, reductions=4, latent_bank_blocks=13, latent_bank_latent_dim=512, input_dim=32):
|
||||
super().__init__()
|
||||
self.initial_conv = ConvGnLelu(3, nf, kernel_size=7, activation=False, norm=False, bias=True)
|
||||
self.rrdb_blocks = nn.Sequential(*[RRDB(nf) for _ in range(nb)])
|
||||
self.reducers = nn.ModuleList([GleanEncoderBlock(nf * 2 ** i) for i in range(reductions)])
|
||||
|
||||
reducer_output_dim = (input_dim // (2 ** reductions)) ** 2
|
||||
reducer_output_nf = nf * 2 ** reductions
|
||||
self.latent_conv = ConvGnLelu(reducer_output_nf, reducer_output_nf, kernel_size=1, activation=True, norm=False, bias=True)
|
||||
# This is a questionable part of this architecture. Apply multiple Denses to separate outputs (as I've done here)?
|
||||
# Apply a single dense, then split the outputs? Who knows..
|
||||
self.latent_linears = nn.ModuleList([EqualLinear(reducer_output_dim * reducer_output_nf, latent_bank_latent_dim,
|
||||
activation="fused_lrelu")
|
||||
for _ in range(latent_bank_blocks)])
|
||||
|
||||
def forward(self, x):
|
||||
fea = self.initial_conv(x)
|
||||
fea = sequential_checkpoint(self.rrdb_blocks, len(self.rrdb_blocks), fea)
|
||||
rrdb_fea = fea
|
||||
convolutional_features = []
|
||||
for reducer in self.reducers:
|
||||
fea, f = checkpoint(reducer, fea)
|
||||
convolutional_features.append(f)
|
||||
|
||||
latents = self.latent_conv(fea)
|
||||
latents = [dense(latents.flatten(1, -1)) for dense in self.latent_linears]
|
||||
latents = torch.stack(latents, dim=1)
|
||||
|
||||
return rrdb_fea, convolutional_features, latents
|
||||
|
||||
|
||||
# Produces an image by fusing the output features from the latent bank.
|
||||
class GleanDecoder(nn.Module):
|
||||
# To determine latent_bank_filters, use the `self.channels` map for the desired input dimensions from stylegan2_rosinality.py
|
||||
def __init__(self, nf, latent_bank_filters=[512, 256, 128]):
|
||||
super().__init__()
|
||||
self.initial_conv = ConvGnLelu(nf, nf, kernel_size=3, activation=False, norm=False, bias=True)
|
||||
|
||||
# The paper calls for pixel shuffling each output of the decoder. We need to make sure that is possible. Doing it by using the latent bank filters as the output filters for each decoder stage
|
||||
assert latent_bank_filters[-1] % 4 == 0
|
||||
decoder_block_shuffled_dims = [nf // 4]
|
||||
decoder_block_shuffled_dims.extend([l // 4 for l in latent_bank_filters])
|
||||
self.decoder_blocks = nn.ModuleList([ConvGnLelu(decoder_block_shuffled_dims[i] + latent_bank_filters[i],
|
||||
latent_bank_filters[i],
|
||||
kernel_size=3, bias=True, norm=False, activation=False)
|
||||
for i in range(len(latent_bank_filters))])
|
||||
self.shuffler = nn.PixelShuffle(2) # TODO: I'm a bit skeptical about this. It doesn't align with RRDB or StyleGAN. It also always produces artifacts in my experience. Try using interpolation instead.
|
||||
|
||||
final_dim = latent_bank_filters[-1]
|
||||
self.final_decode = nn.Sequential(ConvGnLelu(final_dim, final_dim, kernel_size=3, activation=True, bias=True, norm=False),
|
||||
ConvGnLelu(final_dim, 3, kernel_size=3, activation=False, bias=True, norm=False))
|
||||
|
||||
def forward(self, rrdb_fea, latent_bank_fea):
|
||||
fea = self.initial_conv(rrdb_fea)
|
||||
for i, block in enumerate(self.decoder_blocks):
|
||||
fea = self.shuffler(fea)
|
||||
fea = torch.cat([fea, latent_bank_fea[i]], dim=1)
|
||||
fea = checkpoint(block, fea)
|
||||
return self.final_decode(fea)
|
||||
|
||||
|
||||
class GleanGenerator(nn.Module):
|
||||
def __init__(self, nf, latent_bank_pretrained_weights, latent_bank_max_dim=1024, gen_output_dim=256,
|
||||
encoder_rrdb_nb=6, encoder_reductions=4, latent_bank_latent_dim=512, input_dim=32):
|
||||
super().__init__()
|
||||
self.input_dim = input_dim
|
||||
latent_blocks = int(math.log(gen_output_dim, 2)) - 1 # From 4x4->gen_output_dim x gen_output_dim
|
||||
latent_blocks = latent_blocks * 2 + 1 # Two styled convolutions per block, + an initial styled conv.
|
||||
self.encoder = GleanEncoder(nf, encoder_rrdb_nb, reductions=encoder_reductions, latent_bank_blocks=latent_blocks * 2 + 1,
|
||||
latent_bank_latent_dim=latent_bank_latent_dim, input_dim=input_dim)
|
||||
decoder_blocks = int(math.log(gen_output_dim/input_dim, 2))
|
||||
latent_bank_filters_out = [512, 256, 128] # TODO: Use decoder_blocks to synthesize the correct value for latent_bank_filters here. The fixed defaults will work fine for testing, though.
|
||||
self.latent_bank = Stylegan2LatentBank(latent_bank_pretrained_weights, encoder_nf=nf, max_dim=latent_bank_max_dim,
|
||||
latent_dim=latent_bank_latent_dim, encoder_levels=encoder_reductions,
|
||||
decoder_levels=decoder_blocks)
|
||||
self.decoder = GleanDecoder(nf, latent_bank_filters_out)
|
||||
|
||||
def forward(self, x):
|
||||
assert self.input_dim == x.shape[-1] and self.input_dim == x.shape[-2]
|
||||
rrdb_fea, conv_fea, latents = self.encoder(x)
|
||||
latent_bank_fea = self.latent_bank(conv_fea, latents)
|
||||
return self.decoder(rrdb_fea, latent_bank_fea)
|
66
codes/models/glean/stylegan2_latent_bank.py
Normal file
66
codes/models/glean/stylegan2_latent_bank.py
Normal file
|
@ -0,0 +1,66 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from models.arch_util import ConvGnLelu
|
||||
from models.stylegan.stylegan2_rosinality import Generator
|
||||
|
||||
|
||||
class Stylegan2LatentBank(nn.Module):
|
||||
def __init__(self, pretrained_model_file, encoder_nf=64, max_dim=1024, latent_dim=512, encoder_levels=4, decoder_levels=3):
|
||||
super().__init__()
|
||||
|
||||
# Initialize the bank.
|
||||
self.bank = Generator(size=max_dim, style_dim=latent_dim, n_mlp=8, channel_multiplier=2) # Assumed using 'f' generators with mult=2.
|
||||
state_dict = torch.load(pretrained_model_file)
|
||||
self.bank.load_state_dict(state_dict, strict=True)
|
||||
|
||||
# Shut off training of the latent bank.
|
||||
for p in self.bank.parameters():
|
||||
p.requires_grad = False
|
||||
p.DO_NOT_TRAIN = True
|
||||
|
||||
# TODO: Compute these based on the underlying stylegans channels member variable.
|
||||
stylegan_encoder_dims = [512, 512, 512, 512]
|
||||
|
||||
# Initialize the fusion blocks. TODO: Try using the StyledConvs instead of regular ones.
|
||||
encoder_output_dims = reversed([64 * 2 ** i for i in range(encoder_levels)])
|
||||
input_dims_by_layer = [eod + sed for eod, sed in zip(encoder_output_dims, stylegan_encoder_dims)]
|
||||
self.fusion_blocks = nn.ModuleList([ConvGnLelu(in_filters, out_filters, kernel_size=3, activation=True, norm=False, bias=True)
|
||||
for in_filters, out_filters in zip(input_dims_by_layer, stylegan_encoder_dims)])
|
||||
|
||||
self.decoder_levels = decoder_levels
|
||||
self.decoder_start = encoder_levels - 1
|
||||
self.total_levels = encoder_levels + decoder_levels - 1
|
||||
|
||||
# This forward mirrors the forward() pass from the rosinality stylegan2 implementation, with the additions called
|
||||
# for from the GLEAN paper. GLEAN mods are annotated with comments.
|
||||
# Removed stuff:
|
||||
# - Support for split latents (we're spoonfeeding them)
|
||||
# - Support for fixed noise inputs
|
||||
# - RGB computations -> we only care about the latents
|
||||
# - Style MLP -> GLEAN computes the Style inputs directly.
|
||||
# - Later layers -> GLEAN terminates at 256 resolution.
|
||||
def forward(self, convolutional_features, latent_vectors):
|
||||
|
||||
out = self.bank.input(latent_vectors[:, 0]) # The input here is only used to fetch the batch size.
|
||||
out = self.bank.conv1(out, latent_vectors[:, 0], noise=None)
|
||||
|
||||
i, k = 1, 0
|
||||
decoder_outputs = []
|
||||
for conv1, conv2 in zip(self.bank.convs[::2], self.bank.convs[1::2]):
|
||||
if k < len(self.fusion_blocks):
|
||||
out = torch.cat([convolutional_features[-k-1], out], dim=1)
|
||||
out = self.fusion_blocks[k](out)
|
||||
|
||||
out = conv1(out, latent_vectors[:, i], noise=None)
|
||||
out = conv2(out, latent_vectors[:, i + 1], noise=None)
|
||||
|
||||
if k >= self.decoder_start:
|
||||
decoder_outputs.append(out)
|
||||
if k >= self.total_levels:
|
||||
break
|
||||
|
||||
i += 2
|
||||
k += 1
|
||||
|
||||
return decoder_outputs
|
|
@ -1,11 +1,11 @@
|
|||
import models.stylegan.stylegan2_lucidrains as stylegan2
|
||||
|
||||
|
||||
def create_stylegan2_loss(opt_loss, env):
|
||||
type = opt_loss['type']
|
||||
if type == 'stylegan2_divergence':
|
||||
import models.stylegan.stylegan2_lucidrains as stylegan2
|
||||
return stylegan2.StyleGan2DivergenceLoss(opt_loss, env)
|
||||
elif type == 'stylegan2_pathlen':
|
||||
import models.stylegan.stylegan2_lucidrains as stylegan2
|
||||
return stylegan2.StyleGan2PathLengthLoss(opt_loss, env)
|
||||
else:
|
||||
raise NotImplementedError
|
|
@ -8,7 +8,7 @@ from random import random
|
|||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import models.steps.losses as L
|
||||
import trainer.losses as L
|
||||
import numpy as np
|
||||
|
||||
from kornia.filters import filter2D
|
||||
|
|
|
@ -269,7 +269,7 @@ if __name__ == "__main__":
|
|||
ckpt["d"] = d_state
|
||||
|
||||
name = os.path.splitext(os.path.basename(args.path))[0]
|
||||
torch.save(ckpt, name + ".pt")
|
||||
torch.save(state_dict, name + ".pth")
|
||||
|
||||
batch_size = {256: 16, 512: 9, 1024: 4}
|
||||
n_sample = batch_size.get(size, 25)
|
||||
|
|
|
@ -125,6 +125,9 @@ def define_G(opt, opt_net, scale=None):
|
|||
from models.spinenet_arch import SpinenetWithLogits
|
||||
netG = SpinenetWithLogits(str(opt_net['arch']), opt_net['output_to_attach'], opt_net['num_labels'],
|
||||
in_channels=3, use_input_norm=opt_net['use_input_norm'])
|
||||
elif which_model == 'glean':
|
||||
from models.glean.glean import GleanGenerator
|
||||
netG = GleanGenerator(opt_net['nf'], opt_net['pretrained_stylegan'])
|
||||
else:
|
||||
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
||||
return netG
|
||||
|
|
Loading…
Reference in New Issue
Block a user