diff --git a/codes/data/image_folder_dataset.py b/codes/data/image_folder_dataset.py index 2bc1a984..fe334eab 100644 --- a/codes/data/image_folder_dataset.py +++ b/codes/data/image_folder_dataset.py @@ -3,6 +3,7 @@ import itertools import random import cv2 +import kornia import numpy as np import torch import os @@ -25,6 +26,7 @@ class ImageFolderDataset: self.fetch_alt_image = opt['fetch_alt_image'] # If specified, this dataset will attempt to find a second image # from the same video source. Search for 'fetch_alt_image' for more info. self.skip_lq = opt['skip_lq'] + self.disable_flip = opt['disable_flip'] assert (self.target_hq_size // self.scale) % self.multiple == 0 # If we dont throw here, we get some really obscure errors. if not isinstance(self.paths, list): self.paths = [self.paths] @@ -108,6 +110,9 @@ class ImageFolderDataset: def __getitem__(self, item): hq = util.read_img(None, self.image_paths[item], rgb=True) + if not self.disable_flip and random.random() < .5: + hq = hq[:, ::-1, :] + if self.labeler: assert hq.shape[0] == hq.shape[1] # This just has not been accomodated yet. dim = hq.shape[0] diff --git a/codes/models/glean/glean.py b/codes/models/glean/glean.py index c8f981ea..497f0a7a 100644 --- a/codes/models/glean/glean.py +++ b/codes/models/glean/glean.py @@ -15,12 +15,13 @@ from utils.util import checkpoint, sequential_checkpoint class GleanEncoderBlock(nn.Module): - def __init__(self, nf): + def __init__(self, nf, max_nf): super().__init__() self.structural_latent_conv = ConvGnLelu(nf, nf, kernel_size=1, activation=False, norm=False, bias=True) + top_nf = min(nf*2, max_nf) self.process = nn.Sequential( - ConvGnLelu(nf, nf*2, kernel_size=3, stride=2, activation=True, norm=False, bias=False), - ConvGnLelu(nf*2, nf*2, kernel_size=3, activation=True, norm=False, bias=False) + ConvGnLelu(nf, top_nf, kernel_size=3, stride=2, activation=True, norm=False, bias=False), + ConvGnLelu(top_nf, top_nf, kernel_size=3, activation=True, norm=False, bias=False) ) def forward(self, x): @@ -33,15 +34,15 @@ class GleanEncoderBlock(nn.Module): # and latent vectors (`C` shape=[b,l,f] l=levels aka C_sub) for use with the latent bank. # Note that latent levels and convolutional feature levels do not necessarily match, per the paper. class GleanEncoder(nn.Module): - def __init__(self, nf, nb, reductions=4, latent_bank_blocks=7, latent_bank_latent_dim=512, input_dim=32, initial_stride=1): + def __init__(self, nf, nb, max_nf=512, reductions=4, latent_bank_blocks=7, latent_bank_latent_dim=512, input_dim=32, initial_stride=1): super().__init__() self.initial_conv = ConvGnLelu(3, nf, kernel_size=7, activation=False, norm=False, bias=True, stride=initial_stride) self.rrdb_blocks = nn.Sequential(*[RRDB(nf) for _ in range(nb)]) - self.reducers = nn.ModuleList([GleanEncoderBlock(nf * 2 ** i) for i in range(reductions)]) + self.reducers = nn.ModuleList([GleanEncoderBlock(min(nf * 2 ** i, max_nf), max_nf) for i in range(reductions)]) - reducer_output_dim = (input_dim // (2 ** reductions)) ** 2 - reducer_output_nf = nf * 2 ** reductions - self.latent_conv = ConvGnLelu(reducer_output_nf, reducer_output_nf, kernel_size=1, activation=True, norm=False, bias=True) + reducer_output_dim = (input_dim // (2 ** (reductions + 1))) ** 2 + reducer_output_nf = min(nf * 2 ** reductions, max_nf) + self.latent_conv = ConvGnLelu(reducer_output_nf, reducer_output_nf, stride=2, kernel_size=3, activation=True, norm=False, bias=True) self.latent_linear = EqualLinear(reducer_output_dim * reducer_output_nf, latent_bank_latent_dim * latent_bank_blocks, activation="fused_lrelu") @@ -91,14 +92,17 @@ class GleanDecoder(nn.Module): class GleanGenerator(nn.Module): def __init__(self, nf, latent_bank_pretrained_weights, latent_bank_max_dim=1024, gen_output_dim=256, - encoder_rrdb_nb=6, encoder_reductions=4, latent_bank_latent_dim=512, input_dim=32, initial_stride=1): + encoder_rrdb_nb=6, latent_bank_latent_dim=512, input_dim=32, initial_stride=1): super().__init__() - self.input_dim = input_dim // initial_stride + self.input_dim = input_dim + after_stride_dim = input_dim // initial_stride latent_blocks = int(math.log(gen_output_dim, 2)) # From 4x4->gen_output_dim x gen_output_dim + initial styled conv + encoder_reductions = int(math.log(after_stride_dim / 4, 2)) + 1 self.encoder = GleanEncoder(nf, encoder_rrdb_nb, reductions=encoder_reductions, latent_bank_blocks=latent_blocks, - latent_bank_latent_dim=latent_bank_latent_dim, input_dim=input_dim, initial_stride=initial_stride) - decoder_blocks = int(math.log(gen_output_dim/input_dim, 2)) - latent_bank_filters_out = [512, 256, 128] # TODO: Use decoder_blocks to synthesize the correct value for latent_bank_filters here. The fixed defaults will work fine for testing, though. + latent_bank_latent_dim=latent_bank_latent_dim, input_dim=after_stride_dim, initial_stride=initial_stride) + decoder_blocks = int(math.log(gen_output_dim/after_stride_dim, 2)) + latent_bank_filters_out = [512, 512, 512, 256, 128] + latent_bank_filters_out = latent_bank_filters_out[-decoder_blocks:] self.latent_bank = Stylegan2LatentBank(latent_bank_pretrained_weights, encoder_nf=nf, max_dim=latent_bank_max_dim, latent_dim=latent_bank_latent_dim, encoder_levels=encoder_reductions, decoder_levels=decoder_blocks) @@ -114,8 +118,9 @@ class GleanGenerator(nn.Module): @register_model def register_glean(opt_net, opt): kwargs = {} - exclusions = ['which_model_G', 'type'] - for k, v in opt.items(): - if k not in exclusions: + allowlist = ['nf', 'latent_bank_pretrained_weights', 'latent_bank_max_dim', 'gen_output_dim', 'encoder_rrdb_nb', 'latent_bank_latent_dim', + 'input_dim', 'initial_stride'] + for k, v in opt_net.items(): + if k in allowlist: kwargs[k] = v return GleanGenerator(**kwargs) diff --git a/codes/models/glean/stylegan2_latent_bank.py b/codes/models/glean/stylegan2_latent_bank.py index 3adc9735..f7a191ae 100644 --- a/codes/models/glean/stylegan2_latent_bank.py +++ b/codes/models/glean/stylegan2_latent_bank.py @@ -6,7 +6,7 @@ from models.stylegan.stylegan2_rosinality import Generator class Stylegan2LatentBank(nn.Module): - def __init__(self, pretrained_model_file, encoder_nf=64, max_dim=1024, latent_dim=512, encoder_levels=4, decoder_levels=3): + def __init__(self, pretrained_model_file, encoder_nf=64, encoder_max_nf=512, max_dim=1024, latent_dim=512, encoder_levels=4, decoder_levels=3): super().__init__() # Initialize the bank. @@ -19,11 +19,11 @@ class Stylegan2LatentBank(nn.Module): p.requires_grad = False p.DO_NOT_TRAIN = True - # TODO: Compute these based on the underlying stylegans channels member variable. - stylegan_encoder_dims = [512, 512, 512, 512] + # These are from `stylegan_rosinality.py`, search for `self.channels = {`. + stylegan_encoder_dims = [512, 512, 512, 512, 512, 256, 128, 64, 32] # Initialize the fusion blocks. TODO: Try using the StyledConvs instead of regular ones. - encoder_output_dims = reversed([64 * 2 ** i for i in range(encoder_levels)]) + encoder_output_dims = reversed([min(encoder_nf * 2 ** i, encoder_max_nf) for i in range(encoder_levels)]) input_dims_by_layer = [eod + sed for eod, sed in zip(encoder_output_dims, stylegan_encoder_dims)] self.fusion_blocks = nn.ModuleList([ConvGnLelu(in_filters, out_filters, kernel_size=3, activation=True, norm=False, bias=True) for in_filters, out_filters in zip(input_dims_by_layer, stylegan_encoder_dims)])