GLEAN mod to support custom initial strides
This commit is contained in:
parent
3fd627fc62
commit
f9be049adb
|
@ -33,9 +33,9 @@ class GleanEncoderBlock(nn.Module):
|
||||||
# and latent vectors (`C` shape=[b,l,f] l=levels aka C_sub) for use with the latent bank.
|
# and latent vectors (`C` shape=[b,l,f] l=levels aka C_sub) for use with the latent bank.
|
||||||
# Note that latent levels and convolutional feature levels do not necessarily match, per the paper.
|
# Note that latent levels and convolutional feature levels do not necessarily match, per the paper.
|
||||||
class GleanEncoder(nn.Module):
|
class GleanEncoder(nn.Module):
|
||||||
def __init__(self, nf, nb, reductions=4, latent_bank_blocks=7, latent_bank_latent_dim=512, input_dim=32):
|
def __init__(self, nf, nb, reductions=4, latent_bank_blocks=7, latent_bank_latent_dim=512, input_dim=32, initial_stride=1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.initial_conv = ConvGnLelu(3, nf, kernel_size=7, activation=False, norm=False, bias=True)
|
self.initial_conv = ConvGnLelu(3, nf, kernel_size=7, activation=False, norm=False, bias=True, stride=initial_stride)
|
||||||
self.rrdb_blocks = nn.Sequential(*[RRDB(nf) for _ in range(nb)])
|
self.rrdb_blocks = nn.Sequential(*[RRDB(nf) for _ in range(nb)])
|
||||||
self.reducers = nn.ModuleList([GleanEncoderBlock(nf * 2 ** i) for i in range(reductions)])
|
self.reducers = nn.ModuleList([GleanEncoderBlock(nf * 2 ** i) for i in range(reductions)])
|
||||||
|
|
||||||
|
@ -91,12 +91,12 @@ class GleanDecoder(nn.Module):
|
||||||
|
|
||||||
class GleanGenerator(nn.Module):
|
class GleanGenerator(nn.Module):
|
||||||
def __init__(self, nf, latent_bank_pretrained_weights, latent_bank_max_dim=1024, gen_output_dim=256,
|
def __init__(self, nf, latent_bank_pretrained_weights, latent_bank_max_dim=1024, gen_output_dim=256,
|
||||||
encoder_rrdb_nb=6, encoder_reductions=4, latent_bank_latent_dim=512, input_dim=32):
|
encoder_rrdb_nb=6, encoder_reductions=4, latent_bank_latent_dim=512, input_dim=32, initial_stride=1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.input_dim = input_dim
|
self.input_dim = input_dim // initial_stride
|
||||||
latent_blocks = int(math.log(gen_output_dim, 2)) # From 4x4->gen_output_dim x gen_output_dim + initial styled conv
|
latent_blocks = int(math.log(gen_output_dim, 2)) # From 4x4->gen_output_dim x gen_output_dim + initial styled conv
|
||||||
self.encoder = GleanEncoder(nf, encoder_rrdb_nb, reductions=encoder_reductions, latent_bank_blocks=latent_blocks,
|
self.encoder = GleanEncoder(nf, encoder_rrdb_nb, reductions=encoder_reductions, latent_bank_blocks=latent_blocks,
|
||||||
latent_bank_latent_dim=latent_bank_latent_dim, input_dim=input_dim)
|
latent_bank_latent_dim=latent_bank_latent_dim, input_dim=input_dim, initial_stride=initial_stride)
|
||||||
decoder_blocks = int(math.log(gen_output_dim/input_dim, 2))
|
decoder_blocks = int(math.log(gen_output_dim/input_dim, 2))
|
||||||
latent_bank_filters_out = [512, 256, 128] # TODO: Use decoder_blocks to synthesize the correct value for latent_bank_filters here. The fixed defaults will work fine for testing, though.
|
latent_bank_filters_out = [512, 256, 128] # TODO: Use decoder_blocks to synthesize the correct value for latent_bank_filters here. The fixed defaults will work fine for testing, though.
|
||||||
self.latent_bank = Stylegan2LatentBank(latent_bank_pretrained_weights, encoder_nf=nf, max_dim=latent_bank_max_dim,
|
self.latent_bank = Stylegan2LatentBank(latent_bank_pretrained_weights, encoder_nf=nf, max_dim=latent_bank_max_dim,
|
||||||
|
@ -113,4 +113,9 @@ class GleanGenerator(nn.Module):
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def register_glean(opt_net, opt):
|
def register_glean(opt_net, opt):
|
||||||
return GleanGenerator(opt_net['nf'], opt_net['pretrained_stylegan'])
|
kwargs = {}
|
||||||
|
exclusions = ['which_model_G', 'type']
|
||||||
|
for k, v in opt.items():
|
||||||
|
if k not in exclusions:
|
||||||
|
kwargs[k] = v
|
||||||
|
return GleanGenerator(**kwargs)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user