Glean mods

- Fixes fixed upscale factor issues
- Refines a few ops to decrease computation & parameterization
This commit is contained in:
James Betker 2020-12-27 12:25:06 -07:00
parent 5e2e605a50
commit ba543d1152
3 changed files with 30 additions and 20 deletions

View File

@ -3,6 +3,7 @@ import itertools
import random import random
import cv2 import cv2
import kornia
import numpy as np import numpy as np
import torch import torch
import os import os
@ -25,6 +26,7 @@ class ImageFolderDataset:
self.fetch_alt_image = opt['fetch_alt_image'] # If specified, this dataset will attempt to find a second image self.fetch_alt_image = opt['fetch_alt_image'] # If specified, this dataset will attempt to find a second image
# from the same video source. Search for 'fetch_alt_image' for more info. # from the same video source. Search for 'fetch_alt_image' for more info.
self.skip_lq = opt['skip_lq'] self.skip_lq = opt['skip_lq']
self.disable_flip = opt['disable_flip']
assert (self.target_hq_size // self.scale) % self.multiple == 0 # If we dont throw here, we get some really obscure errors. assert (self.target_hq_size // self.scale) % self.multiple == 0 # If we dont throw here, we get some really obscure errors.
if not isinstance(self.paths, list): if not isinstance(self.paths, list):
self.paths = [self.paths] self.paths = [self.paths]
@ -108,6 +110,9 @@ class ImageFolderDataset:
def __getitem__(self, item): def __getitem__(self, item):
hq = util.read_img(None, self.image_paths[item], rgb=True) hq = util.read_img(None, self.image_paths[item], rgb=True)
if not self.disable_flip and random.random() < .5:
hq = hq[:, ::-1, :]
if self.labeler: if self.labeler:
assert hq.shape[0] == hq.shape[1] # This just has not been accomodated yet. assert hq.shape[0] == hq.shape[1] # This just has not been accomodated yet.
dim = hq.shape[0] dim = hq.shape[0]

View File

@ -15,12 +15,13 @@ from utils.util import checkpoint, sequential_checkpoint
class GleanEncoderBlock(nn.Module): class GleanEncoderBlock(nn.Module):
def __init__(self, nf): def __init__(self, nf, max_nf):
super().__init__() super().__init__()
self.structural_latent_conv = ConvGnLelu(nf, nf, kernel_size=1, activation=False, norm=False, bias=True) self.structural_latent_conv = ConvGnLelu(nf, nf, kernel_size=1, activation=False, norm=False, bias=True)
top_nf = min(nf*2, max_nf)
self.process = nn.Sequential( self.process = nn.Sequential(
ConvGnLelu(nf, nf*2, kernel_size=3, stride=2, activation=True, norm=False, bias=False), ConvGnLelu(nf, top_nf, kernel_size=3, stride=2, activation=True, norm=False, bias=False),
ConvGnLelu(nf*2, nf*2, kernel_size=3, activation=True, norm=False, bias=False) ConvGnLelu(top_nf, top_nf, kernel_size=3, activation=True, norm=False, bias=False)
) )
def forward(self, x): def forward(self, x):
@ -33,15 +34,15 @@ class GleanEncoderBlock(nn.Module):
# and latent vectors (`C` shape=[b,l,f] l=levels aka C_sub) for use with the latent bank. # and latent vectors (`C` shape=[b,l,f] l=levels aka C_sub) for use with the latent bank.
# Note that latent levels and convolutional feature levels do not necessarily match, per the paper. # Note that latent levels and convolutional feature levels do not necessarily match, per the paper.
class GleanEncoder(nn.Module): class GleanEncoder(nn.Module):
def __init__(self, nf, nb, reductions=4, latent_bank_blocks=7, latent_bank_latent_dim=512, input_dim=32, initial_stride=1): def __init__(self, nf, nb, max_nf=512, reductions=4, latent_bank_blocks=7, latent_bank_latent_dim=512, input_dim=32, initial_stride=1):
super().__init__() super().__init__()
self.initial_conv = ConvGnLelu(3, nf, kernel_size=7, activation=False, norm=False, bias=True, stride=initial_stride) self.initial_conv = ConvGnLelu(3, nf, kernel_size=7, activation=False, norm=False, bias=True, stride=initial_stride)
self.rrdb_blocks = nn.Sequential(*[RRDB(nf) for _ in range(nb)]) self.rrdb_blocks = nn.Sequential(*[RRDB(nf) for _ in range(nb)])
self.reducers = nn.ModuleList([GleanEncoderBlock(nf * 2 ** i) for i in range(reductions)]) self.reducers = nn.ModuleList([GleanEncoderBlock(min(nf * 2 ** i, max_nf), max_nf) for i in range(reductions)])
reducer_output_dim = (input_dim // (2 ** reductions)) ** 2 reducer_output_dim = (input_dim // (2 ** (reductions + 1))) ** 2
reducer_output_nf = nf * 2 ** reductions reducer_output_nf = min(nf * 2 ** reductions, max_nf)
self.latent_conv = ConvGnLelu(reducer_output_nf, reducer_output_nf, kernel_size=1, activation=True, norm=False, bias=True) self.latent_conv = ConvGnLelu(reducer_output_nf, reducer_output_nf, stride=2, kernel_size=3, activation=True, norm=False, bias=True)
self.latent_linear = EqualLinear(reducer_output_dim * reducer_output_nf, self.latent_linear = EqualLinear(reducer_output_dim * reducer_output_nf,
latent_bank_latent_dim * latent_bank_blocks, latent_bank_latent_dim * latent_bank_blocks,
activation="fused_lrelu") activation="fused_lrelu")
@ -91,14 +92,17 @@ class GleanDecoder(nn.Module):
class GleanGenerator(nn.Module): class GleanGenerator(nn.Module):
def __init__(self, nf, latent_bank_pretrained_weights, latent_bank_max_dim=1024, gen_output_dim=256, def __init__(self, nf, latent_bank_pretrained_weights, latent_bank_max_dim=1024, gen_output_dim=256,
encoder_rrdb_nb=6, encoder_reductions=4, latent_bank_latent_dim=512, input_dim=32, initial_stride=1): encoder_rrdb_nb=6, latent_bank_latent_dim=512, input_dim=32, initial_stride=1):
super().__init__() super().__init__()
self.input_dim = input_dim // initial_stride self.input_dim = input_dim
after_stride_dim = input_dim // initial_stride
latent_blocks = int(math.log(gen_output_dim, 2)) # From 4x4->gen_output_dim x gen_output_dim + initial styled conv latent_blocks = int(math.log(gen_output_dim, 2)) # From 4x4->gen_output_dim x gen_output_dim + initial styled conv
encoder_reductions = int(math.log(after_stride_dim / 4, 2)) + 1
self.encoder = GleanEncoder(nf, encoder_rrdb_nb, reductions=encoder_reductions, latent_bank_blocks=latent_blocks, self.encoder = GleanEncoder(nf, encoder_rrdb_nb, reductions=encoder_reductions, latent_bank_blocks=latent_blocks,
latent_bank_latent_dim=latent_bank_latent_dim, input_dim=input_dim, initial_stride=initial_stride) latent_bank_latent_dim=latent_bank_latent_dim, input_dim=after_stride_dim, initial_stride=initial_stride)
decoder_blocks = int(math.log(gen_output_dim/input_dim, 2)) decoder_blocks = int(math.log(gen_output_dim/after_stride_dim, 2))
latent_bank_filters_out = [512, 256, 128] # TODO: Use decoder_blocks to synthesize the correct value for latent_bank_filters here. The fixed defaults will work fine for testing, though. latent_bank_filters_out = [512, 512, 512, 256, 128]
latent_bank_filters_out = latent_bank_filters_out[-decoder_blocks:]
self.latent_bank = Stylegan2LatentBank(latent_bank_pretrained_weights, encoder_nf=nf, max_dim=latent_bank_max_dim, self.latent_bank = Stylegan2LatentBank(latent_bank_pretrained_weights, encoder_nf=nf, max_dim=latent_bank_max_dim,
latent_dim=latent_bank_latent_dim, encoder_levels=encoder_reductions, latent_dim=latent_bank_latent_dim, encoder_levels=encoder_reductions,
decoder_levels=decoder_blocks) decoder_levels=decoder_blocks)
@ -114,8 +118,9 @@ class GleanGenerator(nn.Module):
@register_model @register_model
def register_glean(opt_net, opt): def register_glean(opt_net, opt):
kwargs = {} kwargs = {}
exclusions = ['which_model_G', 'type'] allowlist = ['nf', 'latent_bank_pretrained_weights', 'latent_bank_max_dim', 'gen_output_dim', 'encoder_rrdb_nb', 'latent_bank_latent_dim',
for k, v in opt.items(): 'input_dim', 'initial_stride']
if k not in exclusions: for k, v in opt_net.items():
if k in allowlist:
kwargs[k] = v kwargs[k] = v
return GleanGenerator(**kwargs) return GleanGenerator(**kwargs)

View File

@ -6,7 +6,7 @@ from models.stylegan.stylegan2_rosinality import Generator
class Stylegan2LatentBank(nn.Module): class Stylegan2LatentBank(nn.Module):
def __init__(self, pretrained_model_file, encoder_nf=64, max_dim=1024, latent_dim=512, encoder_levels=4, decoder_levels=3): def __init__(self, pretrained_model_file, encoder_nf=64, encoder_max_nf=512, max_dim=1024, latent_dim=512, encoder_levels=4, decoder_levels=3):
super().__init__() super().__init__()
# Initialize the bank. # Initialize the bank.
@ -19,11 +19,11 @@ class Stylegan2LatentBank(nn.Module):
p.requires_grad = False p.requires_grad = False
p.DO_NOT_TRAIN = True p.DO_NOT_TRAIN = True
# TODO: Compute these based on the underlying stylegans channels member variable. # These are from `stylegan_rosinality.py`, search for `self.channels = {`.
stylegan_encoder_dims = [512, 512, 512, 512] stylegan_encoder_dims = [512, 512, 512, 512, 512, 256, 128, 64, 32]
# Initialize the fusion blocks. TODO: Try using the StyledConvs instead of regular ones. # Initialize the fusion blocks. TODO: Try using the StyledConvs instead of regular ones.
encoder_output_dims = reversed([64 * 2 ** i for i in range(encoder_levels)]) encoder_output_dims = reversed([min(encoder_nf * 2 ** i, encoder_max_nf) for i in range(encoder_levels)])
input_dims_by_layer = [eod + sed for eod, sed in zip(encoder_output_dims, stylegan_encoder_dims)] input_dims_by_layer = [eod + sed for eod, sed in zip(encoder_output_dims, stylegan_encoder_dims)]
self.fusion_blocks = nn.ModuleList([ConvGnLelu(in_filters, out_filters, kernel_size=3, activation=True, norm=False, bias=True) self.fusion_blocks = nn.ModuleList([ConvGnLelu(in_filters, out_filters, kernel_size=3, activation=True, norm=False, bias=True)
for in_filters, out_filters in zip(input_dims_by_layer, stylegan_encoder_dims)]) for in_filters, out_filters in zip(input_dims_by_layer, stylegan_encoder_dims)])