forked from mrq/DL-Art-School
More refactor changes
This commit is contained in:
parent
5640e4efe4
commit
d875ca8342
|
@ -15,7 +15,7 @@ import torch.nn.functional as F
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from data import create_dataset
|
from data import create_dataset
|
||||||
from models.archs.arch_util import PixelUnshuffle
|
from models.arch_util import PixelUnshuffle
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -9,7 +9,7 @@ from torchvision import transforms
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import models.archs.stylegan.stylegan2_lucidrains as sg2
|
import models.stylegan.stylegan2_lucidrains as sg2
|
||||||
|
|
||||||
|
|
||||||
def convert_transparent_to_rgb(image):
|
def convert_transparent_to_rgb(image):
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
import models.archs.SwitchedResidualGenerator_arch as srg
|
import models.SwitchedResidualGenerator_arch as srg
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from switched_conv.switched_conv_util import save_attention_to_image
|
from switched_conv.switched_conv_util import save_attention_to_image
|
||||||
from switched_conv.switched_conv import compute_attention_specificity
|
from switched_conv.switched_conv import compute_attention_specificity
|
||||||
from models.archs.arch_util import ConvGnLelu, ExpansionBlock, MultiConvBlock
|
from models.arch_util import ConvGnLelu, ExpansionBlock, MultiConvBlock
|
||||||
import functools
|
import functools
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchvision
|
import torchvision
|
||||||
|
|
||||||
from models.archs.arch_util import make_layer, default_init_weights, ConvGnSilu, ConvGnLelu
|
from models.arch_util import make_layer, default_init_weights, ConvGnSilu, ConvGnLelu
|
||||||
from utils.util import checkpoint, sequential_checkpoint
|
from utils.util import checkpoint, sequential_checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,14 +1,14 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from switched_conv.switched_conv import BareConvSwitch, compute_attention_specificity, AttentionNorm
|
from models.switched_conv.switched_conv import BareConvSwitch, compute_attention_specificity, AttentionNorm
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import functools
|
import functools
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConvGnLelu, MultiConvBlock, \
|
from models.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConvGnLelu, MultiConvBlock, \
|
||||||
SiLU, UpconvBlock, ReferenceJoinBlock
|
SiLU, UpconvBlock, ReferenceJoinBlock
|
||||||
from switched_conv.switched_conv_util import save_attention_to_image_rgb
|
from models.switched_conv.switched_conv_util import save_attention_to_image_rgb
|
||||||
import os
|
import os
|
||||||
from models.archs.spinenet_arch import SpineNet
|
from models.spinenet_arch import SpineNet
|
||||||
import torchvision
|
import torchvision
|
||||||
from utils.util import checkpoint
|
from utils.util import checkpoint
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from data.byol_attachment import reconstructed_shared_regions
|
from data.byol_attachment import reconstructed_shared_regions
|
||||||
from models.archs.byol.byol_model_wrapper import singleton, EMA, get_module_device, set_requires_grad, \
|
from models.byol.byol_model_wrapper import singleton, EMA, get_module_device, set_requires_grad, \
|
||||||
update_moving_average
|
update_moving_average
|
||||||
from utils.util import checkpoint
|
from utils.util import checkpoint
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from models.archs.RRDBNet_arch import RRDB, RRDBWithBypass
|
from models.RRDBNet_arch import RRDB, RRDBWithBypass
|
||||||
from models.archs.arch_util import ConvBnLelu, ConvGnLelu, ExpansionBlock, ConvGnSilu, ResidualBlockGN
|
from models.arch_util import ConvBnLelu, ConvGnLelu, ExpansionBlock, ConvGnSilu, ResidualBlockGN
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from models.archs.SwitchedResidualGenerator_arch import gather_2d
|
from models.SwitchedResidualGenerator_arch import gather_2d
|
||||||
from utils.util import checkpoint
|
from utils.util import checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@ import torch.nn.functional as F
|
||||||
from torch.nn.init import kaiming_normal
|
from torch.nn.init import kaiming_normal
|
||||||
|
|
||||||
from torchvision.models.resnet import BasicBlock, Bottleneck
|
from torchvision.models.resnet import BasicBlock, Bottleneck
|
||||||
from models.archs.arch_util import ConvGnSilu, ConvBnSilu, ConvBnRelu
|
from models.arch_util import ConvGnSilu, ConvBnSilu, ConvBnRelu
|
||||||
|
|
||||||
|
|
||||||
def constant_init(module, val, bias=0):
|
def constant_init(module, val, bias=0):
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
|
|
||||||
from models.srflow_orig import thops
|
from models.srflow import thops
|
||||||
|
|
||||||
|
|
||||||
class _ActNorm(nn.Module):
|
class _ActNorm(nn.Module):
|
|
@ -1,8 +1,8 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
|
|
||||||
from models.srflow_orig import thops
|
from models.srflow import thops
|
||||||
from models.archs.srflow_orig.flow import Conv2d, Conv2dZeros
|
from models.srflow.flow import Conv2d, Conv2dZeros
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
|
|
||||||
import models.archs.srflow_orig.Permutations
|
import models.srflow.Permutations
|
||||||
|
|
||||||
|
|
||||||
def getConditional(rrdbResults, position):
|
def getConditional(rrdbResults, position):
|
||||||
|
@ -44,17 +44,17 @@ class FlowStep(nn.Module):
|
||||||
self.acOpt = acOpt
|
self.acOpt = acOpt
|
||||||
|
|
||||||
# 1. actnorm
|
# 1. actnorm
|
||||||
self.actnorm = models.archs.srflow_orig.FlowActNorms.ActNorm2d(in_channels, actnorm_scale)
|
self.actnorm = models.srflow.FlowActNorms.ActNorm2d(in_channels, actnorm_scale)
|
||||||
|
|
||||||
# 2. permute
|
# 2. permute
|
||||||
if flow_permutation == "invconv":
|
if flow_permutation == "invconv":
|
||||||
self.invconv = models.archs.srflow_orig.Permutations.InvertibleConv1x1(
|
self.invconv = models.srflow.Permutations.InvertibleConv1x1(
|
||||||
in_channels, LU_decomposed=LU_decomposed)
|
in_channels, LU_decomposed=LU_decomposed)
|
||||||
|
|
||||||
# 3. coupling
|
# 3. coupling
|
||||||
if flow_coupling == "CondAffineSeparatedAndCond":
|
if flow_coupling == "CondAffineSeparatedAndCond":
|
||||||
self.affine = models.archs.srflow_orig.FlowAffineCouplingsAblation.CondAffineSeparatedAndCond(in_channels=in_channels,
|
self.affine = models.srflow.FlowAffineCouplingsAblation.CondAffineSeparatedAndCond(in_channels=in_channels,
|
||||||
opt=opt)
|
opt=opt)
|
||||||
elif flow_coupling == "noCoupling":
|
elif flow_coupling == "noCoupling":
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
|
@ -2,12 +2,12 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
|
|
||||||
import models.archs.srflow_orig.Split
|
import models.srflow.Split
|
||||||
from models.archs.srflow_orig import flow
|
from models.srflow import flow
|
||||||
from models.srflow_orig import thops
|
from models.srflow import thops
|
||||||
from models.archs.srflow_orig.Split import Split2d
|
from models.srflow.Split import Split2d
|
||||||
from models.archs.srflow_orig.glow_arch import f_conv2d_bias
|
from models.srflow.glow_arch import f_conv2d_bias
|
||||||
from models.archs.srflow_orig.FlowStep import FlowStep
|
from models.srflow.FlowStep import FlowStep
|
||||||
from utils.util import opt_get, checkpoint
|
from utils.util import opt_get, checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
@ -146,8 +146,8 @@ class FlowUpsamplerNet(nn.Module):
|
||||||
t = opt_get(opt, ['networks', 'generator','flow', 'split', 'type'], 'Split2d')
|
t = opt_get(opt, ['networks', 'generator','flow', 'split', 'type'], 'Split2d')
|
||||||
|
|
||||||
if t == 'Split2d':
|
if t == 'Split2d':
|
||||||
split = models.archs.srflow_orig.Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position,
|
split = models.srflow.Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position,
|
||||||
cond_channels=cond_channels, consume_ratio=consume_ratio, opt=opt)
|
cond_channels=cond_channels, consume_ratio=consume_ratio, opt=opt)
|
||||||
self.layers.append(split)
|
self.layers.append(split)
|
||||||
self.output_shapes.append([-1, split.num_channels_pass, H, W])
|
self.output_shapes.append([-1, split.num_channels_pass, H, W])
|
||||||
self.C = split.num_channels_pass
|
self.C = split.num_channels_pass
|
|
@ -3,7 +3,7 @@ import torch
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from models.srflow_orig import thops
|
from models.srflow import thops
|
||||||
|
|
||||||
|
|
||||||
class InvertibleConv1x1(nn.Module):
|
class InvertibleConv1x1(nn.Module):
|
|
@ -2,8 +2,8 @@ import functools
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import models.archs.srflow_orig.module_util as mutil
|
import models.srflow.module_util as mutil
|
||||||
from models.archs.arch_util import default_init_weights, ConvGnSilu, ConvGnLelu
|
from models.arch_util import default_init_weights, ConvGnSilu, ConvGnLelu
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
|
@ -4,10 +4,10 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from models.archs.srflow_orig.RRDBNet_arch import RRDBNet
|
from models.srflow.RRDBNet_arch import RRDBNet
|
||||||
from models.archs.srflow_orig.FlowUpsamplerNet import FlowUpsamplerNet
|
from models.srflow.FlowUpsamplerNet import FlowUpsamplerNet
|
||||||
import models.srflow_orig.thops as thops
|
import models.srflow.thops as thops
|
||||||
import models.archs.srflow_orig.flow as flow
|
import models.srflow.flow as flow
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
|
|
||||||
from models.srflow_orig import thops
|
from models.srflow import thops
|
||||||
from models.archs.srflow_orig.flow import Conv2dZeros, GaussianDiag
|
from models.srflow.flow import Conv2dZeros, GaussianDiag
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from models.archs.srflow_orig.FlowActNorms import ActNorm2d
|
from models.srflow.FlowActNorms import ActNorm2d
|
||||||
from . import thops
|
from . import thops
|
||||||
|
|
||||||
|
|
|
@ -8,9 +8,9 @@ import torch.nn.functional as F
|
||||||
import functools
|
import functools
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
from models.archs.SwitchedResidualGenerator_arch import HalvingProcessingBlock, ConfigurableSwitchComputer
|
from models.SwitchedResidualGenerator_arch import HalvingProcessingBlock, ConfigurableSwitchComputer
|
||||||
from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, MultiConvBlock
|
from models.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, MultiConvBlock
|
||||||
from switched_conv.switched_conv import BareConvSwitch, AttentionNorm
|
from models.switched_conv.switched_conv import BareConvSwitch, AttentionNorm
|
||||||
from utils.util import checkpoint
|
from utils.util import checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
import models.archs.stylegan.stylegan2_lucidrains as stylegan2
|
import models.stylegan.stylegan2_lucidrains as stylegan2
|
||||||
|
|
||||||
|
|
||||||
def create_stylegan2_loss(opt_loss, env):
|
def create_stylegan2_loss(opt_loss, env):
|
||||||
|
|
|
@ -5,7 +5,7 @@ import torch.nn as nn
|
||||||
import torchvision
|
import torchvision
|
||||||
|
|
||||||
from utils.util import sequential_checkpoint
|
from utils.util import sequential_checkpoint
|
||||||
from models.archs.arch_util import ConvGnSilu, make_layer
|
from models.arch_util import ConvGnSilu, make_layer
|
||||||
|
|
||||||
|
|
||||||
class TecoResblock(nn.Module):
|
class TecoResblock(nn.Module):
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from models.archs.spinenet_arch import SpineNet
|
from models.spinenet_arch import SpineNet
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pretrained_path = '../../experiments/train_sbyol_512unsupervised_restart/models/48000_generator.pth'
|
pretrained_path = '../../experiments/train_sbyol_512unsupervised_restart/models/48000_generator.pth'
|
||||||
|
|
|
@ -13,7 +13,7 @@ import numpy as np
|
||||||
|
|
||||||
import utils
|
import utils
|
||||||
from data.image_folder_dataset import ImageFolderDataset
|
from data.image_folder_dataset import ImageFolderDataset
|
||||||
from models.archs.spinenet_arch import SpineNet
|
from models.spinenet_arch import SpineNet
|
||||||
|
|
||||||
|
|
||||||
# Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved
|
# Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved
|
||||||
|
|
|
@ -13,7 +13,7 @@ import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torchvision import utils
|
from torchvision import utils
|
||||||
|
|
||||||
from models.archs.stylegan.stylegan2_rosinality import Generator, Discriminator
|
from models.stylegan.stylegan2_rosinality import Generator, Discriminator
|
||||||
|
|
||||||
|
|
||||||
# Converts from the TF state_dict input provided into the vars originally expected from the rosinality converter.
|
# Converts from the TF state_dict input provided into the vars originally expected from the rosinality converter.
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
from torch.cuda.amp import autocast
|
from torch.cuda.amp import autocast
|
||||||
from models.archs.flownet2.networks import Resample2d
|
from models.flownet2.networks import Resample2d
|
||||||
from models.archs.flownet2 import flow2img
|
from models.flownet2 import flow2img
|
||||||
from trainer.injectors import Injector
|
from trainer.injectors import Injector
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
from torch.cuda.amp import autocast
|
from torch.cuda.amp import autocast
|
||||||
|
|
||||||
from models.archs.stylegan.stylegan2_lucidrains import gradient_penalty
|
from models.stylegan.stylegan2_lucidrains import gradient_penalty
|
||||||
from trainer.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name
|
from trainer.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name
|
||||||
from models.archs.flownet2.networks import Resample2d
|
from models.flownet2.networks import Resample2d
|
||||||
from trainer.injectors import Injector
|
from trainer.injectors import Injector
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from trainer.eval.flow_gaussian_nll import FlowGaussianNll
|
from trainer.eval.flow_gaussian_nll import FlowGaussianNll
|
||||||
from trainer.eval.sr_style import SrStyleTransferEvaluator
|
from trainer.eval.sr_style import SrStyleTransferEvaluator
|
||||||
from trainer.eval import StyleTransferEvaluator
|
from trainer.eval.style import StyleTransferEvaluator
|
||||||
|
|
||||||
|
|
||||||
def create_evaluator(model, opt_eval, env):
|
def create_evaluator(model, opt_eval, env):
|
||||||
|
|
|
@ -6,7 +6,7 @@ import trainer.eval.evaluator as evaluator
|
||||||
|
|
||||||
# Evaluate how close to true Gaussian a flow network predicts in a "normal" pass given a LQ/HQ image pair.
|
# Evaluate how close to true Gaussian a flow network predicts in a "normal" pass given a LQ/HQ image pair.
|
||||||
from data.image_folder_dataset import ImageFolderDataset
|
from data.image_folder_dataset import ImageFolderDataset
|
||||||
from models.archs.srflow_orig.flow import GaussianDiag
|
from models.srflow.flow import GaussianDiag
|
||||||
|
|
||||||
|
|
||||||
class FlowGaussianNll(evaluator.Evaluator):
|
class FlowGaussianNll(evaluator.Evaluator):
|
||||||
|
|
|
@ -19,7 +19,7 @@ def create_injector(opt_inject, env):
|
||||||
from trainer.custom_training_components import create_stereoscopic_injector
|
from trainer.custom_training_components import create_stereoscopic_injector
|
||||||
return create_stereoscopic_injector(opt_inject, env)
|
return create_stereoscopic_injector(opt_inject, env)
|
||||||
elif 'igpt' in type:
|
elif 'igpt' in type:
|
||||||
from models.archs.transformers.igpt import gpt2
|
from models.transformers.igpt import gpt2
|
||||||
return gpt2.create_injector(opt_inject, env)
|
return gpt2.create_injector(opt_inject, env)
|
||||||
elif type == 'generator':
|
elif type == 'generator':
|
||||||
return ImageGeneratorInjector(opt_inject, env)
|
return ImageGeneratorInjector(opt_inject, env)
|
||||||
|
@ -372,7 +372,7 @@ class MultiFrameCombiner(Injector):
|
||||||
self.in_hq_key = opt['in_hq']
|
self.in_hq_key = opt['in_hq']
|
||||||
self.out_lq_key = opt['out']
|
self.out_lq_key = opt['out']
|
||||||
self.out_hq_key = opt['out_hq']
|
self.out_hq_key = opt['out_hq']
|
||||||
from models.archs.flownet2.networks import Resample2d
|
from models.flownet2.networks import Resample2d
|
||||||
self.resampler = Resample2d()
|
self.resampler = Resample2d()
|
||||||
|
|
||||||
def combine(self, state):
|
def combine(self, state):
|
||||||
|
|
|
@ -14,7 +14,7 @@ def create_loss(opt_loss, env):
|
||||||
from trainer.custom_training_components import create_teco_loss
|
from trainer.custom_training_components import create_teco_loss
|
||||||
return create_teco_loss(opt_loss, env)
|
return create_teco_loss(opt_loss, env)
|
||||||
elif 'stylegan2_' in type:
|
elif 'stylegan2_' in type:
|
||||||
from models.archs.stylegan import create_stylegan2_loss
|
from models.stylegan import create_stylegan2_loss
|
||||||
return create_stylegan2_loss(opt_loss, env)
|
return create_stylegan2_loss(opt_loss, env)
|
||||||
elif type == 'crossentropy':
|
elif type == 'crossentropy':
|
||||||
return CrossEntropy(opt_loss, env)
|
return CrossEntropy(opt_loss, env)
|
||||||
|
@ -311,7 +311,7 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
||||||
|
|
||||||
if self.gradient_penalty:
|
if self.gradient_penalty:
|
||||||
# Apply gradient penalty. TODO: migrate this elsewhere.
|
# Apply gradient penalty. TODO: migrate this elsewhere.
|
||||||
from models.archs.stylegan.stylegan2_lucidrains import gradient_penalty
|
from models.stylegan.stylegan2_lucidrains import gradient_penalty
|
||||||
assert len(real) == 1 # Grad penalty doesn't currently support multi-input discriminators.
|
assert len(real) == 1 # Grad penalty doesn't currently support multi-input discriminators.
|
||||||
gp = gradient_penalty(real[0], d_real)
|
gp = gradient_penalty(real[0], d_real)
|
||||||
self.metrics.append(("gradient_penalty", gp.clone().detach()))
|
self.metrics.append(("gradient_penalty", gp.clone().detach()))
|
||||||
|
|
|
@ -6,16 +6,16 @@ import munch
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
from munch import munchify
|
from munch import munchify
|
||||||
import models.archs.stylegan.stylegan2_lucidrains as stylegan2
|
import models.stylegan.stylegan2_lucidrains as stylegan2
|
||||||
|
|
||||||
import models.archs.fixup_resnet.DiscriminatorResnet_arch as DiscriminatorResnet_arch
|
import models.fixup_resnet.DiscriminatorResnet_arch as DiscriminatorResnet_arch
|
||||||
import models.archs.RRDBNet_arch as RRDBNet_arch
|
import models.RRDBNet_arch as RRDBNet_arch
|
||||||
import models.archs.SwitchedResidualGenerator_arch as SwitchedGen_arch
|
import models.SwitchedResidualGenerator_arch as SwitchedGen_arch
|
||||||
import models.archs.discriminator_vgg_arch as SRGAN_arch
|
import models.discriminator_vgg_arch as SRGAN_arch
|
||||||
import models.archs.feature_arch as feature_arch
|
import models.feature_arch as feature_arch
|
||||||
from models.archs import srg2_classic
|
from models import srg2_classic
|
||||||
from models.archs.stylegan.Discriminator_StyleGAN import StyleGanDiscriminator
|
from models.stylegan.Discriminator_StyleGAN import StyleGanDiscriminator
|
||||||
from models.archs.tecogan.teco_resgen import TecoGen
|
from models.tecogan.teco_resgen import TecoGen
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
|
||||||
logger = logging.getLogger('base')
|
logger = logging.getLogger('base')
|
||||||
|
@ -30,7 +30,7 @@ def define_G(opt, opt_net, scale=None):
|
||||||
if which_model == 'RRDBNetBypass':
|
if which_model == 'RRDBNetBypass':
|
||||||
block = RRDBNet_arch.RRDBWithBypass
|
block = RRDBNet_arch.RRDBWithBypass
|
||||||
elif which_model == 'RRDBNetLambda':
|
elif which_model == 'RRDBNetLambda':
|
||||||
from models.archs.lambda_rrdb import LambdaRRDB
|
from models.lambda_rrdb import LambdaRRDB
|
||||||
block = LambdaRRDB
|
block = LambdaRRDB
|
||||||
else:
|
else:
|
||||||
block = RRDBNet_arch.RRDB
|
block = RRDBNet_arch.RRDB
|
||||||
|
@ -62,7 +62,7 @@ def define_G(opt, opt_net, scale=None):
|
||||||
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
|
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
|
||||||
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
|
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
|
||||||
elif which_model == "flownet2":
|
elif which_model == "flownet2":
|
||||||
from models.archs.flownet2 import FlowNet2
|
from models.flownet2 import FlowNet2
|
||||||
ld = 'load_path' in opt_net.keys()
|
ld = 'load_path' in opt_net.keys()
|
||||||
args = munch.Munch({'fp16': False, 'rgb_max': 1.0, 'checkpoint': not ld})
|
args = munch.Munch({'fp16': False, 'rgb_max': 1.0, 'checkpoint': not ld})
|
||||||
netG = FlowNet2(args)
|
netG = FlowNet2(args)
|
||||||
|
@ -85,12 +85,12 @@ def define_G(opt, opt_net, scale=None):
|
||||||
netG = stylegan2.StyleGan2GeneratorWithLatent(image_size=opt_net['image_size'], latent_dim=opt_net['latent_dim'],
|
netG = stylegan2.StyleGan2GeneratorWithLatent(image_size=opt_net['image_size'], latent_dim=opt_net['latent_dim'],
|
||||||
style_depth=opt_net['style_depth'], structure_input=is_structured,
|
style_depth=opt_net['style_depth'], structure_input=is_structured,
|
||||||
attn_layers=attn)
|
attn_layers=attn)
|
||||||
elif which_model == 'srflow_orig':
|
elif which_model == 'srflow':
|
||||||
from models.archs.srflow_orig import SRFlowNet_arch
|
from models.srflow import SRFlowNet_arch
|
||||||
netG = SRFlowNet_arch.SRFlowNet(in_nc=3, out_nc=3, nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'],
|
netG = SRFlowNet_arch.SRFlowNet(in_nc=3, out_nc=3, nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'],
|
||||||
K=opt_net['K'], opt=opt)
|
K=opt_net['K'], opt=opt)
|
||||||
elif which_model == 'rrdb_latent_wrapper':
|
elif which_model == 'rrdb_latent_wrapper':
|
||||||
from models.archs.srflow_orig.RRDBNet_arch import RRDBLatentWrapper
|
from models.srflow.RRDBNet_arch import RRDBLatentWrapper
|
||||||
netG = RRDBLatentWrapper(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
netG = RRDBLatentWrapper(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||||
nf=opt_net['nf'], nb=opt_net['nb'], with_bypass=opt_net['with_bypass'],
|
nf=opt_net['nf'], nb=opt_net['nb'], with_bypass=opt_net['with_bypass'],
|
||||||
blocks=opt_net['blocks_for_latent'], scale=opt_net['scale'], pretrain_rrdb_path=opt_net['pretrain_path'])
|
blocks=opt_net['blocks_for_latent'], scale=opt_net['scale'], pretrain_rrdb_path=opt_net['pretrain_path'])
|
||||||
|
@ -100,29 +100,29 @@ def define_G(opt, opt_net, scale=None):
|
||||||
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], scale=opt_net['scale'],
|
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], scale=opt_net['scale'],
|
||||||
headless=True, output_mode=output_mode)
|
headless=True, output_mode=output_mode)
|
||||||
elif which_model == 'rrdb_srflow':
|
elif which_model == 'rrdb_srflow':
|
||||||
from models.archs.srflow_orig.RRDBNet_arch import RRDBNet
|
from models.srflow.RRDBNet_arch import RRDBNet
|
||||||
netG = RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
netG = RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||||
nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'],
|
nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'],
|
||||||
initial_conv_stride=opt_net['initial_stride'])
|
initial_conv_stride=opt_net['initial_stride'])
|
||||||
elif which_model == 'igpt2':
|
elif which_model == 'igpt2':
|
||||||
from models.archs.transformers.igpt.gpt2 import iGPT2
|
from models.transformers.igpt.gpt2 import iGPT2
|
||||||
netG = iGPT2(opt_net['embed_dim'], opt_net['num_heads'], opt_net['num_layers'], opt_net['num_pixels'] ** 2, opt_net['num_vocab'], centroids_file=opt_net['centroids_file'])
|
netG = iGPT2(opt_net['embed_dim'], opt_net['num_heads'], opt_net['num_layers'], opt_net['num_pixels'] ** 2, opt_net['num_vocab'], centroids_file=opt_net['centroids_file'])
|
||||||
elif which_model == 'byol':
|
elif which_model == 'byol':
|
||||||
from models.archs.byol.byol_model_wrapper import BYOL
|
from models.byol.byol_model_wrapper import BYOL
|
||||||
subnet = define_G(opt, opt_net['subnet'])
|
subnet = define_G(opt, opt_net['subnet'])
|
||||||
netG = BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
|
netG = BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
|
||||||
structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False))
|
structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False))
|
||||||
elif which_model == 'structural_byol':
|
elif which_model == 'structural_byol':
|
||||||
from models.archs.byol.byol_structural import StructuralBYOL
|
from models.byol.byol_structural import StructuralBYOL
|
||||||
subnet = define_G(opt, opt_net['subnet'])
|
subnet = define_G(opt, opt_net['subnet'])
|
||||||
netG = StructuralBYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
|
netG = StructuralBYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
|
||||||
pretrained_state_dict=opt_get(opt_net, ["pretrained_path"]),
|
pretrained_state_dict=opt_get(opt_net, ["pretrained_path"]),
|
||||||
freeze_until=opt_get(opt_net, ['freeze_until'], 0))
|
freeze_until=opt_get(opt_net, ['freeze_until'], 0))
|
||||||
elif which_model == 'spinenet':
|
elif which_model == 'spinenet':
|
||||||
from models.archs.spinenet_arch import SpineNet
|
from models.spinenet_arch import SpineNet
|
||||||
netG = SpineNet(str(opt_net['arch']), in_channels=3, use_input_norm=opt_net['use_input_norm'])
|
netG = SpineNet(str(opt_net['arch']), in_channels=3, use_input_norm=opt_net['use_input_norm'])
|
||||||
elif which_model == 'spinenet_with_logits':
|
elif which_model == 'spinenet_with_logits':
|
||||||
from models.archs.spinenet_arch import SpinenetWithLogits
|
from models.spinenet_arch import SpinenetWithLogits
|
||||||
netG = SpinenetWithLogits(str(opt_net['arch']), opt_net['output_to_attach'], opt_net['num_labels'],
|
netG = SpinenetWithLogits(str(opt_net['arch']), opt_net['output_to_attach'], opt_net['num_labels'],
|
||||||
in_channels=3, use_input_norm=opt_net['use_input_norm'])
|
in_channels=3, use_input_norm=opt_net['use_input_norm'])
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import models.archs.SwitchedResidualGenerator_arch as srg
|
import models.SwitchedResidualGenerator_arch as srg
|
||||||
import models.archs.discriminator_vgg_arch as disc
|
import models.discriminator_vgg_arch as disc
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
blacklisted_modules = [nn.Conv2d, nn.ReLU, nn.LeakyReLU, nn.BatchNorm2d, nn.Softmax]
|
blacklisted_modules = [nn.Conv2d, nn.ReLU, nn.LeakyReLU, nn.BatchNorm2d, nn.Softmax]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user