Fix circular imports
This commit is contained in:
parent
99f0cfaab5
commit
e587d549f7
|
@ -1,14 +1,14 @@
|
||||||
from models.archs.stylegan.stylegan2 import StyleGan2DivergenceLoss, StyleGan2PathLengthLoss
|
import models.archs.stylegan.stylegan2 as stylegan2
|
||||||
from models.archs.stylegan.stylegan2_unet_disc import StyleGan2UnetDivergenceLoss
|
import models.archs.stylegan.stylegan2_unet_disc as stylegan2_unet
|
||||||
|
|
||||||
|
|
||||||
def create_stylegan2_loss(opt_loss, env):
|
def create_stylegan2_loss(opt_loss, env):
|
||||||
type = opt_loss['type']
|
type = opt_loss['type']
|
||||||
if type == 'stylegan2_divergence':
|
if type == 'stylegan2_divergence':
|
||||||
return StyleGan2DivergenceLoss(opt_loss, env)
|
return stylegan2.StyleGan2DivergenceLoss(opt_loss, env)
|
||||||
elif type == 'stylegan2_pathlen':
|
elif type == 'stylegan2_pathlen':
|
||||||
return StyleGan2PathLengthLoss(opt_loss, env)
|
return stylegan2.StyleGan2PathLengthLoss(opt_loss, env)
|
||||||
elif type == 'stylegan2_unet_divergence':
|
elif type == 'stylegan2_unet_divergence':
|
||||||
return StyleGan2UnetDivergenceLoss(opt_loss, env)
|
return stylegan2_unet.StyleGan2UnetDivergenceLoss(opt_loss, env)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
|
@ -6,6 +6,8 @@ import munch
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
from munch import munchify
|
from munch import munchify
|
||||||
|
import models.archs.stylegan.stylegan2 as stylegan2
|
||||||
|
import models.archs.stylegan.stylegan2_unet_disc as stylegan2_unet
|
||||||
|
|
||||||
import models.archs.fixup_resnet.DiscriminatorResnet_arch as DiscriminatorResnet_arch
|
import models.archs.fixup_resnet.DiscriminatorResnet_arch as DiscriminatorResnet_arch
|
||||||
import models.archs.RRDBNet_arch as RRDBNet_arch
|
import models.archs.RRDBNet_arch as RRDBNet_arch
|
||||||
|
@ -22,8 +24,6 @@ from models.archs.stylegan.Discriminator_StyleGAN import StyleGanDiscriminator
|
||||||
from models.archs.pyramid_arch import BasicResamplingFlowNet
|
from models.archs.pyramid_arch import BasicResamplingFlowNet
|
||||||
from models.archs.rrdb_with_adain_latent import AdaRRDBNet, LinearLatentEstimator
|
from models.archs.rrdb_with_adain_latent import AdaRRDBNet, LinearLatentEstimator
|
||||||
from models.archs.rrdb_with_latent import LatentEstimator, RRDBNetWithLatent, LatentEstimator2
|
from models.archs.rrdb_with_latent import LatentEstimator, RRDBNetWithLatent, LatentEstimator2
|
||||||
from models.archs.stylegan.stylegan2 import StyleGan2GeneratorWithLatent, StyleGan2Discriminator, StyleGan2Augmentor
|
|
||||||
from models.archs.stylegan.stylegan2_unet_disc import StyleGan2UnetDiscriminator
|
|
||||||
from models.archs.teco_resgen import TecoGen
|
from models.archs.teco_resgen import TecoGen
|
||||||
|
|
||||||
logger = logging.getLogger('base')
|
logger = logging.getLogger('base')
|
||||||
|
@ -136,7 +136,7 @@ def define_G(opt, net_key='network_G', scale=None):
|
||||||
elif which_model == 'stylegan2':
|
elif which_model == 'stylegan2':
|
||||||
is_structured = opt_net['structured'] if 'structured' in opt_net.keys() else False
|
is_structured = opt_net['structured'] if 'structured' in opt_net.keys() else False
|
||||||
attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
|
attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
|
||||||
netG = StyleGan2GeneratorWithLatent(image_size=opt_net['image_size'], latent_dim=opt_net['latent_dim'],
|
netG = stylegan2.StyleGan2GeneratorWithLatent(image_size=opt_net['image_size'], latent_dim=opt_net['latent_dim'],
|
||||||
style_depth=opt_net['style_depth'], structure_input=is_structured,
|
style_depth=opt_net['style_depth'], structure_input=is_structured,
|
||||||
attn_layers=attn)
|
attn_layers=attn)
|
||||||
else:
|
else:
|
||||||
|
@ -199,11 +199,11 @@ def define_D_net(opt_net, img_sz=None, wrap=False):
|
||||||
netD = SRGAN_arch.PyramidDiscriminator(in_nc=3, nf=opt_net['nf'])
|
netD = SRGAN_arch.PyramidDiscriminator(in_nc=3, nf=opt_net['nf'])
|
||||||
elif which_model == "stylegan2_discriminator":
|
elif which_model == "stylegan2_discriminator":
|
||||||
attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
|
attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
|
||||||
disc = StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn)
|
disc = stylegan2.StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn)
|
||||||
netD = StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])
|
netD = stylegan2.StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])
|
||||||
elif which_model == "stylegan2_unet":
|
elif which_model == "stylegan2_unet":
|
||||||
disc = StyleGan2UnetDiscriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'])
|
disc = stylegan2_unet.StyleGan2UnetDiscriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'])
|
||||||
netD = StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])
|
netD = stylegan2.StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
||||||
return netD
|
return netD
|
||||||
|
|
Loading…
Reference in New Issue
Block a user