DL-Art-School/dlas/models/image_generation/glean/glean.py

139 lines
6.9 KiB
Python
Raw Normal View History

2020-12-18 23:04:19 +00:00
import math
import torch
import torch.nn as nn
2020-12-18 23:04:19 +00:00
from dlas.models.arch_util import ConvGnLelu
2020-12-18 23:04:19 +00:00
# Produces a convolutional feature (`f`) and a reduced feature map with double the filters.
from dlas.models.image_generation.glean.stylegan2_latent_bank import \
Stylegan2LatentBank
from dlas.models.image_generation.RRDBNet_arch import RRDB
from dlas.models.image_generation.stylegan.stylegan2_rosinality import \
EqualLinear
from dlas.trainer.networks import register_model
from dlas.utils.util import checkpoint, sequential_checkpoint
2020-12-18 23:04:19 +00:00
class GleanEncoderBlock(nn.Module):
def __init__(self, nf, max_nf):
2020-12-18 23:04:19 +00:00
super().__init__()
self.structural_latent_conv = ConvGnLelu(
nf, nf, kernel_size=1, activation=False, norm=False, bias=True)
top_nf = min(nf*2, max_nf)
2020-12-18 23:04:19 +00:00
self.process = nn.Sequential(
ConvGnLelu(nf, top_nf, kernel_size=3, stride=2,
activation=True, norm=False, bias=False),
ConvGnLelu(top_nf, top_nf, kernel_size=3,
activation=True, norm=False, bias=False)
2020-12-18 23:04:19 +00:00
)
def forward(self, x):
structural_latent = self.structural_latent_conv(x)
fea = self.process(x)
return fea, structural_latent
# Produces RRDB features, a list of convolutional features (`f` shape=[l][b,c,h,w] l=levels aka f_sub)
# and latent vectors (`C` shape=[b,l,f] l=levels aka C_sub) for use with the latent bank.
# Note that latent levels and convolutional feature levels do not necessarily match, per the paper.
class GleanEncoder(nn.Module):
def __init__(self, nf, nb, max_nf=512, reductions=4, latent_bank_blocks=7, latent_bank_latent_dim=512, input_dim=32, initial_stride=1):
2020-12-18 23:04:19 +00:00
super().__init__()
self.initial_conv = ConvGnLelu(
3, nf, kernel_size=7, activation=False, norm=False, bias=True, stride=initial_stride)
2020-12-18 23:04:19 +00:00
self.rrdb_blocks = nn.Sequential(*[RRDB(nf) for _ in range(nb)])
self.reducers = nn.ModuleList([GleanEncoderBlock(
min(nf * 2 ** i, max_nf), max_nf) for i in range(reductions)])
2020-12-18 23:04:19 +00:00
reducer_output_dim = (input_dim // (2 ** (reductions + 1))) ** 2
reducer_output_nf = min(nf * 2 ** reductions, max_nf)
self.latent_conv = ConvGnLelu(reducer_output_nf, reducer_output_nf,
stride=2, kernel_size=3, activation=True, norm=False, bias=True)
2020-12-19 15:26:07 +00:00
self.latent_linear = EqualLinear(reducer_output_dim * reducer_output_nf,
latent_bank_latent_dim * latent_bank_blocks,
activation="fused_lrelu")
self.latent_bank_blocks = latent_bank_blocks
2020-12-18 23:04:19 +00:00
def forward(self, x):
fea = self.initial_conv(x)
fea = sequential_checkpoint(
self.rrdb_blocks, len(self.rrdb_blocks), fea)
2020-12-18 23:04:19 +00:00
rrdb_fea = fea
convolutional_features = []
for reducer in self.reducers:
fea, f = checkpoint(reducer, fea)
convolutional_features.append(f)
latents = self.latent_conv(fea)
latents = self.latent_linear(latents.flatten(
1, -1)).view(fea.shape[0], self.latent_bank_blocks, -1)
2020-12-18 23:04:19 +00:00
return rrdb_fea, convolutional_features, latents
# Produces an image by fusing the output features from the latent bank.
class GleanDecoder(nn.Module):
# To determine latent_bank_filters, use the `self.channels` map for the desired input dimensions from stylegan2_rosinality.py
def __init__(self, nf, latent_bank_filters=[512, 256, 128]):
super().__init__()
self.initial_conv = ConvGnLelu(
nf, nf, kernel_size=3, activation=True, norm=False, bias=True, weight_init_factor=.1)
2020-12-18 23:04:19 +00:00
2020-12-19 15:26:07 +00:00
decoder_block_shuffled_dims = [nf] + latent_bank_filters
2020-12-18 23:04:19 +00:00
self.decoder_blocks = nn.ModuleList([ConvGnLelu(decoder_block_shuffled_dims[i] + latent_bank_filters[i],
latent_bank_filters[i],
2020-12-19 15:26:07 +00:00
kernel_size=3, bias=True, norm=False, activation=True,
weight_init_factor=.1)
2020-12-18 23:04:19 +00:00
for i in range(len(latent_bank_filters))])
final_dim = latent_bank_filters[-1]
self.final_decode = ConvGnLelu(
final_dim, 3, kernel_size=3, activation=False, bias=True, norm=False, weight_init_factor=.1)
2020-12-18 23:04:19 +00:00
def forward(self, rrdb_fea, latent_bank_fea):
fea = self.initial_conv(rrdb_fea)
for i, block in enumerate(self.decoder_blocks):
2020-12-19 15:26:07 +00:00
# The paper calls for PixelShuffle here, but I don't have good experience with that. It also doesn't align with the way the underlying StyleGAN works.
fea = nn.functional.interpolate(
fea, scale_factor=2, mode="nearest")
2020-12-18 23:04:19 +00:00
fea = torch.cat([fea, latent_bank_fea[i]], dim=1)
fea = checkpoint(block, fea)
return self.final_decode(fea)
class GleanGenerator(nn.Module):
def __init__(self, nf, latent_bank_pretrained_weights, latent_bank_max_dim=1024, gen_output_dim=256,
encoder_rrdb_nb=6, latent_bank_latent_dim=512, input_dim=32, initial_stride=1):
2020-12-18 23:04:19 +00:00
super().__init__()
self.input_dim = input_dim
after_stride_dim = input_dim // initial_stride
# From 4x4->gen_output_dim x gen_output_dim + initial styled conv
latent_blocks = int(math.log(gen_output_dim, 2))
encoder_reductions = int(math.log(after_stride_dim / 4, 2)) + 1
2020-12-19 15:26:07 +00:00
self.encoder = GleanEncoder(nf, encoder_rrdb_nb, reductions=encoder_reductions, latent_bank_blocks=latent_blocks,
latent_bank_latent_dim=latent_bank_latent_dim, input_dim=after_stride_dim, initial_stride=initial_stride)
decoder_blocks = int(math.log(gen_output_dim/after_stride_dim, 2))
latent_bank_filters_out = [512, 512, 512, 256, 128]
latent_bank_filters_out = latent_bank_filters_out[-decoder_blocks:]
2020-12-18 23:04:19 +00:00
self.latent_bank = Stylegan2LatentBank(latent_bank_pretrained_weights, encoder_nf=nf, max_dim=latent_bank_max_dim,
latent_dim=latent_bank_latent_dim, encoder_levels=encoder_reductions,
decoder_levels=decoder_blocks)
self.decoder = GleanDecoder(nf, latent_bank_filters_out)
def forward(self, x):
assert self.input_dim == x.shape[-1] and self.input_dim == x.shape[-2]
rrdb_fea, conv_fea, latents = self.encoder(x)
latent_bank_fea = self.latent_bank(conv_fea, latents)
return self.decoder(rrdb_fea, latent_bank_fea)
@register_model
def register_glean(opt_net, opt):
kwargs = {}
allowlist = ['nf', 'latent_bank_pretrained_weights', 'latent_bank_max_dim', 'gen_output_dim', 'encoder_rrdb_nb', 'latent_bank_latent_dim',
'input_dim', 'initial_stride']
for k, v in opt_net.items():
if k in allowlist:
kwargs[k] = v
return GleanGenerator(**kwargs)