srflow_orig integration
This commit is contained in:
parent
f80acfcab6
commit
5ccdbcefe3
|
@ -58,7 +58,7 @@ class ExtensibleTrainer(BaseModel):
|
||||||
new_net = None
|
new_net = None
|
||||||
if net['type'] == 'generator':
|
if net['type'] == 'generator':
|
||||||
if new_net is None:
|
if new_net is None:
|
||||||
new_net = networks.define_G(net, None, opt['scale']).to(self.device)
|
new_net = networks.define_G(opt, net, opt['scale']).to(self.device)
|
||||||
self.netsG[name] = new_net
|
self.netsG[name] = new_net
|
||||||
elif net['type'] == 'discriminator':
|
elif net['type'] == 'discriminator':
|
||||||
if new_net is None:
|
if new_net is None:
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
|
|
||||||
from models.modules import thops
|
from models.archs.srflow_orig import thops
|
||||||
|
|
||||||
|
|
||||||
class _ActNorm(nn.Module):
|
class _ActNorm(nn.Module):
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
|
|
||||||
from models.modules import thops
|
from models.archs.srflow_orig import thops
|
||||||
from models.modules.flow import Conv2d, Conv2dZeros
|
from models.archs.srflow_orig.flow import Conv2d, Conv2dZeros
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
||||||
|
@ -15,10 +15,10 @@ class CondAffineSeparatedAndCond(nn.Module):
|
||||||
self.kernel_hidden = 1
|
self.kernel_hidden = 1
|
||||||
self.affine_eps = 0.0001
|
self.affine_eps = 0.0001
|
||||||
self.n_hidden_layers = 1
|
self.n_hidden_layers = 1
|
||||||
hidden_channels = opt_get(opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'hidden_channels'])
|
hidden_channels = opt_get(opt, ['networks', 'generator','flow', 'CondAffineSeparatedAndCond', 'hidden_channels'])
|
||||||
self.hidden_channels = 64 if hidden_channels is None else hidden_channels
|
self.hidden_channels = 64 if hidden_channels is None else hidden_channels
|
||||||
|
|
||||||
self.affine_eps = opt_get(opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'eps'], 0.0001)
|
self.affine_eps = opt_get(opt, ['networks', 'generator','flow', 'CondAffineSeparatedAndCond', 'eps'], 0.0001)
|
||||||
|
|
||||||
self.channels_for_nn = self.in_channels // 2
|
self.channels_for_nn = self.in_channels // 2
|
||||||
self.channels_for_co = self.in_channels - self.channels_for_nn
|
self.channels_for_co = self.in_channels - self.channels_for_nn
|
||||||
|
|
|
@ -1,9 +1,8 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
|
|
||||||
import models.modules
|
import models.archs.srflow_orig.Permutations
|
||||||
import models.modules.Permutations
|
from models.archs.srflow_orig import flow, thops, FlowAffineCouplingsAblation
|
||||||
from models.modules import flow, thops, FlowAffineCouplingsAblation
|
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
||||||
|
@ -47,16 +46,16 @@ class FlowStep(nn.Module):
|
||||||
self.acOpt = acOpt
|
self.acOpt = acOpt
|
||||||
|
|
||||||
# 1. actnorm
|
# 1. actnorm
|
||||||
self.actnorm = models.modules.FlowActNorms.ActNorm2d(in_channels, actnorm_scale)
|
self.actnorm = models.archs.srflow_orig.FlowActNorms.ActNorm2d(in_channels, actnorm_scale)
|
||||||
|
|
||||||
# 2. permute
|
# 2. permute
|
||||||
if flow_permutation == "invconv":
|
if flow_permutation == "invconv":
|
||||||
self.invconv = models.modules.Permutations.InvertibleConv1x1(
|
self.invconv = models.archs.srflow_orig.Permutations.InvertibleConv1x1(
|
||||||
in_channels, LU_decomposed=LU_decomposed)
|
in_channels, LU_decomposed=LU_decomposed)
|
||||||
|
|
||||||
# 3. coupling
|
# 3. coupling
|
||||||
if flow_coupling == "CondAffineSeparatedAndCond":
|
if flow_coupling == "CondAffineSeparatedAndCond":
|
||||||
self.affine = models.modules.FlowAffineCouplingsAblation.CondAffineSeparatedAndCond(in_channels=in_channels,
|
self.affine = models.archs.srflow_orig.FlowAffineCouplingsAblation.CondAffineSeparatedAndCond(in_channels=in_channels,
|
||||||
opt=opt)
|
opt=opt)
|
||||||
elif flow_coupling == "noCoupling":
|
elif flow_coupling == "noCoupling":
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -2,11 +2,11 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
|
|
||||||
import models.modules.Split
|
import models.archs.srflow_orig.Split
|
||||||
from models.modules import flow, thops
|
from models.archs.srflow_orig import flow, thops
|
||||||
from models.modules.Split import Split2d
|
from models.archs.srflow_orig.Split import Split2d
|
||||||
from models.modules.glow_arch import f_conv2d_bias
|
from models.archs.srflow_orig.glow_arch import f_conv2d_bias
|
||||||
from models.modules.FlowStep import FlowStep
|
from models.archs.srflow_orig.FlowStep import FlowStep
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
||||||
|
@ -21,8 +21,8 @@ class FlowUpsamplerNet(nn.Module):
|
||||||
|
|
||||||
self.layers = nn.ModuleList()
|
self.layers = nn.ModuleList()
|
||||||
self.output_shapes = []
|
self.output_shapes = []
|
||||||
self.L = opt_get(opt, ['network_G', 'flow', 'L'])
|
self.L = opt_get(opt, ['networks', 'generator','flow', 'L'])
|
||||||
self.K = opt_get(opt, ['network_G', 'flow', 'K'])
|
self.K = opt_get(opt, ['networks', 'generator','flow', 'K'])
|
||||||
if isinstance(self.K, int):
|
if isinstance(self.K, int):
|
||||||
self.K = [K for K in [K, ] * (self.L + 1)]
|
self.K = [K for K in [K, ] * (self.L + 1)]
|
||||||
|
|
||||||
|
@ -60,11 +60,11 @@ class FlowUpsamplerNet(nn.Module):
|
||||||
affineInCh = self.get_affineInCh(opt_get)
|
affineInCh = self.get_affineInCh(opt_get)
|
||||||
flow_permutation = self.get_flow_permutation(flow_permutation, opt)
|
flow_permutation = self.get_flow_permutation(flow_permutation, opt)
|
||||||
|
|
||||||
normOpt = opt_get(opt, ['network_G', 'flow', 'norm'])
|
normOpt = opt_get(opt, ['networks', 'generator','flow', 'norm'])
|
||||||
|
|
||||||
conditional_channels = {}
|
conditional_channels = {}
|
||||||
n_rrdb = self.get_n_rrdb_channels(opt, opt_get)
|
n_rrdb = self.get_n_rrdb_channels(opt, opt_get)
|
||||||
n_bypass_channels = opt_get(opt, ['network_G', 'flow', 'levelConditional', 'n_channels'])
|
n_bypass_channels = opt_get(opt, ['networks', 'generator','flow', 'levelConditional', 'n_channels'])
|
||||||
conditional_channels[0] = n_rrdb
|
conditional_channels[0] = n_rrdb
|
||||||
for level in range(1, self.L + 1):
|
for level in range(1, self.L + 1):
|
||||||
# Level 1 gets conditionals from 2, 3, 4 => L - level
|
# Level 1 gets conditionals from 2, 3, 4 => L - level
|
||||||
|
@ -88,7 +88,7 @@ class FlowUpsamplerNet(nn.Module):
|
||||||
# Split
|
# Split
|
||||||
self.arch_split(H, W, level, self.L, opt, opt_get)
|
self.arch_split(H, W, level, self.L, opt, opt_get)
|
||||||
|
|
||||||
if opt_get(opt, ['network_G', 'flow', 'split', 'enable']):
|
if opt_get(opt, ['networks', 'generator','flow', 'split', 'enable']):
|
||||||
self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64 // 2 // 2)
|
self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64 // 2 // 2)
|
||||||
else:
|
else:
|
||||||
self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64)
|
self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64)
|
||||||
|
@ -99,7 +99,7 @@ class FlowUpsamplerNet(nn.Module):
|
||||||
self.scaleW = 160 / W
|
self.scaleW = 160 / W
|
||||||
|
|
||||||
def get_n_rrdb_channels(self, opt, opt_get):
|
def get_n_rrdb_channels(self, opt, opt_get):
|
||||||
blocks = opt_get(opt, ['network_G', 'flow', 'stackRRDB', 'blocks'])
|
blocks = opt_get(opt, ['networks', 'generator','flow', 'stackRRDB', 'blocks'])
|
||||||
n_rrdb = 64 if blocks is None else (len(blocks) + 1) * 64
|
n_rrdb = 64 if blocks is None else (len(blocks) + 1) * 64
|
||||||
return n_rrdb
|
return n_rrdb
|
||||||
|
|
||||||
|
@ -126,33 +126,33 @@ class FlowUpsamplerNet(nn.Module):
|
||||||
[-1, self.C, H, W])
|
[-1, self.C, H, W])
|
||||||
|
|
||||||
def get_condAffSetting(self, opt, opt_get):
|
def get_condAffSetting(self, opt, opt_get):
|
||||||
condAff = opt_get(opt, ['network_G', 'flow', 'condAff']) or None
|
condAff = opt_get(opt, ['networks', 'generator','flow', 'condAff']) or None
|
||||||
condAff = opt_get(opt, ['network_G', 'flow', 'condFtAffine']) or condAff
|
condAff = opt_get(opt, ['networks', 'generator','flow', 'condFtAffine']) or condAff
|
||||||
return condAff
|
return condAff
|
||||||
|
|
||||||
def arch_split(self, H, W, L, levels, opt, opt_get):
|
def arch_split(self, H, W, L, levels, opt, opt_get):
|
||||||
correct_splits = opt_get(opt, ['network_G', 'flow', 'split', 'correct_splits'], False)
|
correct_splits = opt_get(opt, ['networks', 'generator','flow', 'split', 'correct_splits'], False)
|
||||||
correction = 0 if correct_splits else 1
|
correction = 0 if correct_splits else 1
|
||||||
if opt_get(opt, ['network_G', 'flow', 'split', 'enable']) and L < levels - correction:
|
if opt_get(opt, ['networks', 'generator','flow', 'split', 'enable']) and L < levels - correction:
|
||||||
logs_eps = opt_get(opt, ['network_G', 'flow', 'split', 'logs_eps']) or 0
|
logs_eps = opt_get(opt, ['networks', 'generator','flow', 'split', 'logs_eps']) or 0
|
||||||
consume_ratio = opt_get(opt, ['network_G', 'flow', 'split', 'consume_ratio']) or 0.5
|
consume_ratio = opt_get(opt, ['networks', 'generator','flow', 'split', 'consume_ratio']) or 0.5
|
||||||
position_name = get_position_name(H, self.opt['scale'])
|
position_name = get_position_name(H, self.opt['scale'])
|
||||||
position = position_name if opt_get(opt, ['network_G', 'flow', 'split', 'conditional']) else None
|
position = position_name if opt_get(opt, ['networks', 'generator','flow', 'split', 'conditional']) else None
|
||||||
cond_channels = opt_get(opt, ['network_G', 'flow', 'split', 'cond_channels'])
|
cond_channels = opt_get(opt, ['networks', 'generator','flow', 'split', 'cond_channels'])
|
||||||
cond_channels = 0 if cond_channels is None else cond_channels
|
cond_channels = 0 if cond_channels is None else cond_channels
|
||||||
|
|
||||||
t = opt_get(opt, ['network_G', 'flow', 'split', 'type'], 'Split2d')
|
t = opt_get(opt, ['networks', 'generator','flow', 'split', 'type'], 'Split2d')
|
||||||
|
|
||||||
if t == 'Split2d':
|
if t == 'Split2d':
|
||||||
split = models.modules.Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position,
|
split = models.archs.srflow_orig.Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position,
|
||||||
cond_channels=cond_channels, consume_ratio=consume_ratio, opt=opt)
|
cond_channels=cond_channels, consume_ratio=consume_ratio, opt=opt)
|
||||||
self.layers.append(split)
|
self.layers.append(split)
|
||||||
self.output_shapes.append([-1, split.num_channels_pass, H, W])
|
self.output_shapes.append([-1, split.num_channels_pass, H, W])
|
||||||
self.C = split.num_channels_pass
|
self.C = split.num_channels_pass
|
||||||
|
|
||||||
def arch_additionalFlowAffine(self, H, LU_decomposed, W, actnorm_scale, hidden_channels, opt):
|
def arch_additionalFlowAffine(self, H, LU_decomposed, W, actnorm_scale, hidden_channels, opt):
|
||||||
if 'additionalFlowNoAffine' in opt['network_G']['flow']:
|
if 'additionalFlowNoAffine' in opt['networks']['generator']['flow']:
|
||||||
n_additionalFlowNoAffine = int(opt['network_G']['flow']['additionalFlowNoAffine'])
|
n_additionalFlowNoAffine = int(opt['networks']['generator']['flow']['additionalFlowNoAffine'])
|
||||||
for _ in range(n_additionalFlowNoAffine):
|
for _ in range(n_additionalFlowNoAffine):
|
||||||
self.layers.append(
|
self.layers.append(
|
||||||
FlowStep(in_channels=self.C,
|
FlowStep(in_channels=self.C,
|
||||||
|
@ -171,11 +171,11 @@ class FlowUpsamplerNet(nn.Module):
|
||||||
return H, W
|
return H, W
|
||||||
|
|
||||||
def get_flow_permutation(self, flow_permutation, opt):
|
def get_flow_permutation(self, flow_permutation, opt):
|
||||||
flow_permutation = opt['network_G']['flow'].get('flow_permutation', 'invconv')
|
flow_permutation = opt['networks']['generator']['flow'].get('flow_permutation', 'invconv')
|
||||||
return flow_permutation
|
return flow_permutation
|
||||||
|
|
||||||
def get_affineInCh(self, opt_get):
|
def get_affineInCh(self, opt_get):
|
||||||
affineInCh = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or []
|
affineInCh = opt_get(self.opt, ['networks', 'generator','flow', 'stackRRDB', 'blocks']) or []
|
||||||
affineInCh = (len(affineInCh) + 1) * 64
|
affineInCh = (len(affineInCh) + 1) * 64
|
||||||
return affineInCh
|
return affineInCh
|
||||||
|
|
||||||
|
@ -204,7 +204,7 @@ class FlowUpsamplerNet(nn.Module):
|
||||||
level_conditionals = {}
|
level_conditionals = {}
|
||||||
bypasses = {}
|
bypasses = {}
|
||||||
|
|
||||||
L = opt_get(self.opt, ['network_G', 'flow', 'L'])
|
L = opt_get(self.opt, ['networks', 'generator','flow', 'L'])
|
||||||
|
|
||||||
for level in range(1, L + 1):
|
for level in range(1, L + 1):
|
||||||
bypasses[level] = torch.nn.functional.interpolate(gt, scale_factor=2 ** -level, mode='bilinear', align_corners=False)
|
bypasses[level] = torch.nn.functional.interpolate(gt, scale_factor=2 ** -level, mode='bilinear', align_corners=False)
|
||||||
|
@ -255,7 +255,7 @@ class FlowUpsamplerNet(nn.Module):
|
||||||
# debug.imwrite("fl_fea", fl_fea)
|
# debug.imwrite("fl_fea", fl_fea)
|
||||||
bypasses = {}
|
bypasses = {}
|
||||||
level_conditionals = {}
|
level_conditionals = {}
|
||||||
if not opt_get(self.opt, ['network_G', 'flow', 'levelConditional', 'conditional']) == True:
|
if not opt_get(self.opt, ['networks', 'generator','flow', 'levelConditional', 'conditional']) == True:
|
||||||
for level in range(self.L + 1):
|
for level in range(self.L + 1):
|
||||||
level_conditionals[level] = rrdbResults[self.levelToName[level]]
|
level_conditionals[level] = rrdbResults[self.levelToName[level]]
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ import torch
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from models.modules import thops
|
from models.archs.srflow_orig import thops
|
||||||
|
|
||||||
|
|
||||||
class InvertibleConv1x1(nn.Module):
|
class InvertibleConv1x1(nn.Module):
|
||||||
|
|
|
@ -2,70 +2,148 @@ import functools
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import models.modules.module_util as mutil
|
import models.archs.srflow_orig.module_util as mutil
|
||||||
|
from models.archs.arch_util import default_init_weights, ConvGnSilu
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
||||||
class ResidualDenseBlock_5C(nn.Module):
|
class ResidualDenseBlock(nn.Module):
|
||||||
def __init__(self, nf=64, gc=32, bias=True):
|
"""Residual Dense Block.
|
||||||
super(ResidualDenseBlock_5C, self).__init__()
|
|
||||||
# gc: growth channel, i.e. intermediate channels
|
Used in RRDB block in ESRGAN.
|
||||||
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
|
|
||||||
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
|
Args:
|
||||||
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
|
mid_channels (int): Channel number of intermediate features.
|
||||||
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
|
growth_channels (int): Channels for each growth.
|
||||||
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
|
"""
|
||||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
|
||||||
|
def __init__(self, mid_channels=64, growth_channels=32):
|
||||||
|
super(ResidualDenseBlock, self).__init__()
|
||||||
|
for i in range(5):
|
||||||
|
out_channels = mid_channels if i == 4 else growth_channels
|
||||||
|
self.add_module(
|
||||||
|
f'conv{i+1}',
|
||||||
|
nn.Conv2d(mid_channels + i * growth_channels, out_channels, 3,
|
||||||
|
1, 1))
|
||||||
|
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||||
|
for i in range(5):
|
||||||
|
default_init_weights(getattr(self, f'conv{i+1}'), 0.1)
|
||||||
|
|
||||||
# initialization
|
|
||||||
mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
"""Forward function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input tensor with shape (n, c, h, w).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Forward results.
|
||||||
|
"""
|
||||||
x1 = self.lrelu(self.conv1(x))
|
x1 = self.lrelu(self.conv1(x))
|
||||||
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
||||||
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
||||||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
||||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||||
|
# Emperically, we use 0.2 to scale the residual for better performance
|
||||||
return x5 * 0.2 + x
|
return x5 * 0.2 + x
|
||||||
|
|
||||||
|
|
||||||
class RRDB(nn.Module):
|
class RRDB(nn.Module):
|
||||||
'''Residual in Residual Dense Block'''
|
"""Residual in Residual Dense Block.
|
||||||
|
|
||||||
def __init__(self, nf, gc=32):
|
Used in RRDB-Net in ESRGAN.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mid_channels (int): Channel number of intermediate features.
|
||||||
|
growth_channels (int): Channels for each growth.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, mid_channels, growth_channels=32):
|
||||||
super(RRDB, self).__init__()
|
super(RRDB, self).__init__()
|
||||||
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
|
self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels)
|
||||||
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
|
self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels)
|
||||||
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
self.rdb3 = ResidualDenseBlock(mid_channels, growth_channels)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
out = self.RDB1(x)
|
"""Forward function.
|
||||||
out = self.RDB2(out)
|
|
||||||
out = self.RDB3(out)
|
Args:
|
||||||
|
x (Tensor): Input tensor with shape (n, c, h, w).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Forward results.
|
||||||
|
"""
|
||||||
|
out = self.rdb1(x)
|
||||||
|
out = self.rdb2(out)
|
||||||
|
out = self.rdb3(out)
|
||||||
|
# Emperically, we use 0.2 to scale the residual for better performance
|
||||||
return out * 0.2 + x
|
return out * 0.2 + x
|
||||||
|
|
||||||
|
|
||||||
|
class RRDBWithBypass(nn.Module):
|
||||||
|
"""Residual in Residual Dense Block.
|
||||||
|
|
||||||
|
Used in RRDB-Net in ESRGAN.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mid_channels (int): Channel number of intermediate features.
|
||||||
|
growth_channels (int): Channels for each growth.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, mid_channels, growth_channels=32):
|
||||||
|
super(RRDBWithBypass, self).__init__()
|
||||||
|
self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels)
|
||||||
|
self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels)
|
||||||
|
self.rdb3 = ResidualDenseBlock(mid_channels, growth_channels)
|
||||||
|
self.bypass = nn.Sequential(ConvGnSilu(mid_channels*2, mid_channels, kernel_size=3, bias=True, activation=True, norm=True),
|
||||||
|
ConvGnSilu(mid_channels, mid_channels//2, kernel_size=3, bias=False, activation=True, norm=False),
|
||||||
|
ConvGnSilu(mid_channels//2, 1, kernel_size=3, bias=False, activation=False, norm=False),
|
||||||
|
nn.Sigmoid())
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Forward function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input tensor with shape (n, c, h, w).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Forward results.
|
||||||
|
"""
|
||||||
|
out = self.rdb1(x)
|
||||||
|
out = self.rdb2(out)
|
||||||
|
out = self.rdb3(out)
|
||||||
|
bypass = self.bypass(torch.cat([x, out], dim=1))
|
||||||
|
self.bypass_map = bypass.detach().clone()
|
||||||
|
# Empirically, we use 0.2 to scale the residual for better performance
|
||||||
|
return out * 0.2 * bypass + x
|
||||||
|
|
||||||
|
|
||||||
class RRDBNet(nn.Module):
|
class RRDBNet(nn.Module):
|
||||||
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, opt=None):
|
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, opt=None):
|
||||||
self.opt = opt
|
self.opt = opt
|
||||||
super(RRDBNet, self).__init__()
|
super(RRDBNet, self).__init__()
|
||||||
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
|
||||||
|
bypass = opt_get(self.opt, ['networks', 'generator', 'rrdb_bypass'])
|
||||||
|
if bypass:
|
||||||
|
RRDB_block_f = functools.partial(RRDBWithBypass, mid_channels=nf, growth_channels=gc)
|
||||||
|
else:
|
||||||
|
RRDB_block_f = functools.partial(RRDB, mid_channels=nf, growth_channels=gc)
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
|
|
||||||
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||||
self.RRDB_trunk = mutil.make_layer(RRDB_block_f, nb)
|
self.body = mutil.make_layer(RRDB_block_f, nb)
|
||||||
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
self.conv_body = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||||
#### upsampling
|
#### upsampling
|
||||||
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
self.conv_up1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||||
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
self.conv_up2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||||
if self.scale >= 8:
|
if self.scale >= 8:
|
||||||
self.upconv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
self.conv_up3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||||
if self.scale >= 16:
|
if self.scale >= 16:
|
||||||
self.upconv4 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
self.conv_up4 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||||
if self.scale >= 32:
|
if self.scale >= 32:
|
||||||
self.upconv5 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
self.conv_up5 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||||
|
|
||||||
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
self.conv_hr = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||||
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
||||||
|
|
||||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||||
|
@ -73,23 +151,23 @@ class RRDBNet(nn.Module):
|
||||||
def forward(self, x, get_steps=False):
|
def forward(self, x, get_steps=False):
|
||||||
fea = self.conv_first(x)
|
fea = self.conv_first(x)
|
||||||
|
|
||||||
block_idxs = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or []
|
block_idxs = opt_get(self.opt, ['networks', 'generator','flow', 'stackRRDB', 'blocks']) or []
|
||||||
block_results = {}
|
block_results = {}
|
||||||
|
|
||||||
for idx, m in enumerate(self.RRDB_trunk.children()):
|
for idx, m in enumerate(self.body.children()):
|
||||||
fea = m(fea)
|
fea = m(fea)
|
||||||
for b in block_idxs:
|
for b in block_idxs:
|
||||||
if b == idx:
|
if b == idx:
|
||||||
block_results["block_{}".format(idx)] = fea
|
block_results["block_{}".format(idx)] = fea
|
||||||
|
|
||||||
trunk = self.trunk_conv(fea)
|
trunk = self.conv_body(fea)
|
||||||
|
|
||||||
last_lr_fea = fea + trunk
|
last_lr_fea = fea + trunk
|
||||||
|
|
||||||
fea_up2 = self.upconv1(F.interpolate(last_lr_fea, scale_factor=2, mode='nearest'))
|
fea_up2 = self.conv_up1(F.interpolate(last_lr_fea, scale_factor=2, mode='nearest'))
|
||||||
fea = self.lrelu(fea_up2)
|
fea = self.lrelu(fea_up2)
|
||||||
|
|
||||||
fea_up4 = self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))
|
fea_up4 = self.conv_up2(F.interpolate(fea, scale_factor=2, mode='nearest'))
|
||||||
fea = self.lrelu(fea_up4)
|
fea = self.lrelu(fea_up4)
|
||||||
|
|
||||||
fea_up8 = None
|
fea_up8 = None
|
||||||
|
@ -97,16 +175,16 @@ class RRDBNet(nn.Module):
|
||||||
fea_up32 = None
|
fea_up32 = None
|
||||||
|
|
||||||
if self.scale >= 8:
|
if self.scale >= 8:
|
||||||
fea_up8 = self.upconv3(F.interpolate(fea, scale_factor=2, mode='nearest'))
|
fea_up8 = self.conv_up3(F.interpolate(fea, scale_factor=2, mode='nearest'))
|
||||||
fea = self.lrelu(fea_up8)
|
fea = self.lrelu(fea_up8)
|
||||||
if self.scale >= 16:
|
if self.scale >= 16:
|
||||||
fea_up16 = self.upconv4(F.interpolate(fea, scale_factor=2, mode='nearest'))
|
fea_up16 = self.conv_up4(F.interpolate(fea, scale_factor=2, mode='nearest'))
|
||||||
fea = self.lrelu(fea_up16)
|
fea = self.lrelu(fea_up16)
|
||||||
if self.scale >= 32:
|
if self.scale >= 32:
|
||||||
fea_up32 = self.upconv5(F.interpolate(fea, scale_factor=2, mode='nearest'))
|
fea_up32 = self.conv_up5(F.interpolate(fea, scale_factor=2, mode='nearest'))
|
||||||
fea = self.lrelu(fea_up32)
|
fea = self.lrelu(fea_up32)
|
||||||
|
|
||||||
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
out = self.conv_last(self.lrelu(self.conv_hr(fea)))
|
||||||
|
|
||||||
results = {'last_lr_fea': last_lr_fea,
|
results = {'last_lr_fea': last_lr_fea,
|
||||||
'fea_up1': last_lr_fea,
|
'fea_up1': last_lr_fea,
|
||||||
|
@ -117,10 +195,10 @@ class RRDBNet(nn.Module):
|
||||||
'fea_up32': fea_up32,
|
'fea_up32': fea_up32,
|
||||||
'out': out}
|
'out': out}
|
||||||
|
|
||||||
fea_up0_en = opt_get(self.opt, ['network_G', 'flow', 'fea_up0']) or False
|
fea_up0_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up0']) or False
|
||||||
if fea_up0_en:
|
if fea_up0_en:
|
||||||
results['fea_up0'] = F.interpolate(last_lr_fea, scale_factor=1/2, mode='bilinear', align_corners=False, recompute_scale_factor=True)
|
results['fea_up0'] = F.interpolate(last_lr_fea, scale_factor=1/2, mode='bilinear', align_corners=False, recompute_scale_factor=True)
|
||||||
fea_upn1_en = opt_get(self.opt, ['network_G', 'flow', 'fea_up-1']) or False
|
fea_upn1_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up-1']) or False
|
||||||
if fea_upn1_en:
|
if fea_upn1_en:
|
||||||
results['fea_up-1'] = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True)
|
results['fea_up-1'] = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True)
|
||||||
|
|
||||||
|
|
|
@ -4,10 +4,10 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from models.modules.RRDBNet_arch import RRDBNet
|
from models.archs.srflow_orig.RRDBNet_arch import RRDBNet
|
||||||
from models.modules.FlowUpsamplerNet import FlowUpsamplerNet
|
from models.archs.srflow_orig.FlowUpsamplerNet import FlowUpsamplerNet
|
||||||
import models.modules.thops as thops
|
import models.archs.srflow_orig.thops as thops
|
||||||
import models.modules.flow as flow
|
import models.archs.srflow_orig.flow as flow
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,27 +19,40 @@ class SRFlowNet(nn.Module):
|
||||||
self.quant = 255 if opt_get(opt, ['datasets', 'train', 'quant']) is \
|
self.quant = 255 if opt_get(opt, ['datasets', 'train', 'quant']) is \
|
||||||
None else opt_get(opt, ['datasets', 'train', 'quant'])
|
None else opt_get(opt, ['datasets', 'train', 'quant'])
|
||||||
self.RRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, opt)
|
self.RRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, opt)
|
||||||
hidden_channels = opt_get(opt, ['network_G', 'flow', 'hidden_channels'])
|
if 'pretrain_rrdb' in opt['networks']['generator'].keys():
|
||||||
|
rrdb_state_dict = torch.load(opt['networks']['generator']['pretrain_rrdb'])
|
||||||
|
self.RRDB.load_state_dict(rrdb_state_dict, strict=True)
|
||||||
|
|
||||||
|
hidden_channels = opt_get(opt, ['networks', 'generator','flow', 'hidden_channels'])
|
||||||
hidden_channels = hidden_channels or 64
|
hidden_channels = hidden_channels or 64
|
||||||
self.RRDB_training = True # Default is true
|
self.RRDB_training = True # Default is true
|
||||||
|
|
||||||
train_RRDB_delay = opt_get(self.opt, ['network_G', 'train_RRDB_delay'])
|
train_RRDB_delay = opt_get(self.opt, ['networks', 'generator','train_RRDB_delay'])
|
||||||
set_RRDB_to_train = False
|
self.RRDB_training = False
|
||||||
if set_RRDB_to_train:
|
|
||||||
self.set_rrdb_training(True)
|
|
||||||
|
|
||||||
self.flowUpsamplerNet = \
|
self.flowUpsamplerNet = \
|
||||||
FlowUpsamplerNet((160, 160, 3), hidden_channels, K,
|
FlowUpsamplerNet((160, 160, 3), hidden_channels, K,
|
||||||
flow_coupling=opt['network_G']['flow']['coupling'], opt=opt)
|
flow_coupling=opt['networks']['generator']['flow']['coupling'], opt=opt)
|
||||||
self.i = 0
|
self.i = 0
|
||||||
|
|
||||||
def set_rrdb_training(self, trainable):
|
def get_random_z(self, heat, seed=None, batch_size=1, lr_shape=None, device='cuda'):
|
||||||
if self.RRDB_training != trainable:
|
if seed: torch.manual_seed(seed)
|
||||||
for p in self.RRDB.parameters():
|
if opt_get(self.opt, ['networks', 'generator', 'flow', 'split', 'enable']):
|
||||||
p.requires_grad = trainable
|
C = self.flowUpsamplerNet.C
|
||||||
self.RRDB_training = trainable
|
H = int(self.opt['scale'] * lr_shape[2] // self.flowUpsamplerNet.scaleH)
|
||||||
return True
|
W = int(self.opt['scale'] * lr_shape[3] // self.flowUpsamplerNet.scaleW)
|
||||||
return False
|
|
||||||
|
size = (batch_size, C, H, W)
|
||||||
|
if heat == 0:
|
||||||
|
z = torch.zeros(size)
|
||||||
|
else:
|
||||||
|
z = torch.normal(mean=0, std=heat, size=size)
|
||||||
|
else:
|
||||||
|
L = opt_get(self.opt, ['networks', 'generator', 'flow', 'L']) or 3
|
||||||
|
fac = 2 ** (L - 3)
|
||||||
|
z_size = int(self.lr_size // (2 ** (L - 3)))
|
||||||
|
z = torch.normal(mean=0, std=heat, size=(batch_size, 3 * 8 * 8 * fac * fac, z_size, z_size))
|
||||||
|
return z.to(device)
|
||||||
|
|
||||||
def forward(self, gt=None, lr=None, z=None, eps_std=None, reverse=False, epses=None, reverse_with_grad=False,
|
def forward(self, gt=None, lr=None, z=None, eps_std=None, reverse=False, epses=None, reverse_with_grad=False,
|
||||||
lr_enc=None,
|
lr_enc=None,
|
||||||
|
@ -48,14 +61,10 @@ class SRFlowNet(nn.Module):
|
||||||
return self.normal_flow(gt, lr, epses=epses, lr_enc=lr_enc, add_gt_noise=add_gt_noise, step=step,
|
return self.normal_flow(gt, lr, epses=epses, lr_enc=lr_enc, add_gt_noise=add_gt_noise, step=step,
|
||||||
y_onehot=y_label)
|
y_onehot=y_label)
|
||||||
else:
|
else:
|
||||||
# assert lr.shape[0] == 1
|
|
||||||
assert lr.shape[1] == 3
|
assert lr.shape[1] == 3
|
||||||
# assert lr.shape[2] == 20
|
if z is None:
|
||||||
# assert lr.shape[3] == 20
|
# Synthesize it.
|
||||||
# assert z.shape[0] == 1
|
z = self.get_random_z(eps_std, batch_size=lr.shape[0], lr_shape=lr.shape, device=lr.device)
|
||||||
# assert z.shape[1] == 3 * 8 * 8
|
|
||||||
# assert z.shape[2] == 20
|
|
||||||
# assert z.shape[3] == 20
|
|
||||||
if reverse_with_grad:
|
if reverse_with_grad:
|
||||||
return self.reverse_flow(lr, z, y_onehot=y_label, eps_std=eps_std, epses=epses, lr_enc=lr_enc,
|
return self.reverse_flow(lr, z, y_onehot=y_label, eps_std=eps_std, epses=epses, lr_enc=lr_enc,
|
||||||
add_gt_noise=add_gt_noise)
|
add_gt_noise=add_gt_noise)
|
||||||
|
@ -66,6 +75,10 @@ class SRFlowNet(nn.Module):
|
||||||
|
|
||||||
def normal_flow(self, gt, lr, y_onehot=None, epses=None, lr_enc=None, add_gt_noise=True, step=None):
|
def normal_flow(self, gt, lr, y_onehot=None, epses=None, lr_enc=None, add_gt_noise=True, step=None):
|
||||||
if lr_enc is None:
|
if lr_enc is None:
|
||||||
|
if self.RRDB_training:
|
||||||
|
lr_enc = self.rrdbPreprocessing(lr)
|
||||||
|
else:
|
||||||
|
with torch.no_grad():
|
||||||
lr_enc = self.rrdbPreprocessing(lr)
|
lr_enc = self.rrdbPreprocessing(lr)
|
||||||
|
|
||||||
logdet = torch.zeros_like(gt[:, 0, 0, 0])
|
logdet = torch.zeros_like(gt[:, 0, 0, 0])
|
||||||
|
@ -75,7 +88,7 @@ class SRFlowNet(nn.Module):
|
||||||
|
|
||||||
if add_gt_noise:
|
if add_gt_noise:
|
||||||
# Setup
|
# Setup
|
||||||
noiseQuant = opt_get(self.opt, ['network_G', 'flow', 'augmentation', 'noiseQuant'], True)
|
noiseQuant = opt_get(self.opt, ['networks', 'generator','flow', 'augmentation', 'noiseQuant'], True)
|
||||||
if noiseQuant:
|
if noiseQuant:
|
||||||
z = z + ((torch.rand(z.shape, device=z.device) - 0.5) / self.quant)
|
z = z + ((torch.rand(z.shape, device=z.device) - 0.5) / self.quant)
|
||||||
logdet = logdet + float(-np.log(self.quant) * pixels)
|
logdet = logdet + float(-np.log(self.quant) * pixels)
|
||||||
|
@ -101,11 +114,11 @@ class SRFlowNet(nn.Module):
|
||||||
|
|
||||||
def rrdbPreprocessing(self, lr):
|
def rrdbPreprocessing(self, lr):
|
||||||
rrdbResults = self.RRDB(lr, get_steps=True)
|
rrdbResults = self.RRDB(lr, get_steps=True)
|
||||||
block_idxs = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or []
|
block_idxs = opt_get(self.opt, ['networks', 'generator','flow', 'stackRRDB', 'blocks']) or []
|
||||||
if len(block_idxs) > 0:
|
if len(block_idxs) > 0:
|
||||||
concat = torch.cat([rrdbResults["block_{}".format(idx)] for idx in block_idxs], dim=1)
|
concat = torch.cat([rrdbResults["block_{}".format(idx)] for idx in block_idxs], dim=1)
|
||||||
|
|
||||||
if opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'concat']) or False:
|
if opt_get(self.opt, ['networks', 'generator','flow', 'stackRRDB', 'concat']) or False:
|
||||||
keys = ['last_lr_fea', 'fea_up1', 'fea_up2', 'fea_up4']
|
keys = ['last_lr_fea', 'fea_up1', 'fea_up2', 'fea_up4']
|
||||||
if 'fea_up0' in rrdbResults.keys():
|
if 'fea_up0' in rrdbResults.keys():
|
||||||
keys.append('fea_up0')
|
keys.append('fea_up0')
|
||||||
|
@ -134,6 +147,10 @@ class SRFlowNet(nn.Module):
|
||||||
logdet = logdet - float(-np.log(self.quant) * pixels)
|
logdet = logdet - float(-np.log(self.quant) * pixels)
|
||||||
|
|
||||||
if lr_enc is None:
|
if lr_enc is None:
|
||||||
|
if self.RRDB_training:
|
||||||
|
lr_enc = self.rrdbPreprocessing(lr)
|
||||||
|
else:
|
||||||
|
with torch.no_grad():
|
||||||
lr_enc = self.rrdbPreprocessing(lr)
|
lr_enc = self.rrdbPreprocessing(lr)
|
||||||
|
|
||||||
x, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, z=z, eps_std=eps_std, reverse=True, epses=epses,
|
x, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, z=z, eps_std=eps_std, reverse=True, epses=epses,
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
|
|
||||||
from models.modules import thops
|
from models.archs.srflow_orig import thops
|
||||||
from models.modules.FlowStep import FlowStep
|
from models.archs.srflow_orig.FlowStep import FlowStep
|
||||||
from models.modules.flow import Conv2dZeros, GaussianDiag
|
from models.archs.srflow_orig.flow import Conv2dZeros, GaussianDiag
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from models.modules.FlowActNorms import ActNorm2d
|
from models.archs.srflow_orig.FlowActNorms import ActNorm2d
|
||||||
from . import thops
|
from . import thops
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -105,12 +105,25 @@ class BaseModel():
|
||||||
if 'state_dict' in load_net:
|
if 'state_dict' in load_net:
|
||||||
load_net = load_net['state_dict']
|
load_net = load_net['state_dict']
|
||||||
|
|
||||||
|
is_srflow = False
|
||||||
load_net_clean = OrderedDict() # remove unnecessary 'module.'
|
load_net_clean = OrderedDict() # remove unnecessary 'module.'
|
||||||
for k, v in load_net.items():
|
for k, v in load_net.items():
|
||||||
if k.startswith('module.'):
|
if k.startswith('module.'):
|
||||||
load_net_clean[k[7:]] = v
|
load_net_clean[k[7:]] = v
|
||||||
if k.startswith('generator'): # Hack to fix ESRGAN pretrained model.
|
if k.startswith('generator'): # Hack to fix ESRGAN pretrained model.
|
||||||
load_net_clean[k[10:]] = v
|
load_net_clean[k[10:]] = v
|
||||||
|
if 'RRDB_trunk' in k or is_srflow: # Hacks to fix SRFlow imports, which uses some strange RDB names.
|
||||||
|
is_srflow = True
|
||||||
|
fixed_key = k.replace('RRDB_trunk', 'body')
|
||||||
|
if '.RDB' in fixed_key:
|
||||||
|
fixed_key = fixed_key.replace('.RDB', '.rdb')
|
||||||
|
elif '.upconv' in fixed_key:
|
||||||
|
fixed_key = fixed_key.replace('.upconv', '.conv_up')
|
||||||
|
elif '.trunk_conv' in fixed_key:
|
||||||
|
fixed_key = fixed_key.replace('.trunk_conv', '.conv_body')
|
||||||
|
elif '.HRconv' in fixed_key:
|
||||||
|
fixed_key = fixed_key.replace('.HRconv', '.conv_hr')
|
||||||
|
load_net_clean[fixed_key] = v
|
||||||
else:
|
else:
|
||||||
load_net_clean[k] = v
|
load_net_clean[k] = v
|
||||||
network.load_state_dict(load_net_clean, strict=strict)
|
network.load_state_dict(load_net_clean, strict=strict)
|
||||||
|
|
|
@ -28,11 +28,7 @@ from models.archs.teco_resgen import TecoGen
|
||||||
logger = logging.getLogger('base')
|
logger = logging.getLogger('base')
|
||||||
|
|
||||||
# Generator
|
# Generator
|
||||||
def define_G(opt, net_key='network_G', scale=None):
|
def define_G(opt, opt_net, scale=None):
|
||||||
if net_key is not None:
|
|
||||||
opt_net = opt[net_key]
|
|
||||||
else:
|
|
||||||
opt_net = opt
|
|
||||||
if scale is None:
|
if scale is None:
|
||||||
scale = opt['scale']
|
scale = opt['scale']
|
||||||
which_model = opt_net['which_model_G']
|
which_model = opt_net['which_model_G']
|
||||||
|
@ -141,6 +137,17 @@ def define_G(opt, net_key='network_G', scale=None):
|
||||||
netG = stylegan2.StyleGan2GeneratorWithLatent(image_size=opt_net['image_size'], latent_dim=opt_net['latent_dim'],
|
netG = stylegan2.StyleGan2GeneratorWithLatent(image_size=opt_net['image_size'], latent_dim=opt_net['latent_dim'],
|
||||||
style_depth=opt_net['style_depth'], structure_input=is_structured,
|
style_depth=opt_net['style_depth'], structure_input=is_structured,
|
||||||
attn_layers=attn)
|
attn_layers=attn)
|
||||||
|
elif which_model == 'srflow':
|
||||||
|
from models.archs.srflow import SRFlow_arch
|
||||||
|
netG = SRFlow_arch.SRFlowNet(in_nc=3, out_nc=3, nf=opt_net['nf'], nb=opt_net['nb'],
|
||||||
|
quant=opt_net['quant'], flow_block_maps=opt_net['rrdb_block_maps'],
|
||||||
|
noise_quant=opt_net['noise_quant'], hidden_channels=opt_net['nf'],
|
||||||
|
K=opt_net['K'], L=opt_net['L'], train_rrdb_at_step=opt_net['rrdb_train_step'],
|
||||||
|
hr_img_shape=opt_net['hr_shape'], scale=opt_net['scale'])
|
||||||
|
elif which_model == 'srflow_orig':
|
||||||
|
from models.archs.srflow_orig import SRFlowNet_arch
|
||||||
|
netG = SRFlowNet_arch.SRFlowNet(in_nc=3, out_nc=3, nf=opt_net['nf'], nb=opt_net['nb'], scale=opt['scale'],
|
||||||
|
K=opt_net['K'], opt=opt)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
||||||
return netG
|
return netG
|
||||||
|
|
|
@ -157,6 +157,8 @@ class AddNoiseInjector(Injector):
|
||||||
scale = state[self.opt['scale']]
|
scale = state[self.opt['scale']]
|
||||||
else:
|
else:
|
||||||
scale = self.opt['scale']
|
scale = self.opt['scale']
|
||||||
|
if scale is None:
|
||||||
|
scale = 1
|
||||||
|
|
||||||
ref = state[self.opt['in']]
|
ref = state[self.opt['in']]
|
||||||
if self.mode == 'normal':
|
if self.mode == 'normal':
|
||||||
|
|
|
@ -139,7 +139,8 @@ class ConfigurableStep(Module):
|
||||||
continue
|
continue
|
||||||
# Don't do injections tagged with 'after' or 'before' when we are out of spec.
|
# Don't do injections tagged with 'after' or 'before' when we are out of spec.
|
||||||
if 'after' in inj.opt.keys() and self.env['step'] < inj.opt['after'] or \
|
if 'after' in inj.opt.keys() and self.env['step'] < inj.opt['after'] or \
|
||||||
'before' in inj.opt.keys() and self.env['step'] > inj.opt['before']:
|
'before' in inj.opt.keys() and self.env['step'] > inj.opt['before'] or \
|
||||||
|
'every' in inj.opt.keys() and self.env['step'] % inj.opt['every'] != 0:
|
||||||
continue
|
continue
|
||||||
injected = inj(local_state)
|
injected = inj(local_state)
|
||||||
local_state.update(injected)
|
local_state.update(injected)
|
||||||
|
|
|
@ -291,7 +291,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_srg2_classic_proper_disc.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_srflow.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -291,7 +291,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_srg2_grad_penalty.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_srg2_classic_proper_disc.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user