glean mods

This commit is contained in:
James Betker 2020-12-19 08:26:07 -07:00
parent f35c034fa5
commit 9377d34ac3
3 changed files with 23 additions and 30 deletions

View File

@ -32,7 +32,7 @@ class GleanEncoderBlock(nn.Module):
# and latent vectors (`C` shape=[b,l,f] l=levels aka C_sub) for use with the latent bank. # and latent vectors (`C` shape=[b,l,f] l=levels aka C_sub) for use with the latent bank.
# Note that latent levels and convolutional feature levels do not necessarily match, per the paper. # Note that latent levels and convolutional feature levels do not necessarily match, per the paper.
class GleanEncoder(nn.Module): class GleanEncoder(nn.Module):
def __init__(self, nf, nb, reductions=4, latent_bank_blocks=13, latent_bank_latent_dim=512, input_dim=32): def __init__(self, nf, nb, reductions=4, latent_bank_blocks=7, latent_bank_latent_dim=512, input_dim=32):
super().__init__() super().__init__()
self.initial_conv = ConvGnLelu(3, nf, kernel_size=7, activation=False, norm=False, bias=True) self.initial_conv = ConvGnLelu(3, nf, kernel_size=7, activation=False, norm=False, bias=True)
self.rrdb_blocks = nn.Sequential(*[RRDB(nf) for _ in range(nb)]) self.rrdb_blocks = nn.Sequential(*[RRDB(nf) for _ in range(nb)])
@ -41,11 +41,10 @@ class GleanEncoder(nn.Module):
reducer_output_dim = (input_dim // (2 ** reductions)) ** 2 reducer_output_dim = (input_dim // (2 ** reductions)) ** 2
reducer_output_nf = nf * 2 ** reductions reducer_output_nf = nf * 2 ** reductions
self.latent_conv = ConvGnLelu(reducer_output_nf, reducer_output_nf, kernel_size=1, activation=True, norm=False, bias=True) self.latent_conv = ConvGnLelu(reducer_output_nf, reducer_output_nf, kernel_size=1, activation=True, norm=False, bias=True)
# This is a questionable part of this architecture. Apply multiple Denses to separate outputs (as I've done here)? self.latent_linear = EqualLinear(reducer_output_dim * reducer_output_nf,
# Apply a single dense, then split the outputs? Who knows.. latent_bank_latent_dim * latent_bank_blocks,
self.latent_linears = nn.ModuleList([EqualLinear(reducer_output_dim * reducer_output_nf, latent_bank_latent_dim, activation="fused_lrelu")
activation="fused_lrelu") self.latent_bank_blocks = latent_bank_blocks
for _ in range(latent_bank_blocks)])
def forward(self, x): def forward(self, x):
fea = self.initial_conv(x) fea = self.initial_conv(x)
@ -57,8 +56,7 @@ class GleanEncoder(nn.Module):
convolutional_features.append(f) convolutional_features.append(f)
latents = self.latent_conv(fea) latents = self.latent_conv(fea)
latents = [dense(latents.flatten(1, -1)) for dense in self.latent_linears] latents = self.latent_linear(latents.flatten(1, -1)).view(fea.shape[0], self.latent_bank_blocks, -1)
latents = torch.stack(latents, dim=1)
return rrdb_fea, convolutional_features, latents return rrdb_fea, convolutional_features, latents
@ -68,26 +66,23 @@ class GleanDecoder(nn.Module):
# To determine latent_bank_filters, use the `self.channels` map for the desired input dimensions from stylegan2_rosinality.py # To determine latent_bank_filters, use the `self.channels` map for the desired input dimensions from stylegan2_rosinality.py
def __init__(self, nf, latent_bank_filters=[512, 256, 128]): def __init__(self, nf, latent_bank_filters=[512, 256, 128]):
super().__init__() super().__init__()
self.initial_conv = ConvGnLelu(nf, nf, kernel_size=3, activation=False, norm=False, bias=True) self.initial_conv = ConvGnLelu(nf, nf, kernel_size=3, activation=True, norm=False, bias=True, weight_init_factor=.1)
# The paper calls for pixel shuffling each output of the decoder. We need to make sure that is possible. Doing it by using the latent bank filters as the output filters for each decoder stage decoder_block_shuffled_dims = [nf] + latent_bank_filters
assert latent_bank_filters[-1] % 4 == 0
decoder_block_shuffled_dims = [nf // 4]
decoder_block_shuffled_dims.extend([l // 4 for l in latent_bank_filters])
self.decoder_blocks = nn.ModuleList([ConvGnLelu(decoder_block_shuffled_dims[i] + latent_bank_filters[i], self.decoder_blocks = nn.ModuleList([ConvGnLelu(decoder_block_shuffled_dims[i] + latent_bank_filters[i],
latent_bank_filters[i], latent_bank_filters[i],
kernel_size=3, bias=True, norm=False, activation=False) kernel_size=3, bias=True, norm=False, activation=True,
weight_init_factor=.1)
for i in range(len(latent_bank_filters))]) for i in range(len(latent_bank_filters))])
self.shuffler = nn.PixelShuffle(2) # TODO: I'm a bit skeptical about this. It doesn't align with RRDB or StyleGAN. It also always produces artifacts in my experience. Try using interpolation instead.
final_dim = latent_bank_filters[-1] final_dim = latent_bank_filters[-1]
self.final_decode = nn.Sequential(ConvGnLelu(final_dim, final_dim, kernel_size=3, activation=True, bias=True, norm=False), self.final_decode = ConvGnLelu(final_dim, 3, kernel_size=3, activation=False, bias=True, norm=False, weight_init_factor=.1)
ConvGnLelu(final_dim, 3, kernel_size=3, activation=False, bias=True, norm=False))
def forward(self, rrdb_fea, latent_bank_fea): def forward(self, rrdb_fea, latent_bank_fea):
fea = self.initial_conv(rrdb_fea) fea = self.initial_conv(rrdb_fea)
for i, block in enumerate(self.decoder_blocks): for i, block in enumerate(self.decoder_blocks):
fea = self.shuffler(fea) # The paper calls for PixelShuffle here, but I don't have good experience with that. It also doesn't align with the way the underlying StyleGAN works.
fea = nn.functional.interpolate(fea, scale_factor=2, mode="nearest")
fea = torch.cat([fea, latent_bank_fea[i]], dim=1) fea = torch.cat([fea, latent_bank_fea[i]], dim=1)
fea = checkpoint(block, fea) fea = checkpoint(block, fea)
return self.final_decode(fea) return self.final_decode(fea)
@ -98,9 +93,8 @@ class GleanGenerator(nn.Module):
encoder_rrdb_nb=6, encoder_reductions=4, latent_bank_latent_dim=512, input_dim=32): encoder_rrdb_nb=6, encoder_reductions=4, latent_bank_latent_dim=512, input_dim=32):
super().__init__() super().__init__()
self.input_dim = input_dim self.input_dim = input_dim
latent_blocks = int(math.log(gen_output_dim, 2)) - 1 # From 4x4->gen_output_dim x gen_output_dim latent_blocks = int(math.log(gen_output_dim, 2)) # From 4x4->gen_output_dim x gen_output_dim + initial styled conv
latent_blocks = latent_blocks * 2 + 1 # Two styled convolutions per block, + an initial styled conv. self.encoder = GleanEncoder(nf, encoder_rrdb_nb, reductions=encoder_reductions, latent_bank_blocks=latent_blocks,
self.encoder = GleanEncoder(nf, encoder_rrdb_nb, reductions=encoder_reductions, latent_bank_blocks=latent_blocks * 2 + 1,
latent_bank_latent_dim=latent_bank_latent_dim, input_dim=input_dim) latent_bank_latent_dim=latent_bank_latent_dim, input_dim=input_dim)
decoder_blocks = int(math.log(gen_output_dim/input_dim, 2)) decoder_blocks = int(math.log(gen_output_dim/input_dim, 2))
latent_bank_filters_out = [512, 256, 128] # TODO: Use decoder_blocks to synthesize the correct value for latent_bank_filters here. The fixed defaults will work fine for testing, though. latent_bank_filters_out = [512, 256, 128] # TODO: Use decoder_blocks to synthesize the correct value for latent_bank_filters here. The fixed defaults will work fine for testing, though.

View File

@ -45,22 +45,21 @@ class Stylegan2LatentBank(nn.Module):
out = self.bank.input(latent_vectors[:, 0]) # The input here is only used to fetch the batch size. out = self.bank.input(latent_vectors[:, 0]) # The input here is only used to fetch the batch size.
out = self.bank.conv1(out, latent_vectors[:, 0], noise=None) out = self.bank.conv1(out, latent_vectors[:, 0], noise=None)
i, k = 1, 0 k = 0
decoder_outputs = [] decoder_outputs = []
for conv1, conv2 in zip(self.bank.convs[::2], self.bank.convs[1::2]): for conv1, conv2 in zip(self.bank.convs[::2], self.bank.convs[1::2]):
if k < len(self.fusion_blocks): if k < len(self.fusion_blocks):
out = torch.cat([convolutional_features[-k-1], out], dim=1) out = torch.cat([convolutional_features[-k-1], out], dim=1)
out = self.fusion_blocks[k](out) out = self.fusion_blocks[k](out)
out = conv1(out, latent_vectors[:, i], noise=None) out = conv1(out, latent_vectors[:, k], noise=None)
out = conv2(out, latent_vectors[:, i + 1], noise=None) out = conv2(out, latent_vectors[:, k], noise=None)
if k >= self.decoder_start: if k >= self.decoder_start:
decoder_outputs.append(out) decoder_outputs.append(out)
if k >= self.total_levels: if k >= self.total_levels:
break break
i += 2
k += 1 k += 1
return decoder_outputs return decoder_outputs

View File

@ -90,23 +90,23 @@ steps:
losses: losses:
pix: pix:
type: pix type: pix
weight: .05 weight: 1
criterion: l1 criterion: l2
real: hq real: hq
fake: gen fake: gen
feature: feature:
type: feature type: feature
after: 5000 after: 5000
which_model_F: vgg which_model_F: vgg
criterion: l1 criterion: l2
weight: 1 weight: .01
real: hq real: hq
fake: gen fake: gen
gan_gen_img: gan_gen_img:
after: 10000 after: 10000
type: generator_gan type: generator_gan
gan_type: gan gan_type: gan
weight: .02 weight: .01
noise: .004 noise: .004
discriminator: feature_discriminator discriminator: feature_discriminator
fake: gen fake: gen