ba543d1152
- Fixes fixed upscale factor issues - Refines a few ops to decrease computation & parameterization
127 lines
6.6 KiB
Python
127 lines
6.6 KiB
Python
import math
|
|
|
|
import torch.nn as nn
|
|
import torch
|
|
|
|
from models.RRDBNet_arch import RRDB
|
|
from models.arch_util import ConvGnLelu
|
|
|
|
|
|
# Produces a convolutional feature (`f`) and a reduced feature map with double the filters.
|
|
from models.glean.stylegan2_latent_bank import Stylegan2LatentBank
|
|
from models.stylegan.stylegan2_rosinality import EqualLinear
|
|
from trainer.networks import register_model
|
|
from utils.util import checkpoint, sequential_checkpoint
|
|
|
|
|
|
class GleanEncoderBlock(nn.Module):
|
|
def __init__(self, nf, max_nf):
|
|
super().__init__()
|
|
self.structural_latent_conv = ConvGnLelu(nf, nf, kernel_size=1, activation=False, norm=False, bias=True)
|
|
top_nf = min(nf*2, max_nf)
|
|
self.process = nn.Sequential(
|
|
ConvGnLelu(nf, top_nf, kernel_size=3, stride=2, activation=True, norm=False, bias=False),
|
|
ConvGnLelu(top_nf, top_nf, kernel_size=3, activation=True, norm=False, bias=False)
|
|
)
|
|
|
|
def forward(self, x):
|
|
structural_latent = self.structural_latent_conv(x)
|
|
fea = self.process(x)
|
|
return fea, structural_latent
|
|
|
|
|
|
# Produces RRDB features, a list of convolutional features (`f` shape=[l][b,c,h,w] l=levels aka f_sub)
|
|
# and latent vectors (`C` shape=[b,l,f] l=levels aka C_sub) for use with the latent bank.
|
|
# Note that latent levels and convolutional feature levels do not necessarily match, per the paper.
|
|
class GleanEncoder(nn.Module):
|
|
def __init__(self, nf, nb, max_nf=512, reductions=4, latent_bank_blocks=7, latent_bank_latent_dim=512, input_dim=32, initial_stride=1):
|
|
super().__init__()
|
|
self.initial_conv = ConvGnLelu(3, nf, kernel_size=7, activation=False, norm=False, bias=True, stride=initial_stride)
|
|
self.rrdb_blocks = nn.Sequential(*[RRDB(nf) for _ in range(nb)])
|
|
self.reducers = nn.ModuleList([GleanEncoderBlock(min(nf * 2 ** i, max_nf), max_nf) for i in range(reductions)])
|
|
|
|
reducer_output_dim = (input_dim // (2 ** (reductions + 1))) ** 2
|
|
reducer_output_nf = min(nf * 2 ** reductions, max_nf)
|
|
self.latent_conv = ConvGnLelu(reducer_output_nf, reducer_output_nf, stride=2, kernel_size=3, activation=True, norm=False, bias=True)
|
|
self.latent_linear = EqualLinear(reducer_output_dim * reducer_output_nf,
|
|
latent_bank_latent_dim * latent_bank_blocks,
|
|
activation="fused_lrelu")
|
|
self.latent_bank_blocks = latent_bank_blocks
|
|
|
|
def forward(self, x):
|
|
fea = self.initial_conv(x)
|
|
fea = sequential_checkpoint(self.rrdb_blocks, len(self.rrdb_blocks), fea)
|
|
rrdb_fea = fea
|
|
convolutional_features = []
|
|
for reducer in self.reducers:
|
|
fea, f = checkpoint(reducer, fea)
|
|
convolutional_features.append(f)
|
|
|
|
latents = self.latent_conv(fea)
|
|
latents = self.latent_linear(latents.flatten(1, -1)).view(fea.shape[0], self.latent_bank_blocks, -1)
|
|
|
|
return rrdb_fea, convolutional_features, latents
|
|
|
|
|
|
# Produces an image by fusing the output features from the latent bank.
|
|
class GleanDecoder(nn.Module):
|
|
# To determine latent_bank_filters, use the `self.channels` map for the desired input dimensions from stylegan2_rosinality.py
|
|
def __init__(self, nf, latent_bank_filters=[512, 256, 128]):
|
|
super().__init__()
|
|
self.initial_conv = ConvGnLelu(nf, nf, kernel_size=3, activation=True, norm=False, bias=True, weight_init_factor=.1)
|
|
|
|
decoder_block_shuffled_dims = [nf] + latent_bank_filters
|
|
self.decoder_blocks = nn.ModuleList([ConvGnLelu(decoder_block_shuffled_dims[i] + latent_bank_filters[i],
|
|
latent_bank_filters[i],
|
|
kernel_size=3, bias=True, norm=False, activation=True,
|
|
weight_init_factor=.1)
|
|
for i in range(len(latent_bank_filters))])
|
|
|
|
final_dim = latent_bank_filters[-1]
|
|
self.final_decode = ConvGnLelu(final_dim, 3, kernel_size=3, activation=False, bias=True, norm=False, weight_init_factor=.1)
|
|
|
|
def forward(self, rrdb_fea, latent_bank_fea):
|
|
fea = self.initial_conv(rrdb_fea)
|
|
for i, block in enumerate(self.decoder_blocks):
|
|
# The paper calls for PixelShuffle here, but I don't have good experience with that. It also doesn't align with the way the underlying StyleGAN works.
|
|
fea = nn.functional.interpolate(fea, scale_factor=2, mode="nearest")
|
|
fea = torch.cat([fea, latent_bank_fea[i]], dim=1)
|
|
fea = checkpoint(block, fea)
|
|
return self.final_decode(fea)
|
|
|
|
|
|
class GleanGenerator(nn.Module):
|
|
def __init__(self, nf, latent_bank_pretrained_weights, latent_bank_max_dim=1024, gen_output_dim=256,
|
|
encoder_rrdb_nb=6, latent_bank_latent_dim=512, input_dim=32, initial_stride=1):
|
|
super().__init__()
|
|
self.input_dim = input_dim
|
|
after_stride_dim = input_dim // initial_stride
|
|
latent_blocks = int(math.log(gen_output_dim, 2)) # From 4x4->gen_output_dim x gen_output_dim + initial styled conv
|
|
encoder_reductions = int(math.log(after_stride_dim / 4, 2)) + 1
|
|
self.encoder = GleanEncoder(nf, encoder_rrdb_nb, reductions=encoder_reductions, latent_bank_blocks=latent_blocks,
|
|
latent_bank_latent_dim=latent_bank_latent_dim, input_dim=after_stride_dim, initial_stride=initial_stride)
|
|
decoder_blocks = int(math.log(gen_output_dim/after_stride_dim, 2))
|
|
latent_bank_filters_out = [512, 512, 512, 256, 128]
|
|
latent_bank_filters_out = latent_bank_filters_out[-decoder_blocks:]
|
|
self.latent_bank = Stylegan2LatentBank(latent_bank_pretrained_weights, encoder_nf=nf, max_dim=latent_bank_max_dim,
|
|
latent_dim=latent_bank_latent_dim, encoder_levels=encoder_reductions,
|
|
decoder_levels=decoder_blocks)
|
|
self.decoder = GleanDecoder(nf, latent_bank_filters_out)
|
|
|
|
def forward(self, x):
|
|
assert self.input_dim == x.shape[-1] and self.input_dim == x.shape[-2]
|
|
rrdb_fea, conv_fea, latents = self.encoder(x)
|
|
latent_bank_fea = self.latent_bank(conv_fea, latents)
|
|
return self.decoder(rrdb_fea, latent_bank_fea)
|
|
|
|
|
|
@register_model
|
|
def register_glean(opt_net, opt):
|
|
kwargs = {}
|
|
allowlist = ['nf', 'latent_bank_pretrained_weights', 'latent_bank_max_dim', 'gen_output_dim', 'encoder_rrdb_nb', 'latent_bank_latent_dim',
|
|
'input_dim', 'initial_stride']
|
|
for k, v in opt_net.items():
|
|
if k in allowlist:
|
|
kwargs[k] = v
|
|
return GleanGenerator(**kwargs)
|