forked from mrq/DL-Art-School
Allow attention to be specified for stylegan2
This commit is contained in:
parent
cdc5ac30e9
commit
423ee7cb90
|
@ -537,7 +537,7 @@ class Generator(nn.Module):
|
||||||
|
|
||||||
for style, block, attn in zip(styles, self.blocks, self.attns):
|
for style, block, attn in zip(styles, self.blocks, self.attns):
|
||||||
if exists(attn):
|
if exists(attn):
|
||||||
x = attn(x)
|
x = checkpoint(attn, x)
|
||||||
x, rgb = checkpoint(block, x, rgb, style, input_noise, structure_input)
|
x, rgb = checkpoint(block, x, rgb, style, input_noise, structure_input)
|
||||||
|
|
||||||
return rgb
|
return rgb
|
||||||
|
|
|
@ -134,8 +134,10 @@ def define_G(opt, net_key='network_G', scale=None):
|
||||||
netG = LinearLatentEstimator(in_nc=3, nf=opt_net['nf'])
|
netG = LinearLatentEstimator(in_nc=3, nf=opt_net['nf'])
|
||||||
elif which_model == 'stylegan2':
|
elif which_model == 'stylegan2':
|
||||||
is_structured = opt_net['structured'] if 'structured' in opt_net.keys() else False
|
is_structured = opt_net['structured'] if 'structured' in opt_net.keys() else False
|
||||||
|
attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
|
||||||
netG = StyleGan2GeneratorWithLatent(image_size=opt_net['image_size'], latent_dim=opt_net['latent_dim'],
|
netG = StyleGan2GeneratorWithLatent(image_size=opt_net['image_size'], latent_dim=opt_net['latent_dim'],
|
||||||
style_depth=opt_net['style_depth'], structure_input=is_structured)
|
style_depth=opt_net['style_depth'], structure_input=is_structured,
|
||||||
|
attn_layers=attn)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
||||||
return netG
|
return netG
|
||||||
|
@ -195,7 +197,8 @@ def define_D_net(opt_net, img_sz=None, wrap=False):
|
||||||
elif which_model == "pyramid_disc":
|
elif which_model == "pyramid_disc":
|
||||||
netD = SRGAN_arch.PyramidDiscriminator(in_nc=3, nf=opt_net['nf'])
|
netD = SRGAN_arch.PyramidDiscriminator(in_nc=3, nf=opt_net['nf'])
|
||||||
elif which_model == "stylegan2_discriminator":
|
elif which_model == "stylegan2_discriminator":
|
||||||
disc = StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'])
|
attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
|
||||||
|
disc = StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn)
|
||||||
netD = StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])
|
netD = StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user