import torch import torch.nn as nn from models.arch_util import ConvGnLelu from models.stylegan.stylegan2_rosinality import Generator class Stylegan2LatentBank(nn.Module): def __init__(self, pretrained_model_file, encoder_nf=64, max_dim=1024, latent_dim=512, encoder_levels=4, decoder_levels=3): super().__init__() # Initialize the bank. self.bank = Generator(size=max_dim, style_dim=latent_dim, n_mlp=8, channel_multiplier=2) # Assumed using 'f' generators with mult=2. state_dict = torch.load(pretrained_model_file) self.bank.load_state_dict(state_dict, strict=True) # Shut off training of the latent bank. for p in self.bank.parameters(): p.requires_grad = False p.DO_NOT_TRAIN = True # TODO: Compute these based on the underlying stylegans channels member variable. stylegan_encoder_dims = [512, 512, 512, 512] # Initialize the fusion blocks. TODO: Try using the StyledConvs instead of regular ones. encoder_output_dims = reversed([64 * 2 ** i for i in range(encoder_levels)]) input_dims_by_layer = [eod + sed for eod, sed in zip(encoder_output_dims, stylegan_encoder_dims)] self.fusion_blocks = nn.ModuleList([ConvGnLelu(in_filters, out_filters, kernel_size=3, activation=True, norm=False, bias=True) for in_filters, out_filters in zip(input_dims_by_layer, stylegan_encoder_dims)]) self.decoder_levels = decoder_levels self.decoder_start = encoder_levels - 1 self.total_levels = encoder_levels + decoder_levels - 1 # This forward mirrors the forward() pass from the rosinality stylegan2 implementation, with the additions called # for from the GLEAN paper. GLEAN mods are annotated with comments. # Removed stuff: # - Support for split latents (we're spoonfeeding them) # - Support for fixed noise inputs # - RGB computations -> we only care about the latents # - Style MLP -> GLEAN computes the Style inputs directly. # - Later layers -> GLEAN terminates at 256 resolution. def forward(self, convolutional_features, latent_vectors): out = self.bank.input(latent_vectors[:, 0]) # The input here is only used to fetch the batch size. out = self.bank.conv1(out, latent_vectors[:, 0], noise=None) k = 0 decoder_outputs = [] for conv1, conv2 in zip(self.bank.convs[::2], self.bank.convs[1::2]): if k < len(self.fusion_blocks): out = torch.cat([convolutional_features[-k-1], out], dim=1) out = self.fusion_blocks[k](out) out = conv1(out, latent_vectors[:, k], noise=None) out = conv2(out, latent_vectors[:, k], noise=None) if k >= self.decoder_start: decoder_outputs.append(out) if k >= self.total_levels: break k += 1 return decoder_outputs