diff --git a/codes/models/archs/stylegan/__init__.py b/codes/models/archs/stylegan/__init__.py index 4ab78ddf..c38166f4 100644 --- a/codes/models/archs/stylegan/__init__.py +++ b/codes/models/archs/stylegan/__init__.py @@ -1,14 +1,14 @@ -from models.archs.stylegan.stylegan2 import StyleGan2DivergenceLoss, StyleGan2PathLengthLoss -from models.archs.stylegan.stylegan2_unet_disc import StyleGan2UnetDivergenceLoss +import models.archs.stylegan.stylegan2 as stylegan2 +import models.archs.stylegan.stylegan2_unet_disc as stylegan2_unet def create_stylegan2_loss(opt_loss, env): type = opt_loss['type'] if type == 'stylegan2_divergence': - return StyleGan2DivergenceLoss(opt_loss, env) + return stylegan2.StyleGan2DivergenceLoss(opt_loss, env) elif type == 'stylegan2_pathlen': - return StyleGan2PathLengthLoss(opt_loss, env) + return stylegan2.StyleGan2PathLengthLoss(opt_loss, env) elif type == 'stylegan2_unet_divergence': - return StyleGan2UnetDivergenceLoss(opt_loss, env) + return stylegan2_unet.StyleGan2UnetDivergenceLoss(opt_loss, env) else: raise NotImplementedError \ No newline at end of file diff --git a/codes/models/networks.py b/codes/models/networks.py index 95ac3ae1..dffecec2 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -6,6 +6,8 @@ import munch import torch import torchvision from munch import munchify +import models.archs.stylegan.stylegan2 as stylegan2 +import models.archs.stylegan.stylegan2_unet_disc as stylegan2_unet import models.archs.fixup_resnet.DiscriminatorResnet_arch as DiscriminatorResnet_arch import models.archs.RRDBNet_arch as RRDBNet_arch @@ -22,8 +24,6 @@ from models.archs.stylegan.Discriminator_StyleGAN import StyleGanDiscriminator from models.archs.pyramid_arch import BasicResamplingFlowNet from models.archs.rrdb_with_adain_latent import AdaRRDBNet, LinearLatentEstimator from models.archs.rrdb_with_latent import LatentEstimator, RRDBNetWithLatent, LatentEstimator2 -from models.archs.stylegan.stylegan2 import StyleGan2GeneratorWithLatent, StyleGan2Discriminator, StyleGan2Augmentor -from models.archs.stylegan.stylegan2_unet_disc import StyleGan2UnetDiscriminator from models.archs.teco_resgen import TecoGen logger = logging.getLogger('base') @@ -136,7 +136,7 @@ def define_G(opt, net_key='network_G', scale=None): elif which_model == 'stylegan2': is_structured = opt_net['structured'] if 'structured' in opt_net.keys() else False attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else [] - netG = StyleGan2GeneratorWithLatent(image_size=opt_net['image_size'], latent_dim=opt_net['latent_dim'], + netG = stylegan2.StyleGan2GeneratorWithLatent(image_size=opt_net['image_size'], latent_dim=opt_net['latent_dim'], style_depth=opt_net['style_depth'], structure_input=is_structured, attn_layers=attn) else: @@ -199,11 +199,11 @@ def define_D_net(opt_net, img_sz=None, wrap=False): netD = SRGAN_arch.PyramidDiscriminator(in_nc=3, nf=opt_net['nf']) elif which_model == "stylegan2_discriminator": attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else [] - disc = StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn) - netD = StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability']) + disc = stylegan2.StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn) + netD = stylegan2.StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability']) elif which_model == "stylegan2_unet": - disc = StyleGan2UnetDiscriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc']) - netD = StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability']) + disc = stylegan2_unet.StyleGan2UnetDiscriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc']) + netD = stylegan2.StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability']) else: raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) return netD