diff --git a/codes/models/archs/stylegan2.py b/codes/models/archs/stylegan2.py index be6004f2..a1c5a1cb 100644 --- a/codes/models/archs/stylegan2.py +++ b/codes/models/archs/stylegan2.py @@ -537,7 +537,7 @@ class Generator(nn.Module): for style, block, attn in zip(styles, self.blocks, self.attns): if exists(attn): - x = attn(x) + x = checkpoint(attn, x) x, rgb = checkpoint(block, x, rgb, style, input_noise, structure_input) return rgb diff --git a/codes/models/networks.py b/codes/models/networks.py index ed015203..0d5296b0 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -134,8 +134,10 @@ def define_G(opt, net_key='network_G', scale=None): netG = LinearLatentEstimator(in_nc=3, nf=opt_net['nf']) elif which_model == 'stylegan2': is_structured = opt_net['structured'] if 'structured' in opt_net.keys() else False + attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else [] netG = StyleGan2GeneratorWithLatent(image_size=opt_net['image_size'], latent_dim=opt_net['latent_dim'], - style_depth=opt_net['style_depth'], structure_input=is_structured) + style_depth=opt_net['style_depth'], structure_input=is_structured, + attn_layers=attn) else: raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model)) return netG @@ -195,7 +197,8 @@ def define_D_net(opt_net, img_sz=None, wrap=False): elif which_model == "pyramid_disc": netD = SRGAN_arch.PyramidDiscriminator(in_nc=3, nf=opt_net['nf']) elif which_model == "stylegan2_discriminator": - disc = StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc']) + attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else [] + disc = StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn) netD = StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability']) else: raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))