forked from mrq/DL-Art-School
More refactor changes
This commit is contained in:
parent
5640e4efe4
commit
d875ca8342
|
@ -15,7 +15,7 @@ import torch.nn.functional as F
|
|||
from tqdm import tqdm
|
||||
|
||||
from data import create_dataset
|
||||
from models.archs.arch_util import PixelUnshuffle
|
||||
from models.arch_util import PixelUnshuffle
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ from torchvision import transforms
|
|||
import torch.nn as nn
|
||||
from pathlib import Path
|
||||
|
||||
import models.archs.stylegan.stylegan2_lucidrains as sg2
|
||||
import models.stylegan.stylegan2_lucidrains as sg2
|
||||
|
||||
|
||||
def convert_transparent_to_rgb(image):
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
import models.archs.SwitchedResidualGenerator_arch as srg
|
||||
import models.SwitchedResidualGenerator_arch as srg
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from switched_conv.switched_conv_util import save_attention_to_image
|
||||
from switched_conv.switched_conv import compute_attention_specificity
|
||||
from models.archs.arch_util import ConvGnLelu, ExpansionBlock, MultiConvBlock
|
||||
from models.arch_util import ConvGnLelu, ExpansionBlock, MultiConvBlock
|
||||
import functools
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
|
||||
from models.archs.arch_util import make_layer, default_init_weights, ConvGnSilu, ConvGnLelu
|
||||
from models.arch_util import make_layer, default_init_weights, ConvGnSilu, ConvGnLelu
|
||||
from utils.util import checkpoint, sequential_checkpoint
|
||||
|
||||
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from switched_conv.switched_conv import BareConvSwitch, compute_attention_specificity, AttentionNorm
|
||||
from models.switched_conv.switched_conv import BareConvSwitch, compute_attention_specificity, AttentionNorm
|
||||
import torch.nn.functional as F
|
||||
import functools
|
||||
from collections import OrderedDict
|
||||
from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConvGnLelu, MultiConvBlock, \
|
||||
from models.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConvGnLelu, MultiConvBlock, \
|
||||
SiLU, UpconvBlock, ReferenceJoinBlock
|
||||
from switched_conv.switched_conv_util import save_attention_to_image_rgb
|
||||
from models.switched_conv.switched_conv_util import save_attention_to_image_rgb
|
||||
import os
|
||||
from models.archs.spinenet_arch import SpineNet
|
||||
from models.spinenet_arch import SpineNet
|
||||
import torchvision
|
||||
from utils.util import checkpoint
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch.nn.functional as F
|
|||
from torch import nn
|
||||
|
||||
from data.byol_attachment import reconstructed_shared_regions
|
||||
from models.archs.byol.byol_model_wrapper import singleton, EMA, get_module_device, set_requires_grad, \
|
||||
from models.byol.byol_model_wrapper import singleton, EMA, get_module_device, set_requires_grad, \
|
||||
update_moving_average
|
||||
from utils.util import checkpoint
|
||||
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from models.archs.RRDBNet_arch import RRDB, RRDBWithBypass
|
||||
from models.archs.arch_util import ConvBnLelu, ConvGnLelu, ExpansionBlock, ConvGnSilu, ResidualBlockGN
|
||||
from models.RRDBNet_arch import RRDB, RRDBWithBypass
|
||||
from models.arch_util import ConvBnLelu, ConvGnLelu, ExpansionBlock, ConvGnSilu, ResidualBlockGN
|
||||
import torch.nn.functional as F
|
||||
from models.archs.SwitchedResidualGenerator_arch import gather_2d
|
||||
from models.SwitchedResidualGenerator_arch import gather_2d
|
||||
from utils.util import checkpoint
|
||||
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ import torch.nn.functional as F
|
|||
from torch.nn.init import kaiming_normal
|
||||
|
||||
from torchvision.models.resnet import BasicBlock, Bottleneck
|
||||
from models.archs.arch_util import ConvGnSilu, ConvBnSilu, ConvBnRelu
|
||||
from models.arch_util import ConvGnSilu, ConvBnSilu, ConvBnRelu
|
||||
|
||||
|
||||
def constant_init(module, val, bias=0):
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
from torch import nn as nn
|
||||
|
||||
from models.srflow_orig import thops
|
||||
from models.srflow import thops
|
||||
|
||||
|
||||
class _ActNorm(nn.Module):
|
|
@ -1,8 +1,8 @@
|
|||
import torch
|
||||
from torch import nn as nn
|
||||
|
||||
from models.srflow_orig import thops
|
||||
from models.archs.srflow_orig.flow import Conv2d, Conv2dZeros
|
||||
from models.srflow import thops
|
||||
from models.srflow.flow import Conv2d, Conv2dZeros
|
||||
from utils.util import opt_get
|
||||
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
from torch import nn as nn
|
||||
|
||||
import models.archs.srflow_orig.Permutations
|
||||
import models.srflow.Permutations
|
||||
|
||||
|
||||
def getConditional(rrdbResults, position):
|
||||
|
@ -44,17 +44,17 @@ class FlowStep(nn.Module):
|
|||
self.acOpt = acOpt
|
||||
|
||||
# 1. actnorm
|
||||
self.actnorm = models.archs.srflow_orig.FlowActNorms.ActNorm2d(in_channels, actnorm_scale)
|
||||
self.actnorm = models.srflow.FlowActNorms.ActNorm2d(in_channels, actnorm_scale)
|
||||
|
||||
# 2. permute
|
||||
if flow_permutation == "invconv":
|
||||
self.invconv = models.archs.srflow_orig.Permutations.InvertibleConv1x1(
|
||||
self.invconv = models.srflow.Permutations.InvertibleConv1x1(
|
||||
in_channels, LU_decomposed=LU_decomposed)
|
||||
|
||||
# 3. coupling
|
||||
if flow_coupling == "CondAffineSeparatedAndCond":
|
||||
self.affine = models.archs.srflow_orig.FlowAffineCouplingsAblation.CondAffineSeparatedAndCond(in_channels=in_channels,
|
||||
opt=opt)
|
||||
self.affine = models.srflow.FlowAffineCouplingsAblation.CondAffineSeparatedAndCond(in_channels=in_channels,
|
||||
opt=opt)
|
||||
elif flow_coupling == "noCoupling":
|
||||
pass
|
||||
else:
|
|
@ -2,12 +2,12 @@ import numpy as np
|
|||
import torch
|
||||
from torch import nn as nn
|
||||
|
||||
import models.archs.srflow_orig.Split
|
||||
from models.archs.srflow_orig import flow
|
||||
from models.srflow_orig import thops
|
||||
from models.archs.srflow_orig.Split import Split2d
|
||||
from models.archs.srflow_orig.glow_arch import f_conv2d_bias
|
||||
from models.archs.srflow_orig.FlowStep import FlowStep
|
||||
import models.srflow.Split
|
||||
from models.srflow import flow
|
||||
from models.srflow import thops
|
||||
from models.srflow.Split import Split2d
|
||||
from models.srflow.glow_arch import f_conv2d_bias
|
||||
from models.srflow.FlowStep import FlowStep
|
||||
from utils.util import opt_get, checkpoint
|
||||
|
||||
|
||||
|
@ -146,8 +146,8 @@ class FlowUpsamplerNet(nn.Module):
|
|||
t = opt_get(opt, ['networks', 'generator','flow', 'split', 'type'], 'Split2d')
|
||||
|
||||
if t == 'Split2d':
|
||||
split = models.archs.srflow_orig.Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position,
|
||||
cond_channels=cond_channels, consume_ratio=consume_ratio, opt=opt)
|
||||
split = models.srflow.Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position,
|
||||
cond_channels=cond_channels, consume_ratio=consume_ratio, opt=opt)
|
||||
self.layers.append(split)
|
||||
self.output_shapes.append([-1, split.num_channels_pass, H, W])
|
||||
self.C = split.num_channels_pass
|
|
@ -3,7 +3,7 @@ import torch
|
|||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from models.srflow_orig import thops
|
||||
from models.srflow import thops
|
||||
|
||||
|
||||
class InvertibleConv1x1(nn.Module):
|
|
@ -2,8 +2,8 @@ import functools
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import models.archs.srflow_orig.module_util as mutil
|
||||
from models.archs.arch_util import default_init_weights, ConvGnSilu, ConvGnLelu
|
||||
import models.srflow.module_util as mutil
|
||||
from models.arch_util import default_init_weights, ConvGnSilu, ConvGnLelu
|
||||
from utils.util import opt_get
|
||||
|
||||
|
|
@ -4,10 +4,10 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from models.archs.srflow_orig.RRDBNet_arch import RRDBNet
|
||||
from models.archs.srflow_orig.FlowUpsamplerNet import FlowUpsamplerNet
|
||||
import models.srflow_orig.thops as thops
|
||||
import models.archs.srflow_orig.flow as flow
|
||||
from models.srflow.RRDBNet_arch import RRDBNet
|
||||
from models.srflow.FlowUpsamplerNet import FlowUpsamplerNet
|
||||
import models.srflow.thops as thops
|
||||
import models.srflow.flow as flow
|
||||
from utils.util import opt_get
|
||||
|
||||
|
|
@ -1,8 +1,8 @@
|
|||
import torch
|
||||
from torch import nn as nn
|
||||
|
||||
from models.srflow_orig import thops
|
||||
from models.archs.srflow_orig.flow import Conv2dZeros, GaussianDiag
|
||||
from models.srflow import thops
|
||||
from models.srflow.flow import Conv2dZeros, GaussianDiag
|
||||
from utils.util import opt_get
|
||||
|
||||
|
|
@ -3,7 +3,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
from models.archs.srflow_orig.FlowActNorms import ActNorm2d
|
||||
from models.srflow.FlowActNorms import ActNorm2d
|
||||
from . import thops
|
||||
|
||||
|
|
@ -8,9 +8,9 @@ import torch.nn.functional as F
|
|||
import functools
|
||||
from collections import OrderedDict
|
||||
|
||||
from models.archs.SwitchedResidualGenerator_arch import HalvingProcessingBlock, ConfigurableSwitchComputer
|
||||
from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, MultiConvBlock
|
||||
from switched_conv.switched_conv import BareConvSwitch, AttentionNorm
|
||||
from models.SwitchedResidualGenerator_arch import HalvingProcessingBlock, ConfigurableSwitchComputer
|
||||
from models.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, MultiConvBlock
|
||||
from models.switched_conv.switched_conv import BareConvSwitch, AttentionNorm
|
||||
from utils.util import checkpoint
|
||||
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import models.archs.stylegan.stylegan2_lucidrains as stylegan2
|
||||
import models.stylegan.stylegan2_lucidrains as stylegan2
|
||||
|
||||
|
||||
def create_stylegan2_loss(opt_loss, env):
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch.nn as nn
|
|||
import torchvision
|
||||
|
||||
from utils.util import sequential_checkpoint
|
||||
from models.archs.arch_util import ConvGnSilu, make_layer
|
||||
from models.arch_util import ConvGnSilu, make_layer
|
||||
|
||||
|
||||
class TecoResblock(nn.Module):
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import torch
|
||||
|
||||
from models.archs.spinenet_arch import SpineNet
|
||||
from models.spinenet_arch import SpineNet
|
||||
|
||||
if __name__ == '__main__':
|
||||
pretrained_path = '../../experiments/train_sbyol_512unsupervised_restart/models/48000_generator.pth'
|
||||
|
|
|
@ -13,7 +13,7 @@ import numpy as np
|
|||
|
||||
import utils
|
||||
from data.image_folder_dataset import ImageFolderDataset
|
||||
from models.archs.spinenet_arch import SpineNet
|
||||
from models.spinenet_arch import SpineNet
|
||||
|
||||
|
||||
# Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved
|
||||
|
|
|
@ -13,7 +13,7 @@ import torch
|
|||
import numpy as np
|
||||
from torchvision import utils
|
||||
|
||||
from models.archs.stylegan.stylegan2_rosinality import Generator, Discriminator
|
||||
from models.stylegan.stylegan2_rosinality import Generator, Discriminator
|
||||
|
||||
|
||||
# Converts from the TF state_dict input provided into the vars originally expected from the rosinality converter.
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
from torch.cuda.amp import autocast
|
||||
from models.archs.flownet2.networks import Resample2d
|
||||
from models.archs.flownet2 import flow2img
|
||||
from models.flownet2.networks import Resample2d
|
||||
from models.flownet2 import flow2img
|
||||
from trainer.injectors import Injector
|
||||
|
||||
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
from torch.cuda.amp import autocast
|
||||
|
||||
from models.archs.stylegan.stylegan2_lucidrains import gradient_penalty
|
||||
from models.stylegan.stylegan2_lucidrains import gradient_penalty
|
||||
from trainer.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name
|
||||
from models.archs.flownet2.networks import Resample2d
|
||||
from models.flownet2.networks import Resample2d
|
||||
from trainer.injectors import Injector
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from trainer.eval.flow_gaussian_nll import FlowGaussianNll
|
||||
from trainer.eval.sr_style import SrStyleTransferEvaluator
|
||||
from trainer.eval import StyleTransferEvaluator
|
||||
from trainer.eval.style import StyleTransferEvaluator
|
||||
|
||||
|
||||
def create_evaluator(model, opt_eval, env):
|
||||
|
|
|
@ -6,7 +6,7 @@ import trainer.eval.evaluator as evaluator
|
|||
|
||||
# Evaluate how close to true Gaussian a flow network predicts in a "normal" pass given a LQ/HQ image pair.
|
||||
from data.image_folder_dataset import ImageFolderDataset
|
||||
from models.archs.srflow_orig.flow import GaussianDiag
|
||||
from models.srflow.flow import GaussianDiag
|
||||
|
||||
|
||||
class FlowGaussianNll(evaluator.Evaluator):
|
||||
|
|
|
@ -19,7 +19,7 @@ def create_injector(opt_inject, env):
|
|||
from trainer.custom_training_components import create_stereoscopic_injector
|
||||
return create_stereoscopic_injector(opt_inject, env)
|
||||
elif 'igpt' in type:
|
||||
from models.archs.transformers.igpt import gpt2
|
||||
from models.transformers.igpt import gpt2
|
||||
return gpt2.create_injector(opt_inject, env)
|
||||
elif type == 'generator':
|
||||
return ImageGeneratorInjector(opt_inject, env)
|
||||
|
@ -372,7 +372,7 @@ class MultiFrameCombiner(Injector):
|
|||
self.in_hq_key = opt['in_hq']
|
||||
self.out_lq_key = opt['out']
|
||||
self.out_hq_key = opt['out_hq']
|
||||
from models.archs.flownet2.networks import Resample2d
|
||||
from models.flownet2.networks import Resample2d
|
||||
self.resampler = Resample2d()
|
||||
|
||||
def combine(self, state):
|
||||
|
|
|
@ -14,7 +14,7 @@ def create_loss(opt_loss, env):
|
|||
from trainer.custom_training_components import create_teco_loss
|
||||
return create_teco_loss(opt_loss, env)
|
||||
elif 'stylegan2_' in type:
|
||||
from models.archs.stylegan import create_stylegan2_loss
|
||||
from models.stylegan import create_stylegan2_loss
|
||||
return create_stylegan2_loss(opt_loss, env)
|
||||
elif type == 'crossentropy':
|
||||
return CrossEntropy(opt_loss, env)
|
||||
|
@ -311,7 +311,7 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
|||
|
||||
if self.gradient_penalty:
|
||||
# Apply gradient penalty. TODO: migrate this elsewhere.
|
||||
from models.archs.stylegan.stylegan2_lucidrains import gradient_penalty
|
||||
from models.stylegan.stylegan2_lucidrains import gradient_penalty
|
||||
assert len(real) == 1 # Grad penalty doesn't currently support multi-input discriminators.
|
||||
gp = gradient_penalty(real[0], d_real)
|
||||
self.metrics.append(("gradient_penalty", gp.clone().detach()))
|
||||
|
|
|
@ -6,16 +6,16 @@ import munch
|
|||
import torch
|
||||
import torchvision
|
||||
from munch import munchify
|
||||
import models.archs.stylegan.stylegan2_lucidrains as stylegan2
|
||||
import models.stylegan.stylegan2_lucidrains as stylegan2
|
||||
|
||||
import models.archs.fixup_resnet.DiscriminatorResnet_arch as DiscriminatorResnet_arch
|
||||
import models.archs.RRDBNet_arch as RRDBNet_arch
|
||||
import models.archs.SwitchedResidualGenerator_arch as SwitchedGen_arch
|
||||
import models.archs.discriminator_vgg_arch as SRGAN_arch
|
||||
import models.archs.feature_arch as feature_arch
|
||||
from models.archs import srg2_classic
|
||||
from models.archs.stylegan.Discriminator_StyleGAN import StyleGanDiscriminator
|
||||
from models.archs.tecogan.teco_resgen import TecoGen
|
||||
import models.fixup_resnet.DiscriminatorResnet_arch as DiscriminatorResnet_arch
|
||||
import models.RRDBNet_arch as RRDBNet_arch
|
||||
import models.SwitchedResidualGenerator_arch as SwitchedGen_arch
|
||||
import models.discriminator_vgg_arch as SRGAN_arch
|
||||
import models.feature_arch as feature_arch
|
||||
from models import srg2_classic
|
||||
from models.stylegan.Discriminator_StyleGAN import StyleGanDiscriminator
|
||||
from models.tecogan.teco_resgen import TecoGen
|
||||
from utils.util import opt_get
|
||||
|
||||
logger = logging.getLogger('base')
|
||||
|
@ -30,7 +30,7 @@ def define_G(opt, opt_net, scale=None):
|
|||
if which_model == 'RRDBNetBypass':
|
||||
block = RRDBNet_arch.RRDBWithBypass
|
||||
elif which_model == 'RRDBNetLambda':
|
||||
from models.archs.lambda_rrdb import LambdaRRDB
|
||||
from models.lambda_rrdb import LambdaRRDB
|
||||
block = LambdaRRDB
|
||||
else:
|
||||
block = RRDBNet_arch.RRDB
|
||||
|
@ -62,7 +62,7 @@ def define_G(opt, opt_net, scale=None):
|
|||
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
|
||||
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
|
||||
elif which_model == "flownet2":
|
||||
from models.archs.flownet2 import FlowNet2
|
||||
from models.flownet2 import FlowNet2
|
||||
ld = 'load_path' in opt_net.keys()
|
||||
args = munch.Munch({'fp16': False, 'rgb_max': 1.0, 'checkpoint': not ld})
|
||||
netG = FlowNet2(args)
|
||||
|
@ -85,12 +85,12 @@ def define_G(opt, opt_net, scale=None):
|
|||
netG = stylegan2.StyleGan2GeneratorWithLatent(image_size=opt_net['image_size'], latent_dim=opt_net['latent_dim'],
|
||||
style_depth=opt_net['style_depth'], structure_input=is_structured,
|
||||
attn_layers=attn)
|
||||
elif which_model == 'srflow_orig':
|
||||
from models.archs.srflow_orig import SRFlowNet_arch
|
||||
elif which_model == 'srflow':
|
||||
from models.srflow import SRFlowNet_arch
|
||||
netG = SRFlowNet_arch.SRFlowNet(in_nc=3, out_nc=3, nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'],
|
||||
K=opt_net['K'], opt=opt)
|
||||
elif which_model == 'rrdb_latent_wrapper':
|
||||
from models.archs.srflow_orig.RRDBNet_arch import RRDBLatentWrapper
|
||||
from models.srflow.RRDBNet_arch import RRDBLatentWrapper
|
||||
netG = RRDBLatentWrapper(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||
nf=opt_net['nf'], nb=opt_net['nb'], with_bypass=opt_net['with_bypass'],
|
||||
blocks=opt_net['blocks_for_latent'], scale=opt_net['scale'], pretrain_rrdb_path=opt_net['pretrain_path'])
|
||||
|
@ -100,29 +100,29 @@ def define_G(opt, opt_net, scale=None):
|
|||
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], scale=opt_net['scale'],
|
||||
headless=True, output_mode=output_mode)
|
||||
elif which_model == 'rrdb_srflow':
|
||||
from models.archs.srflow_orig.RRDBNet_arch import RRDBNet
|
||||
from models.srflow.RRDBNet_arch import RRDBNet
|
||||
netG = RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||
nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'],
|
||||
initial_conv_stride=opt_net['initial_stride'])
|
||||
elif which_model == 'igpt2':
|
||||
from models.archs.transformers.igpt.gpt2 import iGPT2
|
||||
from models.transformers.igpt.gpt2 import iGPT2
|
||||
netG = iGPT2(opt_net['embed_dim'], opt_net['num_heads'], opt_net['num_layers'], opt_net['num_pixels'] ** 2, opt_net['num_vocab'], centroids_file=opt_net['centroids_file'])
|
||||
elif which_model == 'byol':
|
||||
from models.archs.byol.byol_model_wrapper import BYOL
|
||||
from models.byol.byol_model_wrapper import BYOL
|
||||
subnet = define_G(opt, opt_net['subnet'])
|
||||
netG = BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
|
||||
structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False))
|
||||
elif which_model == 'structural_byol':
|
||||
from models.archs.byol.byol_structural import StructuralBYOL
|
||||
from models.byol.byol_structural import StructuralBYOL
|
||||
subnet = define_G(opt, opt_net['subnet'])
|
||||
netG = StructuralBYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
|
||||
pretrained_state_dict=opt_get(opt_net, ["pretrained_path"]),
|
||||
freeze_until=opt_get(opt_net, ['freeze_until'], 0))
|
||||
elif which_model == 'spinenet':
|
||||
from models.archs.spinenet_arch import SpineNet
|
||||
from models.spinenet_arch import SpineNet
|
||||
netG = SpineNet(str(opt_net['arch']), in_channels=3, use_input_norm=opt_net['use_input_norm'])
|
||||
elif which_model == 'spinenet_with_logits':
|
||||
from models.archs.spinenet_arch import SpinenetWithLogits
|
||||
from models.spinenet_arch import SpinenetWithLogits
|
||||
netG = SpinenetWithLogits(str(opt_net['arch']), opt_net['output_to_attach'], opt_net['num_labels'],
|
||||
in_channels=3, use_input_norm=opt_net['use_input_norm'])
|
||||
else:
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
import models.archs.SwitchedResidualGenerator_arch as srg
|
||||
import models.archs.discriminator_vgg_arch as disc
|
||||
import models.SwitchedResidualGenerator_arch as srg
|
||||
import models.discriminator_vgg_arch as disc
|
||||
import functools
|
||||
|
||||
blacklisted_modules = [nn.Conv2d, nn.ReLU, nn.LeakyReLU, nn.BatchNorm2d, nn.Softmax]
|
||||
|
|
Loading…
Reference in New Issue
Block a user