More refactor changes

This commit is contained in:
James Betker 2020-12-18 09:24:31 -07:00
parent 5640e4efe4
commit d875ca8342
35 changed files with 80 additions and 80 deletions

View File

@ -15,7 +15,7 @@ import torch.nn.functional as F
from tqdm import tqdm from tqdm import tqdm
from data import create_dataset from data import create_dataset
from models.archs.arch_util import PixelUnshuffle from models.arch_util import PixelUnshuffle
from utils.util import opt_get from utils.util import opt_get

View File

@ -9,7 +9,7 @@ from torchvision import transforms
import torch.nn as nn import torch.nn as nn
from pathlib import Path from pathlib import Path
import models.archs.stylegan.stylegan2_lucidrains as sg2 import models.stylegan.stylegan2_lucidrains as sg2
def convert_transparent_to_rgb(image): def convert_transparent_to_rgb(image):

View File

@ -1,9 +1,9 @@
import models.archs.SwitchedResidualGenerator_arch as srg import models.SwitchedResidualGenerator_arch as srg
import torch import torch
import torch.nn as nn import torch.nn as nn
from switched_conv.switched_conv_util import save_attention_to_image from switched_conv.switched_conv_util import save_attention_to_image
from switched_conv.switched_conv import compute_attention_specificity from switched_conv.switched_conv import compute_attention_specificity
from models.archs.arch_util import ConvGnLelu, ExpansionBlock, MultiConvBlock from models.arch_util import ConvGnLelu, ExpansionBlock, MultiConvBlock
import functools import functools
import torch.nn.functional as F import torch.nn.functional as F

View File

@ -5,7 +5,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchvision import torchvision
from models.archs.arch_util import make_layer, default_init_weights, ConvGnSilu, ConvGnLelu from models.arch_util import make_layer, default_init_weights, ConvGnSilu, ConvGnLelu
from utils.util import checkpoint, sequential_checkpoint from utils.util import checkpoint, sequential_checkpoint

View File

@ -1,14 +1,14 @@
import torch import torch
from torch import nn from torch import nn
from switched_conv.switched_conv import BareConvSwitch, compute_attention_specificity, AttentionNorm from models.switched_conv.switched_conv import BareConvSwitch, compute_attention_specificity, AttentionNorm
import torch.nn.functional as F import torch.nn.functional as F
import functools import functools
from collections import OrderedDict from collections import OrderedDict
from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConvGnLelu, MultiConvBlock, \ from models.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConvGnLelu, MultiConvBlock, \
SiLU, UpconvBlock, ReferenceJoinBlock SiLU, UpconvBlock, ReferenceJoinBlock
from switched_conv.switched_conv_util import save_attention_to_image_rgb from models.switched_conv.switched_conv_util import save_attention_to_image_rgb
import os import os
from models.archs.spinenet_arch import SpineNet from models.spinenet_arch import SpineNet
import torchvision import torchvision
from utils.util import checkpoint from utils.util import checkpoint

View File

@ -5,7 +5,7 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from data.byol_attachment import reconstructed_shared_regions from data.byol_attachment import reconstructed_shared_regions
from models.archs.byol.byol_model_wrapper import singleton, EMA, get_module_device, set_requires_grad, \ from models.byol.byol_model_wrapper import singleton, EMA, get_module_device, set_requires_grad, \
update_moving_average update_moving_average
from utils.util import checkpoint from utils.util import checkpoint

View File

@ -1,10 +1,10 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from models.archs.RRDBNet_arch import RRDB, RRDBWithBypass from models.RRDBNet_arch import RRDB, RRDBWithBypass
from models.archs.arch_util import ConvBnLelu, ConvGnLelu, ExpansionBlock, ConvGnSilu, ResidualBlockGN from models.arch_util import ConvBnLelu, ConvGnLelu, ExpansionBlock, ConvGnSilu, ResidualBlockGN
import torch.nn.functional as F import torch.nn.functional as F
from models.archs.SwitchedResidualGenerator_arch import gather_2d from models.SwitchedResidualGenerator_arch import gather_2d
from utils.util import checkpoint from utils.util import checkpoint

View File

@ -6,7 +6,7 @@ import torch.nn.functional as F
from torch.nn.init import kaiming_normal from torch.nn.init import kaiming_normal
from torchvision.models.resnet import BasicBlock, Bottleneck from torchvision.models.resnet import BasicBlock, Bottleneck
from models.archs.arch_util import ConvGnSilu, ConvBnSilu, ConvBnRelu from models.arch_util import ConvGnSilu, ConvBnSilu, ConvBnRelu
def constant_init(module, val, bias=0): def constant_init(module, val, bias=0):

View File

@ -1,7 +1,7 @@
import torch import torch
from torch import nn as nn from torch import nn as nn
from models.srflow_orig import thops from models.srflow import thops
class _ActNorm(nn.Module): class _ActNorm(nn.Module):

View File

@ -1,8 +1,8 @@
import torch import torch
from torch import nn as nn from torch import nn as nn
from models.srflow_orig import thops from models.srflow import thops
from models.archs.srflow_orig.flow import Conv2d, Conv2dZeros from models.srflow.flow import Conv2d, Conv2dZeros
from utils.util import opt_get from utils.util import opt_get

View File

@ -1,7 +1,7 @@
import torch import torch
from torch import nn as nn from torch import nn as nn
import models.archs.srflow_orig.Permutations import models.srflow.Permutations
def getConditional(rrdbResults, position): def getConditional(rrdbResults, position):
@ -44,17 +44,17 @@ class FlowStep(nn.Module):
self.acOpt = acOpt self.acOpt = acOpt
# 1. actnorm # 1. actnorm
self.actnorm = models.archs.srflow_orig.FlowActNorms.ActNorm2d(in_channels, actnorm_scale) self.actnorm = models.srflow.FlowActNorms.ActNorm2d(in_channels, actnorm_scale)
# 2. permute # 2. permute
if flow_permutation == "invconv": if flow_permutation == "invconv":
self.invconv = models.archs.srflow_orig.Permutations.InvertibleConv1x1( self.invconv = models.srflow.Permutations.InvertibleConv1x1(
in_channels, LU_decomposed=LU_decomposed) in_channels, LU_decomposed=LU_decomposed)
# 3. coupling # 3. coupling
if flow_coupling == "CondAffineSeparatedAndCond": if flow_coupling == "CondAffineSeparatedAndCond":
self.affine = models.archs.srflow_orig.FlowAffineCouplingsAblation.CondAffineSeparatedAndCond(in_channels=in_channels, self.affine = models.srflow.FlowAffineCouplingsAblation.CondAffineSeparatedAndCond(in_channels=in_channels,
opt=opt) opt=opt)
elif flow_coupling == "noCoupling": elif flow_coupling == "noCoupling":
pass pass
else: else:

View File

@ -2,12 +2,12 @@ import numpy as np
import torch import torch
from torch import nn as nn from torch import nn as nn
import models.archs.srflow_orig.Split import models.srflow.Split
from models.archs.srflow_orig import flow from models.srflow import flow
from models.srflow_orig import thops from models.srflow import thops
from models.archs.srflow_orig.Split import Split2d from models.srflow.Split import Split2d
from models.archs.srflow_orig.glow_arch import f_conv2d_bias from models.srflow.glow_arch import f_conv2d_bias
from models.archs.srflow_orig.FlowStep import FlowStep from models.srflow.FlowStep import FlowStep
from utils.util import opt_get, checkpoint from utils.util import opt_get, checkpoint
@ -146,8 +146,8 @@ class FlowUpsamplerNet(nn.Module):
t = opt_get(opt, ['networks', 'generator','flow', 'split', 'type'], 'Split2d') t = opt_get(opt, ['networks', 'generator','flow', 'split', 'type'], 'Split2d')
if t == 'Split2d': if t == 'Split2d':
split = models.archs.srflow_orig.Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position, split = models.srflow.Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position,
cond_channels=cond_channels, consume_ratio=consume_ratio, opt=opt) cond_channels=cond_channels, consume_ratio=consume_ratio, opt=opt)
self.layers.append(split) self.layers.append(split)
self.output_shapes.append([-1, split.num_channels_pass, H, W]) self.output_shapes.append([-1, split.num_channels_pass, H, W])
self.C = split.num_channels_pass self.C = split.num_channels_pass

View File

@ -3,7 +3,7 @@ import torch
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from models.srflow_orig import thops from models.srflow import thops
class InvertibleConv1x1(nn.Module): class InvertibleConv1x1(nn.Module):

View File

@ -2,8 +2,8 @@ import functools
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import models.archs.srflow_orig.module_util as mutil import models.srflow.module_util as mutil
from models.archs.arch_util import default_init_weights, ConvGnSilu, ConvGnLelu from models.arch_util import default_init_weights, ConvGnSilu, ConvGnLelu
from utils.util import opt_get from utils.util import opt_get

View File

@ -4,10 +4,10 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np import numpy as np
from models.archs.srflow_orig.RRDBNet_arch import RRDBNet from models.srflow.RRDBNet_arch import RRDBNet
from models.archs.srflow_orig.FlowUpsamplerNet import FlowUpsamplerNet from models.srflow.FlowUpsamplerNet import FlowUpsamplerNet
import models.srflow_orig.thops as thops import models.srflow.thops as thops
import models.archs.srflow_orig.flow as flow import models.srflow.flow as flow
from utils.util import opt_get from utils.util import opt_get

View File

@ -1,8 +1,8 @@
import torch import torch
from torch import nn as nn from torch import nn as nn
from models.srflow_orig import thops from models.srflow import thops
from models.archs.srflow_orig.flow import Conv2dZeros, GaussianDiag from models.srflow.flow import Conv2dZeros, GaussianDiag
from utils.util import opt_get from utils.util import opt_get

View File

@ -3,7 +3,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np import numpy as np
from models.archs.srflow_orig.FlowActNorms import ActNorm2d from models.srflow.FlowActNorms import ActNorm2d
from . import thops from . import thops

View File

@ -8,9 +8,9 @@ import torch.nn.functional as F
import functools import functools
from collections import OrderedDict from collections import OrderedDict
from models.archs.SwitchedResidualGenerator_arch import HalvingProcessingBlock, ConfigurableSwitchComputer from models.SwitchedResidualGenerator_arch import HalvingProcessingBlock, ConfigurableSwitchComputer
from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, MultiConvBlock from models.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, MultiConvBlock
from switched_conv.switched_conv import BareConvSwitch, AttentionNorm from models.switched_conv.switched_conv import BareConvSwitch, AttentionNorm
from utils.util import checkpoint from utils.util import checkpoint

View File

@ -1,4 +1,4 @@
import models.archs.stylegan.stylegan2_lucidrains as stylegan2 import models.stylegan.stylegan2_lucidrains as stylegan2
def create_stylegan2_loss(opt_loss, env): def create_stylegan2_loss(opt_loss, env):

View File

@ -5,7 +5,7 @@ import torch.nn as nn
import torchvision import torchvision
from utils.util import sequential_checkpoint from utils.util import sequential_checkpoint
from models.archs.arch_util import ConvGnSilu, make_layer from models.arch_util import ConvGnSilu, make_layer
class TecoResblock(nn.Module): class TecoResblock(nn.Module):

View File

@ -1,6 +1,6 @@
import torch import torch
from models.archs.spinenet_arch import SpineNet from models.spinenet_arch import SpineNet
if __name__ == '__main__': if __name__ == '__main__':
pretrained_path = '../../experiments/train_sbyol_512unsupervised_restart/models/48000_generator.pth' pretrained_path = '../../experiments/train_sbyol_512unsupervised_restart/models/48000_generator.pth'

View File

@ -13,7 +13,7 @@ import numpy as np
import utils import utils
from data.image_folder_dataset import ImageFolderDataset from data.image_folder_dataset import ImageFolderDataset
from models.archs.spinenet_arch import SpineNet from models.spinenet_arch import SpineNet
# Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved # Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved

View File

@ -13,7 +13,7 @@ import torch
import numpy as np import numpy as np
from torchvision import utils from torchvision import utils
from models.archs.stylegan.stylegan2_rosinality import Generator, Discriminator from models.stylegan.stylegan2_rosinality import Generator, Discriminator
# Converts from the TF state_dict input provided into the vars originally expected from the rosinality converter. # Converts from the TF state_dict input provided into the vars originally expected from the rosinality converter.

View File

@ -1,7 +1,7 @@
import torch import torch
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
from models.archs.flownet2.networks import Resample2d from models.flownet2.networks import Resample2d
from models.archs.flownet2 import flow2img from models.flownet2 import flow2img
from trainer.injectors import Injector from trainer.injectors import Injector

View File

@ -1,8 +1,8 @@
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
from models.archs.stylegan.stylegan2_lucidrains import gradient_penalty from models.stylegan.stylegan2_lucidrains import gradient_penalty
from trainer.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name from trainer.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name
from models.archs.flownet2.networks import Resample2d from models.flownet2.networks import Resample2d
from trainer.injectors import Injector from trainer.injectors import Injector
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F

View File

@ -1,6 +1,6 @@
from trainer.eval.flow_gaussian_nll import FlowGaussianNll from trainer.eval.flow_gaussian_nll import FlowGaussianNll
from trainer.eval.sr_style import SrStyleTransferEvaluator from trainer.eval.sr_style import SrStyleTransferEvaluator
from trainer.eval import StyleTransferEvaluator from trainer.eval.style import StyleTransferEvaluator
def create_evaluator(model, opt_eval, env): def create_evaluator(model, opt_eval, env):

View File

@ -6,7 +6,7 @@ import trainer.eval.evaluator as evaluator
# Evaluate how close to true Gaussian a flow network predicts in a "normal" pass given a LQ/HQ image pair. # Evaluate how close to true Gaussian a flow network predicts in a "normal" pass given a LQ/HQ image pair.
from data.image_folder_dataset import ImageFolderDataset from data.image_folder_dataset import ImageFolderDataset
from models.archs.srflow_orig.flow import GaussianDiag from models.srflow.flow import GaussianDiag
class FlowGaussianNll(evaluator.Evaluator): class FlowGaussianNll(evaluator.Evaluator):

View File

@ -19,7 +19,7 @@ def create_injector(opt_inject, env):
from trainer.custom_training_components import create_stereoscopic_injector from trainer.custom_training_components import create_stereoscopic_injector
return create_stereoscopic_injector(opt_inject, env) return create_stereoscopic_injector(opt_inject, env)
elif 'igpt' in type: elif 'igpt' in type:
from models.archs.transformers.igpt import gpt2 from models.transformers.igpt import gpt2
return gpt2.create_injector(opt_inject, env) return gpt2.create_injector(opt_inject, env)
elif type == 'generator': elif type == 'generator':
return ImageGeneratorInjector(opt_inject, env) return ImageGeneratorInjector(opt_inject, env)
@ -372,7 +372,7 @@ class MultiFrameCombiner(Injector):
self.in_hq_key = opt['in_hq'] self.in_hq_key = opt['in_hq']
self.out_lq_key = opt['out'] self.out_lq_key = opt['out']
self.out_hq_key = opt['out_hq'] self.out_hq_key = opt['out_hq']
from models.archs.flownet2.networks import Resample2d from models.flownet2.networks import Resample2d
self.resampler = Resample2d() self.resampler = Resample2d()
def combine(self, state): def combine(self, state):

View File

@ -14,7 +14,7 @@ def create_loss(opt_loss, env):
from trainer.custom_training_components import create_teco_loss from trainer.custom_training_components import create_teco_loss
return create_teco_loss(opt_loss, env) return create_teco_loss(opt_loss, env)
elif 'stylegan2_' in type: elif 'stylegan2_' in type:
from models.archs.stylegan import create_stylegan2_loss from models.stylegan import create_stylegan2_loss
return create_stylegan2_loss(opt_loss, env) return create_stylegan2_loss(opt_loss, env)
elif type == 'crossentropy': elif type == 'crossentropy':
return CrossEntropy(opt_loss, env) return CrossEntropy(opt_loss, env)
@ -311,7 +311,7 @@ class DiscriminatorGanLoss(ConfigurableLoss):
if self.gradient_penalty: if self.gradient_penalty:
# Apply gradient penalty. TODO: migrate this elsewhere. # Apply gradient penalty. TODO: migrate this elsewhere.
from models.archs.stylegan.stylegan2_lucidrains import gradient_penalty from models.stylegan.stylegan2_lucidrains import gradient_penalty
assert len(real) == 1 # Grad penalty doesn't currently support multi-input discriminators. assert len(real) == 1 # Grad penalty doesn't currently support multi-input discriminators.
gp = gradient_penalty(real[0], d_real) gp = gradient_penalty(real[0], d_real)
self.metrics.append(("gradient_penalty", gp.clone().detach())) self.metrics.append(("gradient_penalty", gp.clone().detach()))

View File

@ -6,16 +6,16 @@ import munch
import torch import torch
import torchvision import torchvision
from munch import munchify from munch import munchify
import models.archs.stylegan.stylegan2_lucidrains as stylegan2 import models.stylegan.stylegan2_lucidrains as stylegan2
import models.archs.fixup_resnet.DiscriminatorResnet_arch as DiscriminatorResnet_arch import models.fixup_resnet.DiscriminatorResnet_arch as DiscriminatorResnet_arch
import models.archs.RRDBNet_arch as RRDBNet_arch import models.RRDBNet_arch as RRDBNet_arch
import models.archs.SwitchedResidualGenerator_arch as SwitchedGen_arch import models.SwitchedResidualGenerator_arch as SwitchedGen_arch
import models.archs.discriminator_vgg_arch as SRGAN_arch import models.discriminator_vgg_arch as SRGAN_arch
import models.archs.feature_arch as feature_arch import models.feature_arch as feature_arch
from models.archs import srg2_classic from models import srg2_classic
from models.archs.stylegan.Discriminator_StyleGAN import StyleGanDiscriminator from models.stylegan.Discriminator_StyleGAN import StyleGanDiscriminator
from models.archs.tecogan.teco_resgen import TecoGen from models.tecogan.teco_resgen import TecoGen
from utils.util import opt_get from utils.util import opt_get
logger = logging.getLogger('base') logger = logging.getLogger('base')
@ -30,7 +30,7 @@ def define_G(opt, opt_net, scale=None):
if which_model == 'RRDBNetBypass': if which_model == 'RRDBNetBypass':
block = RRDBNet_arch.RRDBWithBypass block = RRDBNet_arch.RRDBWithBypass
elif which_model == 'RRDBNetLambda': elif which_model == 'RRDBNetLambda':
from models.archs.lambda_rrdb import LambdaRRDB from models.lambda_rrdb import LambdaRRDB
block = LambdaRRDB block = LambdaRRDB
else: else:
block = RRDBNet_arch.RRDB block = RRDBNet_arch.RRDB
@ -62,7 +62,7 @@ def define_G(opt, opt_net, scale=None):
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'], heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise']) upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
elif which_model == "flownet2": elif which_model == "flownet2":
from models.archs.flownet2 import FlowNet2 from models.flownet2 import FlowNet2
ld = 'load_path' in opt_net.keys() ld = 'load_path' in opt_net.keys()
args = munch.Munch({'fp16': False, 'rgb_max': 1.0, 'checkpoint': not ld}) args = munch.Munch({'fp16': False, 'rgb_max': 1.0, 'checkpoint': not ld})
netG = FlowNet2(args) netG = FlowNet2(args)
@ -85,12 +85,12 @@ def define_G(opt, opt_net, scale=None):
netG = stylegan2.StyleGan2GeneratorWithLatent(image_size=opt_net['image_size'], latent_dim=opt_net['latent_dim'], netG = stylegan2.StyleGan2GeneratorWithLatent(image_size=opt_net['image_size'], latent_dim=opt_net['latent_dim'],
style_depth=opt_net['style_depth'], structure_input=is_structured, style_depth=opt_net['style_depth'], structure_input=is_structured,
attn_layers=attn) attn_layers=attn)
elif which_model == 'srflow_orig': elif which_model == 'srflow':
from models.archs.srflow_orig import SRFlowNet_arch from models.srflow import SRFlowNet_arch
netG = SRFlowNet_arch.SRFlowNet(in_nc=3, out_nc=3, nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'], netG = SRFlowNet_arch.SRFlowNet(in_nc=3, out_nc=3, nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'],
K=opt_net['K'], opt=opt) K=opt_net['K'], opt=opt)
elif which_model == 'rrdb_latent_wrapper': elif which_model == 'rrdb_latent_wrapper':
from models.archs.srflow_orig.RRDBNet_arch import RRDBLatentWrapper from models.srflow.RRDBNet_arch import RRDBLatentWrapper
netG = RRDBLatentWrapper(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], netG = RRDBLatentWrapper(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], nb=opt_net['nb'], with_bypass=opt_net['with_bypass'], nf=opt_net['nf'], nb=opt_net['nb'], with_bypass=opt_net['with_bypass'],
blocks=opt_net['blocks_for_latent'], scale=opt_net['scale'], pretrain_rrdb_path=opt_net['pretrain_path']) blocks=opt_net['blocks_for_latent'], scale=opt_net['scale'], pretrain_rrdb_path=opt_net['pretrain_path'])
@ -100,29 +100,29 @@ def define_G(opt, opt_net, scale=None):
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], scale=opt_net['scale'], mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], scale=opt_net['scale'],
headless=True, output_mode=output_mode) headless=True, output_mode=output_mode)
elif which_model == 'rrdb_srflow': elif which_model == 'rrdb_srflow':
from models.archs.srflow_orig.RRDBNet_arch import RRDBNet from models.srflow.RRDBNet_arch import RRDBNet
netG = RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], netG = RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'], nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'],
initial_conv_stride=opt_net['initial_stride']) initial_conv_stride=opt_net['initial_stride'])
elif which_model == 'igpt2': elif which_model == 'igpt2':
from models.archs.transformers.igpt.gpt2 import iGPT2 from models.transformers.igpt.gpt2 import iGPT2
netG = iGPT2(opt_net['embed_dim'], opt_net['num_heads'], opt_net['num_layers'], opt_net['num_pixels'] ** 2, opt_net['num_vocab'], centroids_file=opt_net['centroids_file']) netG = iGPT2(opt_net['embed_dim'], opt_net['num_heads'], opt_net['num_layers'], opt_net['num_pixels'] ** 2, opt_net['num_vocab'], centroids_file=opt_net['centroids_file'])
elif which_model == 'byol': elif which_model == 'byol':
from models.archs.byol.byol_model_wrapper import BYOL from models.byol.byol_model_wrapper import BYOL
subnet = define_G(opt, opt_net['subnet']) subnet = define_G(opt, opt_net['subnet'])
netG = BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'], netG = BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False)) structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False))
elif which_model == 'structural_byol': elif which_model == 'structural_byol':
from models.archs.byol.byol_structural import StructuralBYOL from models.byol.byol_structural import StructuralBYOL
subnet = define_G(opt, opt_net['subnet']) subnet = define_G(opt, opt_net['subnet'])
netG = StructuralBYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'], netG = StructuralBYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
pretrained_state_dict=opt_get(opt_net, ["pretrained_path"]), pretrained_state_dict=opt_get(opt_net, ["pretrained_path"]),
freeze_until=opt_get(opt_net, ['freeze_until'], 0)) freeze_until=opt_get(opt_net, ['freeze_until'], 0))
elif which_model == 'spinenet': elif which_model == 'spinenet':
from models.archs.spinenet_arch import SpineNet from models.spinenet_arch import SpineNet
netG = SpineNet(str(opt_net['arch']), in_channels=3, use_input_norm=opt_net['use_input_norm']) netG = SpineNet(str(opt_net['arch']), in_channels=3, use_input_norm=opt_net['use_input_norm'])
elif which_model == 'spinenet_with_logits': elif which_model == 'spinenet_with_logits':
from models.archs.spinenet_arch import SpinenetWithLogits from models.spinenet_arch import SpinenetWithLogits
netG = SpinenetWithLogits(str(opt_net['arch']), opt_net['output_to_attach'], opt_net['num_labels'], netG = SpinenetWithLogits(str(opt_net['arch']), opt_net['output_to_attach'], opt_net['num_labels'],
in_channels=3, use_input_norm=opt_net['use_input_norm']) in_channels=3, use_input_norm=opt_net['use_input_norm'])
else: else:

View File

@ -1,7 +1,7 @@
import torch import torch
from torch import nn from torch import nn
import models.archs.SwitchedResidualGenerator_arch as srg import models.SwitchedResidualGenerator_arch as srg
import models.archs.discriminator_vgg_arch as disc import models.discriminator_vgg_arch as disc
import functools import functools
blacklisted_modules = [nn.Conv2d, nn.ReLU, nn.LeakyReLU, nn.BatchNorm2d, nn.Softmax] blacklisted_modules = [nn.Conv2d, nn.ReLU, nn.LeakyReLU, nn.BatchNorm2d, nn.Softmax]