srflow_orig integration
This commit is contained in:
parent
f80acfcab6
commit
5ccdbcefe3
|
@ -58,7 +58,7 @@ class ExtensibleTrainer(BaseModel):
|
|||
new_net = None
|
||||
if net['type'] == 'generator':
|
||||
if new_net is None:
|
||||
new_net = networks.define_G(net, None, opt['scale']).to(self.device)
|
||||
new_net = networks.define_G(opt, net, opt['scale']).to(self.device)
|
||||
self.netsG[name] = new_net
|
||||
elif net['type'] == 'discriminator':
|
||||
if new_net is None:
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
from torch import nn as nn
|
||||
|
||||
from models.modules import thops
|
||||
from models.archs.srflow_orig import thops
|
||||
|
||||
|
||||
class _ActNorm(nn.Module):
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
import torch
|
||||
from torch import nn as nn
|
||||
|
||||
from models.modules import thops
|
||||
from models.modules.flow import Conv2d, Conv2dZeros
|
||||
from models.archs.srflow_orig import thops
|
||||
from models.archs.srflow_orig.flow import Conv2d, Conv2dZeros
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
|
@ -15,10 +15,10 @@ class CondAffineSeparatedAndCond(nn.Module):
|
|||
self.kernel_hidden = 1
|
||||
self.affine_eps = 0.0001
|
||||
self.n_hidden_layers = 1
|
||||
hidden_channels = opt_get(opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'hidden_channels'])
|
||||
hidden_channels = opt_get(opt, ['networks', 'generator','flow', 'CondAffineSeparatedAndCond', 'hidden_channels'])
|
||||
self.hidden_channels = 64 if hidden_channels is None else hidden_channels
|
||||
|
||||
self.affine_eps = opt_get(opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'eps'], 0.0001)
|
||||
self.affine_eps = opt_get(opt, ['networks', 'generator','flow', 'CondAffineSeparatedAndCond', 'eps'], 0.0001)
|
||||
|
||||
self.channels_for_nn = self.in_channels // 2
|
||||
self.channels_for_co = self.in_channels - self.channels_for_nn
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
import torch
|
||||
from torch import nn as nn
|
||||
|
||||
import models.modules
|
||||
import models.modules.Permutations
|
||||
from models.modules import flow, thops, FlowAffineCouplingsAblation
|
||||
import models.archs.srflow_orig.Permutations
|
||||
from models.archs.srflow_orig import flow, thops, FlowAffineCouplingsAblation
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
|
@ -47,16 +46,16 @@ class FlowStep(nn.Module):
|
|||
self.acOpt = acOpt
|
||||
|
||||
# 1. actnorm
|
||||
self.actnorm = models.modules.FlowActNorms.ActNorm2d(in_channels, actnorm_scale)
|
||||
self.actnorm = models.archs.srflow_orig.FlowActNorms.ActNorm2d(in_channels, actnorm_scale)
|
||||
|
||||
# 2. permute
|
||||
if flow_permutation == "invconv":
|
||||
self.invconv = models.modules.Permutations.InvertibleConv1x1(
|
||||
self.invconv = models.archs.srflow_orig.Permutations.InvertibleConv1x1(
|
||||
in_channels, LU_decomposed=LU_decomposed)
|
||||
|
||||
# 3. coupling
|
||||
if flow_coupling == "CondAffineSeparatedAndCond":
|
||||
self.affine = models.modules.FlowAffineCouplingsAblation.CondAffineSeparatedAndCond(in_channels=in_channels,
|
||||
self.affine = models.archs.srflow_orig.FlowAffineCouplingsAblation.CondAffineSeparatedAndCond(in_channels=in_channels,
|
||||
opt=opt)
|
||||
elif flow_coupling == "noCoupling":
|
||||
pass
|
||||
|
|
|
@ -2,11 +2,11 @@ import numpy as np
|
|||
import torch
|
||||
from torch import nn as nn
|
||||
|
||||
import models.modules.Split
|
||||
from models.modules import flow, thops
|
||||
from models.modules.Split import Split2d
|
||||
from models.modules.glow_arch import f_conv2d_bias
|
||||
from models.modules.FlowStep import FlowStep
|
||||
import models.archs.srflow_orig.Split
|
||||
from models.archs.srflow_orig import flow, thops
|
||||
from models.archs.srflow_orig.Split import Split2d
|
||||
from models.archs.srflow_orig.glow_arch import f_conv2d_bias
|
||||
from models.archs.srflow_orig.FlowStep import FlowStep
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
|
@ -21,8 +21,8 @@ class FlowUpsamplerNet(nn.Module):
|
|||
|
||||
self.layers = nn.ModuleList()
|
||||
self.output_shapes = []
|
||||
self.L = opt_get(opt, ['network_G', 'flow', 'L'])
|
||||
self.K = opt_get(opt, ['network_G', 'flow', 'K'])
|
||||
self.L = opt_get(opt, ['networks', 'generator','flow', 'L'])
|
||||
self.K = opt_get(opt, ['networks', 'generator','flow', 'K'])
|
||||
if isinstance(self.K, int):
|
||||
self.K = [K for K in [K, ] * (self.L + 1)]
|
||||
|
||||
|
@ -60,11 +60,11 @@ class FlowUpsamplerNet(nn.Module):
|
|||
affineInCh = self.get_affineInCh(opt_get)
|
||||
flow_permutation = self.get_flow_permutation(flow_permutation, opt)
|
||||
|
||||
normOpt = opt_get(opt, ['network_G', 'flow', 'norm'])
|
||||
normOpt = opt_get(opt, ['networks', 'generator','flow', 'norm'])
|
||||
|
||||
conditional_channels = {}
|
||||
n_rrdb = self.get_n_rrdb_channels(opt, opt_get)
|
||||
n_bypass_channels = opt_get(opt, ['network_G', 'flow', 'levelConditional', 'n_channels'])
|
||||
n_bypass_channels = opt_get(opt, ['networks', 'generator','flow', 'levelConditional', 'n_channels'])
|
||||
conditional_channels[0] = n_rrdb
|
||||
for level in range(1, self.L + 1):
|
||||
# Level 1 gets conditionals from 2, 3, 4 => L - level
|
||||
|
@ -88,7 +88,7 @@ class FlowUpsamplerNet(nn.Module):
|
|||
# Split
|
||||
self.arch_split(H, W, level, self.L, opt, opt_get)
|
||||
|
||||
if opt_get(opt, ['network_G', 'flow', 'split', 'enable']):
|
||||
if opt_get(opt, ['networks', 'generator','flow', 'split', 'enable']):
|
||||
self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64 // 2 // 2)
|
||||
else:
|
||||
self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64)
|
||||
|
@ -99,7 +99,7 @@ class FlowUpsamplerNet(nn.Module):
|
|||
self.scaleW = 160 / W
|
||||
|
||||
def get_n_rrdb_channels(self, opt, opt_get):
|
||||
blocks = opt_get(opt, ['network_G', 'flow', 'stackRRDB', 'blocks'])
|
||||
blocks = opt_get(opt, ['networks', 'generator','flow', 'stackRRDB', 'blocks'])
|
||||
n_rrdb = 64 if blocks is None else (len(blocks) + 1) * 64
|
||||
return n_rrdb
|
||||
|
||||
|
@ -126,33 +126,33 @@ class FlowUpsamplerNet(nn.Module):
|
|||
[-1, self.C, H, W])
|
||||
|
||||
def get_condAffSetting(self, opt, opt_get):
|
||||
condAff = opt_get(opt, ['network_G', 'flow', 'condAff']) or None
|
||||
condAff = opt_get(opt, ['network_G', 'flow', 'condFtAffine']) or condAff
|
||||
condAff = opt_get(opt, ['networks', 'generator','flow', 'condAff']) or None
|
||||
condAff = opt_get(opt, ['networks', 'generator','flow', 'condFtAffine']) or condAff
|
||||
return condAff
|
||||
|
||||
def arch_split(self, H, W, L, levels, opt, opt_get):
|
||||
correct_splits = opt_get(opt, ['network_G', 'flow', 'split', 'correct_splits'], False)
|
||||
correct_splits = opt_get(opt, ['networks', 'generator','flow', 'split', 'correct_splits'], False)
|
||||
correction = 0 if correct_splits else 1
|
||||
if opt_get(opt, ['network_G', 'flow', 'split', 'enable']) and L < levels - correction:
|
||||
logs_eps = opt_get(opt, ['network_G', 'flow', 'split', 'logs_eps']) or 0
|
||||
consume_ratio = opt_get(opt, ['network_G', 'flow', 'split', 'consume_ratio']) or 0.5
|
||||
if opt_get(opt, ['networks', 'generator','flow', 'split', 'enable']) and L < levels - correction:
|
||||
logs_eps = opt_get(opt, ['networks', 'generator','flow', 'split', 'logs_eps']) or 0
|
||||
consume_ratio = opt_get(opt, ['networks', 'generator','flow', 'split', 'consume_ratio']) or 0.5
|
||||
position_name = get_position_name(H, self.opt['scale'])
|
||||
position = position_name if opt_get(opt, ['network_G', 'flow', 'split', 'conditional']) else None
|
||||
cond_channels = opt_get(opt, ['network_G', 'flow', 'split', 'cond_channels'])
|
||||
position = position_name if opt_get(opt, ['networks', 'generator','flow', 'split', 'conditional']) else None
|
||||
cond_channels = opt_get(opt, ['networks', 'generator','flow', 'split', 'cond_channels'])
|
||||
cond_channels = 0 if cond_channels is None else cond_channels
|
||||
|
||||
t = opt_get(opt, ['network_G', 'flow', 'split', 'type'], 'Split2d')
|
||||
t = opt_get(opt, ['networks', 'generator','flow', 'split', 'type'], 'Split2d')
|
||||
|
||||
if t == 'Split2d':
|
||||
split = models.modules.Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position,
|
||||
split = models.archs.srflow_orig.Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position,
|
||||
cond_channels=cond_channels, consume_ratio=consume_ratio, opt=opt)
|
||||
self.layers.append(split)
|
||||
self.output_shapes.append([-1, split.num_channels_pass, H, W])
|
||||
self.C = split.num_channels_pass
|
||||
|
||||
def arch_additionalFlowAffine(self, H, LU_decomposed, W, actnorm_scale, hidden_channels, opt):
|
||||
if 'additionalFlowNoAffine' in opt['network_G']['flow']:
|
||||
n_additionalFlowNoAffine = int(opt['network_G']['flow']['additionalFlowNoAffine'])
|
||||
if 'additionalFlowNoAffine' in opt['networks']['generator']['flow']:
|
||||
n_additionalFlowNoAffine = int(opt['networks']['generator']['flow']['additionalFlowNoAffine'])
|
||||
for _ in range(n_additionalFlowNoAffine):
|
||||
self.layers.append(
|
||||
FlowStep(in_channels=self.C,
|
||||
|
@ -171,11 +171,11 @@ class FlowUpsamplerNet(nn.Module):
|
|||
return H, W
|
||||
|
||||
def get_flow_permutation(self, flow_permutation, opt):
|
||||
flow_permutation = opt['network_G']['flow'].get('flow_permutation', 'invconv')
|
||||
flow_permutation = opt['networks']['generator']['flow'].get('flow_permutation', 'invconv')
|
||||
return flow_permutation
|
||||
|
||||
def get_affineInCh(self, opt_get):
|
||||
affineInCh = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or []
|
||||
affineInCh = opt_get(self.opt, ['networks', 'generator','flow', 'stackRRDB', 'blocks']) or []
|
||||
affineInCh = (len(affineInCh) + 1) * 64
|
||||
return affineInCh
|
||||
|
||||
|
@ -204,7 +204,7 @@ class FlowUpsamplerNet(nn.Module):
|
|||
level_conditionals = {}
|
||||
bypasses = {}
|
||||
|
||||
L = opt_get(self.opt, ['network_G', 'flow', 'L'])
|
||||
L = opt_get(self.opt, ['networks', 'generator','flow', 'L'])
|
||||
|
||||
for level in range(1, L + 1):
|
||||
bypasses[level] = torch.nn.functional.interpolate(gt, scale_factor=2 ** -level, mode='bilinear', align_corners=False)
|
||||
|
@ -255,7 +255,7 @@ class FlowUpsamplerNet(nn.Module):
|
|||
# debug.imwrite("fl_fea", fl_fea)
|
||||
bypasses = {}
|
||||
level_conditionals = {}
|
||||
if not opt_get(self.opt, ['network_G', 'flow', 'levelConditional', 'conditional']) == True:
|
||||
if not opt_get(self.opt, ['networks', 'generator','flow', 'levelConditional', 'conditional']) == True:
|
||||
for level in range(self.L + 1):
|
||||
level_conditionals[level] = rrdbResults[self.levelToName[level]]
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ import torch
|
|||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from models.modules import thops
|
||||
from models.archs.srflow_orig import thops
|
||||
|
||||
|
||||
class InvertibleConv1x1(nn.Module):
|
||||
|
|
|
@ -2,70 +2,148 @@ import functools
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import models.modules.module_util as mutil
|
||||
import models.archs.srflow_orig.module_util as mutil
|
||||
from models.archs.arch_util import default_init_weights, ConvGnSilu
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
class ResidualDenseBlock_5C(nn.Module):
|
||||
def __init__(self, nf=64, gc=32, bias=True):
|
||||
super(ResidualDenseBlock_5C, self).__init__()
|
||||
# gc: growth channel, i.e. intermediate channels
|
||||
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
|
||||
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
class ResidualDenseBlock(nn.Module):
|
||||
"""Residual Dense Block.
|
||||
|
||||
Used in RRDB block in ESRGAN.
|
||||
|
||||
Args:
|
||||
mid_channels (int): Channel number of intermediate features.
|
||||
growth_channels (int): Channels for each growth.
|
||||
"""
|
||||
|
||||
def __init__(self, mid_channels=64, growth_channels=32):
|
||||
super(ResidualDenseBlock, self).__init__()
|
||||
for i in range(5):
|
||||
out_channels = mid_channels if i == 4 else growth_channels
|
||||
self.add_module(
|
||||
f'conv{i+1}',
|
||||
nn.Conv2d(mid_channels + i * growth_channels, out_channels, 3,
|
||||
1, 1))
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
for i in range(5):
|
||||
default_init_weights(getattr(self, f'conv{i+1}'), 0.1)
|
||||
|
||||
# initialization
|
||||
mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor with shape (n, c, h, w).
|
||||
|
||||
Returns:
|
||||
Tensor: Forward results.
|
||||
"""
|
||||
x1 = self.lrelu(self.conv1(x))
|
||||
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
||||
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
||||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||
# Emperically, we use 0.2 to scale the residual for better performance
|
||||
return x5 * 0.2 + x
|
||||
|
||||
|
||||
class RRDB(nn.Module):
|
||||
'''Residual in Residual Dense Block'''
|
||||
"""Residual in Residual Dense Block.
|
||||
|
||||
def __init__(self, nf, gc=32):
|
||||
Used in RRDB-Net in ESRGAN.
|
||||
|
||||
Args:
|
||||
mid_channels (int): Channel number of intermediate features.
|
||||
growth_channels (int): Channels for each growth.
|
||||
"""
|
||||
|
||||
def __init__(self, mid_channels, growth_channels=32):
|
||||
super(RRDB, self).__init__()
|
||||
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels)
|
||||
self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels)
|
||||
self.rdb3 = ResidualDenseBlock(mid_channels, growth_channels)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.RDB1(x)
|
||||
out = self.RDB2(out)
|
||||
out = self.RDB3(out)
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor with shape (n, c, h, w).
|
||||
|
||||
Returns:
|
||||
Tensor: Forward results.
|
||||
"""
|
||||
out = self.rdb1(x)
|
||||
out = self.rdb2(out)
|
||||
out = self.rdb3(out)
|
||||
# Emperically, we use 0.2 to scale the residual for better performance
|
||||
return out * 0.2 + x
|
||||
|
||||
|
||||
class RRDBWithBypass(nn.Module):
|
||||
"""Residual in Residual Dense Block.
|
||||
|
||||
Used in RRDB-Net in ESRGAN.
|
||||
|
||||
Args:
|
||||
mid_channels (int): Channel number of intermediate features.
|
||||
growth_channels (int): Channels for each growth.
|
||||
"""
|
||||
|
||||
def __init__(self, mid_channels, growth_channels=32):
|
||||
super(RRDBWithBypass, self).__init__()
|
||||
self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels)
|
||||
self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels)
|
||||
self.rdb3 = ResidualDenseBlock(mid_channels, growth_channels)
|
||||
self.bypass = nn.Sequential(ConvGnSilu(mid_channels*2, mid_channels, kernel_size=3, bias=True, activation=True, norm=True),
|
||||
ConvGnSilu(mid_channels, mid_channels//2, kernel_size=3, bias=False, activation=True, norm=False),
|
||||
ConvGnSilu(mid_channels//2, 1, kernel_size=3, bias=False, activation=False, norm=False),
|
||||
nn.Sigmoid())
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor with shape (n, c, h, w).
|
||||
|
||||
Returns:
|
||||
Tensor: Forward results.
|
||||
"""
|
||||
out = self.rdb1(x)
|
||||
out = self.rdb2(out)
|
||||
out = self.rdb3(out)
|
||||
bypass = self.bypass(torch.cat([x, out], dim=1))
|
||||
self.bypass_map = bypass.detach().clone()
|
||||
# Empirically, we use 0.2 to scale the residual for better performance
|
||||
return out * 0.2 * bypass + x
|
||||
|
||||
|
||||
class RRDBNet(nn.Module):
|
||||
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, opt=None):
|
||||
self.opt = opt
|
||||
super(RRDBNet, self).__init__()
|
||||
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
||||
|
||||
bypass = opt_get(self.opt, ['networks', 'generator', 'rrdb_bypass'])
|
||||
if bypass:
|
||||
RRDB_block_f = functools.partial(RRDBWithBypass, mid_channels=nf, growth_channels=gc)
|
||||
else:
|
||||
RRDB_block_f = functools.partial(RRDB, mid_channels=nf, growth_channels=gc)
|
||||
self.scale = scale
|
||||
|
||||
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||
self.RRDB_trunk = mutil.make_layer(RRDB_block_f, nb)
|
||||
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.body = mutil.make_layer(RRDB_block_f, nb)
|
||||
self.conv_body = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
#### upsampling
|
||||
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.conv_up1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.conv_up2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
if self.scale >= 8:
|
||||
self.upconv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.conv_up3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
if self.scale >= 16:
|
||||
self.upconv4 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.conv_up4 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
if self.scale >= 32:
|
||||
self.upconv5 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.conv_up5 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
|
||||
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.conv_hr = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
||||
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
@ -73,23 +151,23 @@ class RRDBNet(nn.Module):
|
|||
def forward(self, x, get_steps=False):
|
||||
fea = self.conv_first(x)
|
||||
|
||||
block_idxs = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or []
|
||||
block_idxs = opt_get(self.opt, ['networks', 'generator','flow', 'stackRRDB', 'blocks']) or []
|
||||
block_results = {}
|
||||
|
||||
for idx, m in enumerate(self.RRDB_trunk.children()):
|
||||
for idx, m in enumerate(self.body.children()):
|
||||
fea = m(fea)
|
||||
for b in block_idxs:
|
||||
if b == idx:
|
||||
block_results["block_{}".format(idx)] = fea
|
||||
|
||||
trunk = self.trunk_conv(fea)
|
||||
trunk = self.conv_body(fea)
|
||||
|
||||
last_lr_fea = fea + trunk
|
||||
|
||||
fea_up2 = self.upconv1(F.interpolate(last_lr_fea, scale_factor=2, mode='nearest'))
|
||||
fea_up2 = self.conv_up1(F.interpolate(last_lr_fea, scale_factor=2, mode='nearest'))
|
||||
fea = self.lrelu(fea_up2)
|
||||
|
||||
fea_up4 = self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))
|
||||
fea_up4 = self.conv_up2(F.interpolate(fea, scale_factor=2, mode='nearest'))
|
||||
fea = self.lrelu(fea_up4)
|
||||
|
||||
fea_up8 = None
|
||||
|
@ -97,16 +175,16 @@ class RRDBNet(nn.Module):
|
|||
fea_up32 = None
|
||||
|
||||
if self.scale >= 8:
|
||||
fea_up8 = self.upconv3(F.interpolate(fea, scale_factor=2, mode='nearest'))
|
||||
fea_up8 = self.conv_up3(F.interpolate(fea, scale_factor=2, mode='nearest'))
|
||||
fea = self.lrelu(fea_up8)
|
||||
if self.scale >= 16:
|
||||
fea_up16 = self.upconv4(F.interpolate(fea, scale_factor=2, mode='nearest'))
|
||||
fea_up16 = self.conv_up4(F.interpolate(fea, scale_factor=2, mode='nearest'))
|
||||
fea = self.lrelu(fea_up16)
|
||||
if self.scale >= 32:
|
||||
fea_up32 = self.upconv5(F.interpolate(fea, scale_factor=2, mode='nearest'))
|
||||
fea_up32 = self.conv_up5(F.interpolate(fea, scale_factor=2, mode='nearest'))
|
||||
fea = self.lrelu(fea_up32)
|
||||
|
||||
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
||||
out = self.conv_last(self.lrelu(self.conv_hr(fea)))
|
||||
|
||||
results = {'last_lr_fea': last_lr_fea,
|
||||
'fea_up1': last_lr_fea,
|
||||
|
@ -117,10 +195,10 @@ class RRDBNet(nn.Module):
|
|||
'fea_up32': fea_up32,
|
||||
'out': out}
|
||||
|
||||
fea_up0_en = opt_get(self.opt, ['network_G', 'flow', 'fea_up0']) or False
|
||||
fea_up0_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up0']) or False
|
||||
if fea_up0_en:
|
||||
results['fea_up0'] = F.interpolate(last_lr_fea, scale_factor=1/2, mode='bilinear', align_corners=False, recompute_scale_factor=True)
|
||||
fea_upn1_en = opt_get(self.opt, ['network_G', 'flow', 'fea_up-1']) or False
|
||||
fea_upn1_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up-1']) or False
|
||||
if fea_upn1_en:
|
||||
results['fea_up-1'] = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True)
|
||||
|
||||
|
|
|
@ -4,10 +4,10 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from models.modules.RRDBNet_arch import RRDBNet
|
||||
from models.modules.FlowUpsamplerNet import FlowUpsamplerNet
|
||||
import models.modules.thops as thops
|
||||
import models.modules.flow as flow
|
||||
from models.archs.srflow_orig.RRDBNet_arch import RRDBNet
|
||||
from models.archs.srflow_orig.FlowUpsamplerNet import FlowUpsamplerNet
|
||||
import models.archs.srflow_orig.thops as thops
|
||||
import models.archs.srflow_orig.flow as flow
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
|
@ -19,27 +19,40 @@ class SRFlowNet(nn.Module):
|
|||
self.quant = 255 if opt_get(opt, ['datasets', 'train', 'quant']) is \
|
||||
None else opt_get(opt, ['datasets', 'train', 'quant'])
|
||||
self.RRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, opt)
|
||||
hidden_channels = opt_get(opt, ['network_G', 'flow', 'hidden_channels'])
|
||||
if 'pretrain_rrdb' in opt['networks']['generator'].keys():
|
||||
rrdb_state_dict = torch.load(opt['networks']['generator']['pretrain_rrdb'])
|
||||
self.RRDB.load_state_dict(rrdb_state_dict, strict=True)
|
||||
|
||||
hidden_channels = opt_get(opt, ['networks', 'generator','flow', 'hidden_channels'])
|
||||
hidden_channels = hidden_channels or 64
|
||||
self.RRDB_training = True # Default is true
|
||||
|
||||
train_RRDB_delay = opt_get(self.opt, ['network_G', 'train_RRDB_delay'])
|
||||
set_RRDB_to_train = False
|
||||
if set_RRDB_to_train:
|
||||
self.set_rrdb_training(True)
|
||||
train_RRDB_delay = opt_get(self.opt, ['networks', 'generator','train_RRDB_delay'])
|
||||
self.RRDB_training = False
|
||||
|
||||
self.flowUpsamplerNet = \
|
||||
FlowUpsamplerNet((160, 160, 3), hidden_channels, K,
|
||||
flow_coupling=opt['network_G']['flow']['coupling'], opt=opt)
|
||||
flow_coupling=opt['networks']['generator']['flow']['coupling'], opt=opt)
|
||||
self.i = 0
|
||||
|
||||
def set_rrdb_training(self, trainable):
|
||||
if self.RRDB_training != trainable:
|
||||
for p in self.RRDB.parameters():
|
||||
p.requires_grad = trainable
|
||||
self.RRDB_training = trainable
|
||||
return True
|
||||
return False
|
||||
def get_random_z(self, heat, seed=None, batch_size=1, lr_shape=None, device='cuda'):
|
||||
if seed: torch.manual_seed(seed)
|
||||
if opt_get(self.opt, ['networks', 'generator', 'flow', 'split', 'enable']):
|
||||
C = self.flowUpsamplerNet.C
|
||||
H = int(self.opt['scale'] * lr_shape[2] // self.flowUpsamplerNet.scaleH)
|
||||
W = int(self.opt['scale'] * lr_shape[3] // self.flowUpsamplerNet.scaleW)
|
||||
|
||||
size = (batch_size, C, H, W)
|
||||
if heat == 0:
|
||||
z = torch.zeros(size)
|
||||
else:
|
||||
z = torch.normal(mean=0, std=heat, size=size)
|
||||
else:
|
||||
L = opt_get(self.opt, ['networks', 'generator', 'flow', 'L']) or 3
|
||||
fac = 2 ** (L - 3)
|
||||
z_size = int(self.lr_size // (2 ** (L - 3)))
|
||||
z = torch.normal(mean=0, std=heat, size=(batch_size, 3 * 8 * 8 * fac * fac, z_size, z_size))
|
||||
return z.to(device)
|
||||
|
||||
def forward(self, gt=None, lr=None, z=None, eps_std=None, reverse=False, epses=None, reverse_with_grad=False,
|
||||
lr_enc=None,
|
||||
|
@ -48,14 +61,10 @@ class SRFlowNet(nn.Module):
|
|||
return self.normal_flow(gt, lr, epses=epses, lr_enc=lr_enc, add_gt_noise=add_gt_noise, step=step,
|
||||
y_onehot=y_label)
|
||||
else:
|
||||
# assert lr.shape[0] == 1
|
||||
assert lr.shape[1] == 3
|
||||
# assert lr.shape[2] == 20
|
||||
# assert lr.shape[3] == 20
|
||||
# assert z.shape[0] == 1
|
||||
# assert z.shape[1] == 3 * 8 * 8
|
||||
# assert z.shape[2] == 20
|
||||
# assert z.shape[3] == 20
|
||||
if z is None:
|
||||
# Synthesize it.
|
||||
z = self.get_random_z(eps_std, batch_size=lr.shape[0], lr_shape=lr.shape, device=lr.device)
|
||||
if reverse_with_grad:
|
||||
return self.reverse_flow(lr, z, y_onehot=y_label, eps_std=eps_std, epses=epses, lr_enc=lr_enc,
|
||||
add_gt_noise=add_gt_noise)
|
||||
|
@ -66,7 +75,11 @@ class SRFlowNet(nn.Module):
|
|||
|
||||
def normal_flow(self, gt, lr, y_onehot=None, epses=None, lr_enc=None, add_gt_noise=True, step=None):
|
||||
if lr_enc is None:
|
||||
lr_enc = self.rrdbPreprocessing(lr)
|
||||
if self.RRDB_training:
|
||||
lr_enc = self.rrdbPreprocessing(lr)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
lr_enc = self.rrdbPreprocessing(lr)
|
||||
|
||||
logdet = torch.zeros_like(gt[:, 0, 0, 0])
|
||||
pixels = thops.pixels(gt)
|
||||
|
@ -75,7 +88,7 @@ class SRFlowNet(nn.Module):
|
|||
|
||||
if add_gt_noise:
|
||||
# Setup
|
||||
noiseQuant = opt_get(self.opt, ['network_G', 'flow', 'augmentation', 'noiseQuant'], True)
|
||||
noiseQuant = opt_get(self.opt, ['networks', 'generator','flow', 'augmentation', 'noiseQuant'], True)
|
||||
if noiseQuant:
|
||||
z = z + ((torch.rand(z.shape, device=z.device) - 0.5) / self.quant)
|
||||
logdet = logdet + float(-np.log(self.quant) * pixels)
|
||||
|
@ -101,11 +114,11 @@ class SRFlowNet(nn.Module):
|
|||
|
||||
def rrdbPreprocessing(self, lr):
|
||||
rrdbResults = self.RRDB(lr, get_steps=True)
|
||||
block_idxs = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or []
|
||||
block_idxs = opt_get(self.opt, ['networks', 'generator','flow', 'stackRRDB', 'blocks']) or []
|
||||
if len(block_idxs) > 0:
|
||||
concat = torch.cat([rrdbResults["block_{}".format(idx)] for idx in block_idxs], dim=1)
|
||||
|
||||
if opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'concat']) or False:
|
||||
if opt_get(self.opt, ['networks', 'generator','flow', 'stackRRDB', 'concat']) or False:
|
||||
keys = ['last_lr_fea', 'fea_up1', 'fea_up2', 'fea_up4']
|
||||
if 'fea_up0' in rrdbResults.keys():
|
||||
keys.append('fea_up0')
|
||||
|
@ -134,7 +147,11 @@ class SRFlowNet(nn.Module):
|
|||
logdet = logdet - float(-np.log(self.quant) * pixels)
|
||||
|
||||
if lr_enc is None:
|
||||
lr_enc = self.rrdbPreprocessing(lr)
|
||||
if self.RRDB_training:
|
||||
lr_enc = self.rrdbPreprocessing(lr)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
lr_enc = self.rrdbPreprocessing(lr)
|
||||
|
||||
x, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, z=z, eps_std=eps_std, reverse=True, epses=epses,
|
||||
logdet=logdet)
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
import torch
|
||||
from torch import nn as nn
|
||||
|
||||
from models.modules import thops
|
||||
from models.modules.FlowStep import FlowStep
|
||||
from models.modules.flow import Conv2dZeros, GaussianDiag
|
||||
from models.archs.srflow_orig import thops
|
||||
from models.archs.srflow_orig.FlowStep import FlowStep
|
||||
from models.archs.srflow_orig.flow import Conv2dZeros, GaussianDiag
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
from models.modules.FlowActNorms import ActNorm2d
|
||||
from models.archs.srflow_orig.FlowActNorms import ActNorm2d
|
||||
from . import thops
|
||||
|
||||
|
||||
|
|
|
@ -105,12 +105,25 @@ class BaseModel():
|
|||
if 'state_dict' in load_net:
|
||||
load_net = load_net['state_dict']
|
||||
|
||||
is_srflow = False
|
||||
load_net_clean = OrderedDict() # remove unnecessary 'module.'
|
||||
for k, v in load_net.items():
|
||||
if k.startswith('module.'):
|
||||
load_net_clean[k[7:]] = v
|
||||
if k.startswith('generator'): # Hack to fix ESRGAN pretrained model.
|
||||
load_net_clean[k[10:]] = v
|
||||
if 'RRDB_trunk' in k or is_srflow: # Hacks to fix SRFlow imports, which uses some strange RDB names.
|
||||
is_srflow = True
|
||||
fixed_key = k.replace('RRDB_trunk', 'body')
|
||||
if '.RDB' in fixed_key:
|
||||
fixed_key = fixed_key.replace('.RDB', '.rdb')
|
||||
elif '.upconv' in fixed_key:
|
||||
fixed_key = fixed_key.replace('.upconv', '.conv_up')
|
||||
elif '.trunk_conv' in fixed_key:
|
||||
fixed_key = fixed_key.replace('.trunk_conv', '.conv_body')
|
||||
elif '.HRconv' in fixed_key:
|
||||
fixed_key = fixed_key.replace('.HRconv', '.conv_hr')
|
||||
load_net_clean[fixed_key] = v
|
||||
else:
|
||||
load_net_clean[k] = v
|
||||
network.load_state_dict(load_net_clean, strict=strict)
|
||||
|
|
|
@ -28,11 +28,7 @@ from models.archs.teco_resgen import TecoGen
|
|||
logger = logging.getLogger('base')
|
||||
|
||||
# Generator
|
||||
def define_G(opt, net_key='network_G', scale=None):
|
||||
if net_key is not None:
|
||||
opt_net = opt[net_key]
|
||||
else:
|
||||
opt_net = opt
|
||||
def define_G(opt, opt_net, scale=None):
|
||||
if scale is None:
|
||||
scale = opt['scale']
|
||||
which_model = opt_net['which_model_G']
|
||||
|
@ -141,6 +137,17 @@ def define_G(opt, net_key='network_G', scale=None):
|
|||
netG = stylegan2.StyleGan2GeneratorWithLatent(image_size=opt_net['image_size'], latent_dim=opt_net['latent_dim'],
|
||||
style_depth=opt_net['style_depth'], structure_input=is_structured,
|
||||
attn_layers=attn)
|
||||
elif which_model == 'srflow':
|
||||
from models.archs.srflow import SRFlow_arch
|
||||
netG = SRFlow_arch.SRFlowNet(in_nc=3, out_nc=3, nf=opt_net['nf'], nb=opt_net['nb'],
|
||||
quant=opt_net['quant'], flow_block_maps=opt_net['rrdb_block_maps'],
|
||||
noise_quant=opt_net['noise_quant'], hidden_channels=opt_net['nf'],
|
||||
K=opt_net['K'], L=opt_net['L'], train_rrdb_at_step=opt_net['rrdb_train_step'],
|
||||
hr_img_shape=opt_net['hr_shape'], scale=opt_net['scale'])
|
||||
elif which_model == 'srflow_orig':
|
||||
from models.archs.srflow_orig import SRFlowNet_arch
|
||||
netG = SRFlowNet_arch.SRFlowNet(in_nc=3, out_nc=3, nf=opt_net['nf'], nb=opt_net['nb'], scale=opt['scale'],
|
||||
K=opt_net['K'], opt=opt)
|
||||
else:
|
||||
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
||||
return netG
|
||||
|
|
|
@ -157,6 +157,8 @@ class AddNoiseInjector(Injector):
|
|||
scale = state[self.opt['scale']]
|
||||
else:
|
||||
scale = self.opt['scale']
|
||||
if scale is None:
|
||||
scale = 1
|
||||
|
||||
ref = state[self.opt['in']]
|
||||
if self.mode == 'normal':
|
||||
|
|
|
@ -139,7 +139,8 @@ class ConfigurableStep(Module):
|
|||
continue
|
||||
# Don't do injections tagged with 'after' or 'before' when we are out of spec.
|
||||
if 'after' in inj.opt.keys() and self.env['step'] < inj.opt['after'] or \
|
||||
'before' in inj.opt.keys() and self.env['step'] > inj.opt['before']:
|
||||
'before' in inj.opt.keys() and self.env['step'] > inj.opt['before'] or \
|
||||
'every' in inj.opt.keys() and self.env['step'] % inj.opt['every'] != 0:
|
||||
continue
|
||||
injected = inj(local_state)
|
||||
local_state.update(injected)
|
||||
|
|
|
@ -291,7 +291,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_srg2_classic_proper_disc.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_srflow.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -291,7 +291,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_srg2_grad_penalty.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_srg2_classic_proper_disc.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
Loading…
Reference in New Issue
Block a user