DL-Art-School/codes/models/glean/glean.py

122 lines
6.3 KiB
Python
Raw Normal View History

2020-12-18 23:04:19 +00:00
import math
import torch.nn as nn
import torch
from models.RRDBNet_arch import RRDB
from models.arch_util import ConvGnLelu
# Produces a convolutional feature (`f`) and a reduced feature map with double the filters.
from models.glean.stylegan2_latent_bank import Stylegan2LatentBank
from models.stylegan.stylegan2_rosinality import EqualLinear
from trainer.networks import register_model
2020-12-18 23:04:19 +00:00
from utils.util import checkpoint, sequential_checkpoint
class GleanEncoderBlock(nn.Module):
def __init__(self, nf):
super().__init__()
self.structural_latent_conv = ConvGnLelu(nf, nf, kernel_size=1, activation=False, norm=False, bias=True)
self.process = nn.Sequential(
ConvGnLelu(nf, nf*2, kernel_size=3, stride=2, activation=True, norm=False, bias=False),
ConvGnLelu(nf*2, nf*2, kernel_size=3, activation=True, norm=False, bias=False)
)
def forward(self, x):
structural_latent = self.structural_latent_conv(x)
fea = self.process(x)
return fea, structural_latent
# Produces RRDB features, a list of convolutional features (`f` shape=[l][b,c,h,w] l=levels aka f_sub)
# and latent vectors (`C` shape=[b,l,f] l=levels aka C_sub) for use with the latent bank.
# Note that latent levels and convolutional feature levels do not necessarily match, per the paper.
class GleanEncoder(nn.Module):
def __init__(self, nf, nb, reductions=4, latent_bank_blocks=7, latent_bank_latent_dim=512, input_dim=32, initial_stride=1):
2020-12-18 23:04:19 +00:00
super().__init__()
self.initial_conv = ConvGnLelu(3, nf, kernel_size=7, activation=False, norm=False, bias=True, stride=initial_stride)
2020-12-18 23:04:19 +00:00
self.rrdb_blocks = nn.Sequential(*[RRDB(nf) for _ in range(nb)])
self.reducers = nn.ModuleList([GleanEncoderBlock(nf * 2 ** i) for i in range(reductions)])
reducer_output_dim = (input_dim // (2 ** reductions)) ** 2
reducer_output_nf = nf * 2 ** reductions
self.latent_conv = ConvGnLelu(reducer_output_nf, reducer_output_nf, kernel_size=1, activation=True, norm=False, bias=True)
2020-12-19 15:26:07 +00:00
self.latent_linear = EqualLinear(reducer_output_dim * reducer_output_nf,
latent_bank_latent_dim * latent_bank_blocks,
activation="fused_lrelu")
self.latent_bank_blocks = latent_bank_blocks
2020-12-18 23:04:19 +00:00
def forward(self, x):
fea = self.initial_conv(x)
fea = sequential_checkpoint(self.rrdb_blocks, len(self.rrdb_blocks), fea)
rrdb_fea = fea
convolutional_features = []
for reducer in self.reducers:
fea, f = checkpoint(reducer, fea)
convolutional_features.append(f)
latents = self.latent_conv(fea)
2020-12-19 15:26:07 +00:00
latents = self.latent_linear(latents.flatten(1, -1)).view(fea.shape[0], self.latent_bank_blocks, -1)
2020-12-18 23:04:19 +00:00
return rrdb_fea, convolutional_features, latents
# Produces an image by fusing the output features from the latent bank.
class GleanDecoder(nn.Module):
# To determine latent_bank_filters, use the `self.channels` map for the desired input dimensions from stylegan2_rosinality.py
def __init__(self, nf, latent_bank_filters=[512, 256, 128]):
super().__init__()
2020-12-19 15:26:07 +00:00
self.initial_conv = ConvGnLelu(nf, nf, kernel_size=3, activation=True, norm=False, bias=True, weight_init_factor=.1)
2020-12-18 23:04:19 +00:00
2020-12-19 15:26:07 +00:00
decoder_block_shuffled_dims = [nf] + latent_bank_filters
2020-12-18 23:04:19 +00:00
self.decoder_blocks = nn.ModuleList([ConvGnLelu(decoder_block_shuffled_dims[i] + latent_bank_filters[i],
latent_bank_filters[i],
2020-12-19 15:26:07 +00:00
kernel_size=3, bias=True, norm=False, activation=True,
weight_init_factor=.1)
2020-12-18 23:04:19 +00:00
for i in range(len(latent_bank_filters))])
final_dim = latent_bank_filters[-1]
2020-12-19 15:26:07 +00:00
self.final_decode = ConvGnLelu(final_dim, 3, kernel_size=3, activation=False, bias=True, norm=False, weight_init_factor=.1)
2020-12-18 23:04:19 +00:00
def forward(self, rrdb_fea, latent_bank_fea):
fea = self.initial_conv(rrdb_fea)
for i, block in enumerate(self.decoder_blocks):
2020-12-19 15:26:07 +00:00
# The paper calls for PixelShuffle here, but I don't have good experience with that. It also doesn't align with the way the underlying StyleGAN works.
fea = nn.functional.interpolate(fea, scale_factor=2, mode="nearest")
2020-12-18 23:04:19 +00:00
fea = torch.cat([fea, latent_bank_fea[i]], dim=1)
fea = checkpoint(block, fea)
return self.final_decode(fea)
class GleanGenerator(nn.Module):
def __init__(self, nf, latent_bank_pretrained_weights, latent_bank_max_dim=1024, gen_output_dim=256,
encoder_rrdb_nb=6, encoder_reductions=4, latent_bank_latent_dim=512, input_dim=32, initial_stride=1):
2020-12-18 23:04:19 +00:00
super().__init__()
self.input_dim = input_dim // initial_stride
2020-12-19 15:26:07 +00:00
latent_blocks = int(math.log(gen_output_dim, 2)) # From 4x4->gen_output_dim x gen_output_dim + initial styled conv
self.encoder = GleanEncoder(nf, encoder_rrdb_nb, reductions=encoder_reductions, latent_bank_blocks=latent_blocks,
latent_bank_latent_dim=latent_bank_latent_dim, input_dim=input_dim, initial_stride=initial_stride)
2020-12-18 23:04:19 +00:00
decoder_blocks = int(math.log(gen_output_dim/input_dim, 2))
latent_bank_filters_out = [512, 256, 128] # TODO: Use decoder_blocks to synthesize the correct value for latent_bank_filters here. The fixed defaults will work fine for testing, though.
self.latent_bank = Stylegan2LatentBank(latent_bank_pretrained_weights, encoder_nf=nf, max_dim=latent_bank_max_dim,
latent_dim=latent_bank_latent_dim, encoder_levels=encoder_reductions,
decoder_levels=decoder_blocks)
self.decoder = GleanDecoder(nf, latent_bank_filters_out)
def forward(self, x):
assert self.input_dim == x.shape[-1] and self.input_dim == x.shape[-2]
rrdb_fea, conv_fea, latents = self.encoder(x)
latent_bank_fea = self.latent_bank(conv_fea, latents)
return self.decoder(rrdb_fea, latent_bank_fea)
@register_model
def register_glean(opt_net, opt):
kwargs = {}
exclusions = ['which_model_G', 'type']
for k, v in opt.items():
if k not in exclusions:
kwargs[k] = v
return GleanGenerator(**kwargs)