GLEAN mod to support custom initial strides

This commit is contained in:
James Betker 2020-12-26 13:51:14 -07:00
parent 3fd627fc62
commit f9be049adb

View File

@ -33,9 +33,9 @@ class GleanEncoderBlock(nn.Module):
# and latent vectors (`C` shape=[b,l,f] l=levels aka C_sub) for use with the latent bank. # and latent vectors (`C` shape=[b,l,f] l=levels aka C_sub) for use with the latent bank.
# Note that latent levels and convolutional feature levels do not necessarily match, per the paper. # Note that latent levels and convolutional feature levels do not necessarily match, per the paper.
class GleanEncoder(nn.Module): class GleanEncoder(nn.Module):
def __init__(self, nf, nb, reductions=4, latent_bank_blocks=7, latent_bank_latent_dim=512, input_dim=32): def __init__(self, nf, nb, reductions=4, latent_bank_blocks=7, latent_bank_latent_dim=512, input_dim=32, initial_stride=1):
super().__init__() super().__init__()
self.initial_conv = ConvGnLelu(3, nf, kernel_size=7, activation=False, norm=False, bias=True) self.initial_conv = ConvGnLelu(3, nf, kernel_size=7, activation=False, norm=False, bias=True, stride=initial_stride)
self.rrdb_blocks = nn.Sequential(*[RRDB(nf) for _ in range(nb)]) self.rrdb_blocks = nn.Sequential(*[RRDB(nf) for _ in range(nb)])
self.reducers = nn.ModuleList([GleanEncoderBlock(nf * 2 ** i) for i in range(reductions)]) self.reducers = nn.ModuleList([GleanEncoderBlock(nf * 2 ** i) for i in range(reductions)])
@ -91,12 +91,12 @@ class GleanDecoder(nn.Module):
class GleanGenerator(nn.Module): class GleanGenerator(nn.Module):
def __init__(self, nf, latent_bank_pretrained_weights, latent_bank_max_dim=1024, gen_output_dim=256, def __init__(self, nf, latent_bank_pretrained_weights, latent_bank_max_dim=1024, gen_output_dim=256,
encoder_rrdb_nb=6, encoder_reductions=4, latent_bank_latent_dim=512, input_dim=32): encoder_rrdb_nb=6, encoder_reductions=4, latent_bank_latent_dim=512, input_dim=32, initial_stride=1):
super().__init__() super().__init__()
self.input_dim = input_dim self.input_dim = input_dim // initial_stride
latent_blocks = int(math.log(gen_output_dim, 2)) # From 4x4->gen_output_dim x gen_output_dim + initial styled conv latent_blocks = int(math.log(gen_output_dim, 2)) # From 4x4->gen_output_dim x gen_output_dim + initial styled conv
self.encoder = GleanEncoder(nf, encoder_rrdb_nb, reductions=encoder_reductions, latent_bank_blocks=latent_blocks, self.encoder = GleanEncoder(nf, encoder_rrdb_nb, reductions=encoder_reductions, latent_bank_blocks=latent_blocks,
latent_bank_latent_dim=latent_bank_latent_dim, input_dim=input_dim) latent_bank_latent_dim=latent_bank_latent_dim, input_dim=input_dim, initial_stride=initial_stride)
decoder_blocks = int(math.log(gen_output_dim/input_dim, 2)) decoder_blocks = int(math.log(gen_output_dim/input_dim, 2))
latent_bank_filters_out = [512, 256, 128] # TODO: Use decoder_blocks to synthesize the correct value for latent_bank_filters here. The fixed defaults will work fine for testing, though. latent_bank_filters_out = [512, 256, 128] # TODO: Use decoder_blocks to synthesize the correct value for latent_bank_filters here. The fixed defaults will work fine for testing, though.
self.latent_bank = Stylegan2LatentBank(latent_bank_pretrained_weights, encoder_nf=nf, max_dim=latent_bank_max_dim, self.latent_bank = Stylegan2LatentBank(latent_bank_pretrained_weights, encoder_nf=nf, max_dim=latent_bank_max_dim,
@ -113,4 +113,9 @@ class GleanGenerator(nn.Module):
@register_model @register_model
def register_glean(opt_net, opt): def register_glean(opt_net, opt):
return GleanGenerator(opt_net['nf'], opt_net['pretrained_stylegan']) kwargs = {}
exclusions = ['which_model_G', 'type']
for k, v in opt.items():
if k not in exclusions:
kwargs[k] = v
return GleanGenerator(**kwargs)