srflow_orig integration

This commit is contained in:
James Betker 2020-11-19 23:47:24 -07:00
parent f80acfcab6
commit 5ccdbcefe3
16 changed files with 239 additions and 122 deletions

View File

@ -58,7 +58,7 @@ class ExtensibleTrainer(BaseModel):
new_net = None
if net['type'] == 'generator':
if new_net is None:
new_net = networks.define_G(net, None, opt['scale']).to(self.device)
new_net = networks.define_G(opt, net, opt['scale']).to(self.device)
self.netsG[name] = new_net
elif net['type'] == 'discriminator':
if new_net is None:

View File

@ -1,7 +1,7 @@
import torch
from torch import nn as nn
from models.modules import thops
from models.archs.srflow_orig import thops
class _ActNorm(nn.Module):

View File

@ -1,8 +1,8 @@
import torch
from torch import nn as nn
from models.modules import thops
from models.modules.flow import Conv2d, Conv2dZeros
from models.archs.srflow_orig import thops
from models.archs.srflow_orig.flow import Conv2d, Conv2dZeros
from utils.util import opt_get
@ -15,10 +15,10 @@ class CondAffineSeparatedAndCond(nn.Module):
self.kernel_hidden = 1
self.affine_eps = 0.0001
self.n_hidden_layers = 1
hidden_channels = opt_get(opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'hidden_channels'])
hidden_channels = opt_get(opt, ['networks', 'generator','flow', 'CondAffineSeparatedAndCond', 'hidden_channels'])
self.hidden_channels = 64 if hidden_channels is None else hidden_channels
self.affine_eps = opt_get(opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'eps'], 0.0001)
self.affine_eps = opt_get(opt, ['networks', 'generator','flow', 'CondAffineSeparatedAndCond', 'eps'], 0.0001)
self.channels_for_nn = self.in_channels // 2
self.channels_for_co = self.in_channels - self.channels_for_nn

View File

@ -1,9 +1,8 @@
import torch
from torch import nn as nn
import models.modules
import models.modules.Permutations
from models.modules import flow, thops, FlowAffineCouplingsAblation
import models.archs.srflow_orig.Permutations
from models.archs.srflow_orig import flow, thops, FlowAffineCouplingsAblation
from utils.util import opt_get
@ -47,16 +46,16 @@ class FlowStep(nn.Module):
self.acOpt = acOpt
# 1. actnorm
self.actnorm = models.modules.FlowActNorms.ActNorm2d(in_channels, actnorm_scale)
self.actnorm = models.archs.srflow_orig.FlowActNorms.ActNorm2d(in_channels, actnorm_scale)
# 2. permute
if flow_permutation == "invconv":
self.invconv = models.modules.Permutations.InvertibleConv1x1(
self.invconv = models.archs.srflow_orig.Permutations.InvertibleConv1x1(
in_channels, LU_decomposed=LU_decomposed)
# 3. coupling
if flow_coupling == "CondAffineSeparatedAndCond":
self.affine = models.modules.FlowAffineCouplingsAblation.CondAffineSeparatedAndCond(in_channels=in_channels,
self.affine = models.archs.srflow_orig.FlowAffineCouplingsAblation.CondAffineSeparatedAndCond(in_channels=in_channels,
opt=opt)
elif flow_coupling == "noCoupling":
pass

View File

@ -2,11 +2,11 @@ import numpy as np
import torch
from torch import nn as nn
import models.modules.Split
from models.modules import flow, thops
from models.modules.Split import Split2d
from models.modules.glow_arch import f_conv2d_bias
from models.modules.FlowStep import FlowStep
import models.archs.srflow_orig.Split
from models.archs.srflow_orig import flow, thops
from models.archs.srflow_orig.Split import Split2d
from models.archs.srflow_orig.glow_arch import f_conv2d_bias
from models.archs.srflow_orig.FlowStep import FlowStep
from utils.util import opt_get
@ -21,8 +21,8 @@ class FlowUpsamplerNet(nn.Module):
self.layers = nn.ModuleList()
self.output_shapes = []
self.L = opt_get(opt, ['network_G', 'flow', 'L'])
self.K = opt_get(opt, ['network_G', 'flow', 'K'])
self.L = opt_get(opt, ['networks', 'generator','flow', 'L'])
self.K = opt_get(opt, ['networks', 'generator','flow', 'K'])
if isinstance(self.K, int):
self.K = [K for K in [K, ] * (self.L + 1)]
@ -60,11 +60,11 @@ class FlowUpsamplerNet(nn.Module):
affineInCh = self.get_affineInCh(opt_get)
flow_permutation = self.get_flow_permutation(flow_permutation, opt)
normOpt = opt_get(opt, ['network_G', 'flow', 'norm'])
normOpt = opt_get(opt, ['networks', 'generator','flow', 'norm'])
conditional_channels = {}
n_rrdb = self.get_n_rrdb_channels(opt, opt_get)
n_bypass_channels = opt_get(opt, ['network_G', 'flow', 'levelConditional', 'n_channels'])
n_bypass_channels = opt_get(opt, ['networks', 'generator','flow', 'levelConditional', 'n_channels'])
conditional_channels[0] = n_rrdb
for level in range(1, self.L + 1):
# Level 1 gets conditionals from 2, 3, 4 => L - level
@ -88,7 +88,7 @@ class FlowUpsamplerNet(nn.Module):
# Split
self.arch_split(H, W, level, self.L, opt, opt_get)
if opt_get(opt, ['network_G', 'flow', 'split', 'enable']):
if opt_get(opt, ['networks', 'generator','flow', 'split', 'enable']):
self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64 // 2 // 2)
else:
self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64)
@ -99,7 +99,7 @@ class FlowUpsamplerNet(nn.Module):
self.scaleW = 160 / W
def get_n_rrdb_channels(self, opt, opt_get):
blocks = opt_get(opt, ['network_G', 'flow', 'stackRRDB', 'blocks'])
blocks = opt_get(opt, ['networks', 'generator','flow', 'stackRRDB', 'blocks'])
n_rrdb = 64 if blocks is None else (len(blocks) + 1) * 64
return n_rrdb
@ -126,33 +126,33 @@ class FlowUpsamplerNet(nn.Module):
[-1, self.C, H, W])
def get_condAffSetting(self, opt, opt_get):
condAff = opt_get(opt, ['network_G', 'flow', 'condAff']) or None
condAff = opt_get(opt, ['network_G', 'flow', 'condFtAffine']) or condAff
condAff = opt_get(opt, ['networks', 'generator','flow', 'condAff']) or None
condAff = opt_get(opt, ['networks', 'generator','flow', 'condFtAffine']) or condAff
return condAff
def arch_split(self, H, W, L, levels, opt, opt_get):
correct_splits = opt_get(opt, ['network_G', 'flow', 'split', 'correct_splits'], False)
correct_splits = opt_get(opt, ['networks', 'generator','flow', 'split', 'correct_splits'], False)
correction = 0 if correct_splits else 1
if opt_get(opt, ['network_G', 'flow', 'split', 'enable']) and L < levels - correction:
logs_eps = opt_get(opt, ['network_G', 'flow', 'split', 'logs_eps']) or 0
consume_ratio = opt_get(opt, ['network_G', 'flow', 'split', 'consume_ratio']) or 0.5
if opt_get(opt, ['networks', 'generator','flow', 'split', 'enable']) and L < levels - correction:
logs_eps = opt_get(opt, ['networks', 'generator','flow', 'split', 'logs_eps']) or 0
consume_ratio = opt_get(opt, ['networks', 'generator','flow', 'split', 'consume_ratio']) or 0.5
position_name = get_position_name(H, self.opt['scale'])
position = position_name if opt_get(opt, ['network_G', 'flow', 'split', 'conditional']) else None
cond_channels = opt_get(opt, ['network_G', 'flow', 'split', 'cond_channels'])
position = position_name if opt_get(opt, ['networks', 'generator','flow', 'split', 'conditional']) else None
cond_channels = opt_get(opt, ['networks', 'generator','flow', 'split', 'cond_channels'])
cond_channels = 0 if cond_channels is None else cond_channels
t = opt_get(opt, ['network_G', 'flow', 'split', 'type'], 'Split2d')
t = opt_get(opt, ['networks', 'generator','flow', 'split', 'type'], 'Split2d')
if t == 'Split2d':
split = models.modules.Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position,
split = models.archs.srflow_orig.Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position,
cond_channels=cond_channels, consume_ratio=consume_ratio, opt=opt)
self.layers.append(split)
self.output_shapes.append([-1, split.num_channels_pass, H, W])
self.C = split.num_channels_pass
def arch_additionalFlowAffine(self, H, LU_decomposed, W, actnorm_scale, hidden_channels, opt):
if 'additionalFlowNoAffine' in opt['network_G']['flow']:
n_additionalFlowNoAffine = int(opt['network_G']['flow']['additionalFlowNoAffine'])
if 'additionalFlowNoAffine' in opt['networks']['generator']['flow']:
n_additionalFlowNoAffine = int(opt['networks']['generator']['flow']['additionalFlowNoAffine'])
for _ in range(n_additionalFlowNoAffine):
self.layers.append(
FlowStep(in_channels=self.C,
@ -171,11 +171,11 @@ class FlowUpsamplerNet(nn.Module):
return H, W
def get_flow_permutation(self, flow_permutation, opt):
flow_permutation = opt['network_G']['flow'].get('flow_permutation', 'invconv')
flow_permutation = opt['networks']['generator']['flow'].get('flow_permutation', 'invconv')
return flow_permutation
def get_affineInCh(self, opt_get):
affineInCh = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or []
affineInCh = opt_get(self.opt, ['networks', 'generator','flow', 'stackRRDB', 'blocks']) or []
affineInCh = (len(affineInCh) + 1) * 64
return affineInCh
@ -204,7 +204,7 @@ class FlowUpsamplerNet(nn.Module):
level_conditionals = {}
bypasses = {}
L = opt_get(self.opt, ['network_G', 'flow', 'L'])
L = opt_get(self.opt, ['networks', 'generator','flow', 'L'])
for level in range(1, L + 1):
bypasses[level] = torch.nn.functional.interpolate(gt, scale_factor=2 ** -level, mode='bilinear', align_corners=False)
@ -255,7 +255,7 @@ class FlowUpsamplerNet(nn.Module):
# debug.imwrite("fl_fea", fl_fea)
bypasses = {}
level_conditionals = {}
if not opt_get(self.opt, ['network_G', 'flow', 'levelConditional', 'conditional']) == True:
if not opt_get(self.opt, ['networks', 'generator','flow', 'levelConditional', 'conditional']) == True:
for level in range(self.L + 1):
level_conditionals[level] = rrdbResults[self.levelToName[level]]

View File

@ -3,7 +3,7 @@ import torch
from torch import nn as nn
from torch.nn import functional as F
from models.modules import thops
from models.archs.srflow_orig import thops
class InvertibleConv1x1(nn.Module):

View File

@ -2,70 +2,148 @@ import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
import models.modules.module_util as mutil
import models.archs.srflow_orig.module_util as mutil
from models.archs.arch_util import default_init_weights, ConvGnSilu
from utils.util import opt_get
class ResidualDenseBlock_5C(nn.Module):
def __init__(self, nf=64, gc=32, bias=True):
super(ResidualDenseBlock_5C, self).__init__()
# gc: growth channel, i.e. intermediate channels
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
class ResidualDenseBlock(nn.Module):
"""Residual Dense Block.
Used in RRDB block in ESRGAN.
Args:
mid_channels (int): Channel number of intermediate features.
growth_channels (int): Channels for each growth.
"""
def __init__(self, mid_channels=64, growth_channels=32):
super(ResidualDenseBlock, self).__init__()
for i in range(5):
out_channels = mid_channels if i == 4 else growth_channels
self.add_module(
f'conv{i+1}',
nn.Conv2d(mid_channels + i * growth_channels, out_channels, 3,
1, 1))
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
for i in range(5):
default_init_weights(getattr(self, f'conv{i+1}'), 0.1)
# initialization
mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
# Emperically, we use 0.2 to scale the residual for better performance
return x5 * 0.2 + x
class RRDB(nn.Module):
'''Residual in Residual Dense Block'''
"""Residual in Residual Dense Block.
def __init__(self, nf, gc=32):
Used in RRDB-Net in ESRGAN.
Args:
mid_channels (int): Channel number of intermediate features.
growth_channels (int): Channels for each growth.
"""
def __init__(self, mid_channels, growth_channels=32):
super(RRDB, self).__init__()
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels)
self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels)
self.rdb3 = ResidualDenseBlock(mid_channels, growth_channels)
def forward(self, x):
out = self.RDB1(x)
out = self.RDB2(out)
out = self.RDB3(out)
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
out = self.rdb1(x)
out = self.rdb2(out)
out = self.rdb3(out)
# Emperically, we use 0.2 to scale the residual for better performance
return out * 0.2 + x
class RRDBWithBypass(nn.Module):
"""Residual in Residual Dense Block.
Used in RRDB-Net in ESRGAN.
Args:
mid_channels (int): Channel number of intermediate features.
growth_channels (int): Channels for each growth.
"""
def __init__(self, mid_channels, growth_channels=32):
super(RRDBWithBypass, self).__init__()
self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels)
self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels)
self.rdb3 = ResidualDenseBlock(mid_channels, growth_channels)
self.bypass = nn.Sequential(ConvGnSilu(mid_channels*2, mid_channels, kernel_size=3, bias=True, activation=True, norm=True),
ConvGnSilu(mid_channels, mid_channels//2, kernel_size=3, bias=False, activation=True, norm=False),
ConvGnSilu(mid_channels//2, 1, kernel_size=3, bias=False, activation=False, norm=False),
nn.Sigmoid())
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
out = self.rdb1(x)
out = self.rdb2(out)
out = self.rdb3(out)
bypass = self.bypass(torch.cat([x, out], dim=1))
self.bypass_map = bypass.detach().clone()
# Empirically, we use 0.2 to scale the residual for better performance
return out * 0.2 * bypass + x
class RRDBNet(nn.Module):
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, opt=None):
self.opt = opt
super(RRDBNet, self).__init__()
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
bypass = opt_get(self.opt, ['networks', 'generator', 'rrdb_bypass'])
if bypass:
RRDB_block_f = functools.partial(RRDBWithBypass, mid_channels=nf, growth_channels=gc)
else:
RRDB_block_f = functools.partial(RRDB, mid_channels=nf, growth_channels=gc)
self.scale = scale
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
self.RRDB_trunk = mutil.make_layer(RRDB_block_f, nb)
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.body = mutil.make_layer(RRDB_block_f, nb)
self.conv_body = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
#### upsampling
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.conv_up1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.conv_up2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
if self.scale >= 8:
self.upconv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.conv_up3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
if self.scale >= 16:
self.upconv4 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.conv_up4 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
if self.scale >= 32:
self.upconv5 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.conv_up5 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.conv_hr = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
@ -73,23 +151,23 @@ class RRDBNet(nn.Module):
def forward(self, x, get_steps=False):
fea = self.conv_first(x)
block_idxs = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or []
block_idxs = opt_get(self.opt, ['networks', 'generator','flow', 'stackRRDB', 'blocks']) or []
block_results = {}
for idx, m in enumerate(self.RRDB_trunk.children()):
for idx, m in enumerate(self.body.children()):
fea = m(fea)
for b in block_idxs:
if b == idx:
block_results["block_{}".format(idx)] = fea
trunk = self.trunk_conv(fea)
trunk = self.conv_body(fea)
last_lr_fea = fea + trunk
fea_up2 = self.upconv1(F.interpolate(last_lr_fea, scale_factor=2, mode='nearest'))
fea_up2 = self.conv_up1(F.interpolate(last_lr_fea, scale_factor=2, mode='nearest'))
fea = self.lrelu(fea_up2)
fea_up4 = self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))
fea_up4 = self.conv_up2(F.interpolate(fea, scale_factor=2, mode='nearest'))
fea = self.lrelu(fea_up4)
fea_up8 = None
@ -97,16 +175,16 @@ class RRDBNet(nn.Module):
fea_up32 = None
if self.scale >= 8:
fea_up8 = self.upconv3(F.interpolate(fea, scale_factor=2, mode='nearest'))
fea_up8 = self.conv_up3(F.interpolate(fea, scale_factor=2, mode='nearest'))
fea = self.lrelu(fea_up8)
if self.scale >= 16:
fea_up16 = self.upconv4(F.interpolate(fea, scale_factor=2, mode='nearest'))
fea_up16 = self.conv_up4(F.interpolate(fea, scale_factor=2, mode='nearest'))
fea = self.lrelu(fea_up16)
if self.scale >= 32:
fea_up32 = self.upconv5(F.interpolate(fea, scale_factor=2, mode='nearest'))
fea_up32 = self.conv_up5(F.interpolate(fea, scale_factor=2, mode='nearest'))
fea = self.lrelu(fea_up32)
out = self.conv_last(self.lrelu(self.HRconv(fea)))
out = self.conv_last(self.lrelu(self.conv_hr(fea)))
results = {'last_lr_fea': last_lr_fea,
'fea_up1': last_lr_fea,
@ -117,10 +195,10 @@ class RRDBNet(nn.Module):
'fea_up32': fea_up32,
'out': out}
fea_up0_en = opt_get(self.opt, ['network_G', 'flow', 'fea_up0']) or False
fea_up0_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up0']) or False
if fea_up0_en:
results['fea_up0'] = F.interpolate(last_lr_fea, scale_factor=1/2, mode='bilinear', align_corners=False, recompute_scale_factor=True)
fea_upn1_en = opt_get(self.opt, ['network_G', 'flow', 'fea_up-1']) or False
fea_upn1_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up-1']) or False
if fea_upn1_en:
results['fea_up-1'] = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True)

View File

@ -4,10 +4,10 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from models.modules.RRDBNet_arch import RRDBNet
from models.modules.FlowUpsamplerNet import FlowUpsamplerNet
import models.modules.thops as thops
import models.modules.flow as flow
from models.archs.srflow_orig.RRDBNet_arch import RRDBNet
from models.archs.srflow_orig.FlowUpsamplerNet import FlowUpsamplerNet
import models.archs.srflow_orig.thops as thops
import models.archs.srflow_orig.flow as flow
from utils.util import opt_get
@ -19,27 +19,40 @@ class SRFlowNet(nn.Module):
self.quant = 255 if opt_get(opt, ['datasets', 'train', 'quant']) is \
None else opt_get(opt, ['datasets', 'train', 'quant'])
self.RRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, opt)
hidden_channels = opt_get(opt, ['network_G', 'flow', 'hidden_channels'])
if 'pretrain_rrdb' in opt['networks']['generator'].keys():
rrdb_state_dict = torch.load(opt['networks']['generator']['pretrain_rrdb'])
self.RRDB.load_state_dict(rrdb_state_dict, strict=True)
hidden_channels = opt_get(opt, ['networks', 'generator','flow', 'hidden_channels'])
hidden_channels = hidden_channels or 64
self.RRDB_training = True # Default is true
train_RRDB_delay = opt_get(self.opt, ['network_G', 'train_RRDB_delay'])
set_RRDB_to_train = False
if set_RRDB_to_train:
self.set_rrdb_training(True)
train_RRDB_delay = opt_get(self.opt, ['networks', 'generator','train_RRDB_delay'])
self.RRDB_training = False
self.flowUpsamplerNet = \
FlowUpsamplerNet((160, 160, 3), hidden_channels, K,
flow_coupling=opt['network_G']['flow']['coupling'], opt=opt)
flow_coupling=opt['networks']['generator']['flow']['coupling'], opt=opt)
self.i = 0
def set_rrdb_training(self, trainable):
if self.RRDB_training != trainable:
for p in self.RRDB.parameters():
p.requires_grad = trainable
self.RRDB_training = trainable
return True
return False
def get_random_z(self, heat, seed=None, batch_size=1, lr_shape=None, device='cuda'):
if seed: torch.manual_seed(seed)
if opt_get(self.opt, ['networks', 'generator', 'flow', 'split', 'enable']):
C = self.flowUpsamplerNet.C
H = int(self.opt['scale'] * lr_shape[2] // self.flowUpsamplerNet.scaleH)
W = int(self.opt['scale'] * lr_shape[3] // self.flowUpsamplerNet.scaleW)
size = (batch_size, C, H, W)
if heat == 0:
z = torch.zeros(size)
else:
z = torch.normal(mean=0, std=heat, size=size)
else:
L = opt_get(self.opt, ['networks', 'generator', 'flow', 'L']) or 3
fac = 2 ** (L - 3)
z_size = int(self.lr_size // (2 ** (L - 3)))
z = torch.normal(mean=0, std=heat, size=(batch_size, 3 * 8 * 8 * fac * fac, z_size, z_size))
return z.to(device)
def forward(self, gt=None, lr=None, z=None, eps_std=None, reverse=False, epses=None, reverse_with_grad=False,
lr_enc=None,
@ -48,14 +61,10 @@ class SRFlowNet(nn.Module):
return self.normal_flow(gt, lr, epses=epses, lr_enc=lr_enc, add_gt_noise=add_gt_noise, step=step,
y_onehot=y_label)
else:
# assert lr.shape[0] == 1
assert lr.shape[1] == 3
# assert lr.shape[2] == 20
# assert lr.shape[3] == 20
# assert z.shape[0] == 1
# assert z.shape[1] == 3 * 8 * 8
# assert z.shape[2] == 20
# assert z.shape[3] == 20
if z is None:
# Synthesize it.
z = self.get_random_z(eps_std, batch_size=lr.shape[0], lr_shape=lr.shape, device=lr.device)
if reverse_with_grad:
return self.reverse_flow(lr, z, y_onehot=y_label, eps_std=eps_std, epses=epses, lr_enc=lr_enc,
add_gt_noise=add_gt_noise)
@ -66,7 +75,11 @@ class SRFlowNet(nn.Module):
def normal_flow(self, gt, lr, y_onehot=None, epses=None, lr_enc=None, add_gt_noise=True, step=None):
if lr_enc is None:
lr_enc = self.rrdbPreprocessing(lr)
if self.RRDB_training:
lr_enc = self.rrdbPreprocessing(lr)
else:
with torch.no_grad():
lr_enc = self.rrdbPreprocessing(lr)
logdet = torch.zeros_like(gt[:, 0, 0, 0])
pixels = thops.pixels(gt)
@ -75,7 +88,7 @@ class SRFlowNet(nn.Module):
if add_gt_noise:
# Setup
noiseQuant = opt_get(self.opt, ['network_G', 'flow', 'augmentation', 'noiseQuant'], True)
noiseQuant = opt_get(self.opt, ['networks', 'generator','flow', 'augmentation', 'noiseQuant'], True)
if noiseQuant:
z = z + ((torch.rand(z.shape, device=z.device) - 0.5) / self.quant)
logdet = logdet + float(-np.log(self.quant) * pixels)
@ -101,11 +114,11 @@ class SRFlowNet(nn.Module):
def rrdbPreprocessing(self, lr):
rrdbResults = self.RRDB(lr, get_steps=True)
block_idxs = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or []
block_idxs = opt_get(self.opt, ['networks', 'generator','flow', 'stackRRDB', 'blocks']) or []
if len(block_idxs) > 0:
concat = torch.cat([rrdbResults["block_{}".format(idx)] for idx in block_idxs], dim=1)
if opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'concat']) or False:
if opt_get(self.opt, ['networks', 'generator','flow', 'stackRRDB', 'concat']) or False:
keys = ['last_lr_fea', 'fea_up1', 'fea_up2', 'fea_up4']
if 'fea_up0' in rrdbResults.keys():
keys.append('fea_up0')
@ -134,7 +147,11 @@ class SRFlowNet(nn.Module):
logdet = logdet - float(-np.log(self.quant) * pixels)
if lr_enc is None:
lr_enc = self.rrdbPreprocessing(lr)
if self.RRDB_training:
lr_enc = self.rrdbPreprocessing(lr)
else:
with torch.no_grad():
lr_enc = self.rrdbPreprocessing(lr)
x, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, z=z, eps_std=eps_std, reverse=True, epses=epses,
logdet=logdet)

View File

@ -1,9 +1,9 @@
import torch
from torch import nn as nn
from models.modules import thops
from models.modules.FlowStep import FlowStep
from models.modules.flow import Conv2dZeros, GaussianDiag
from models.archs.srflow_orig import thops
from models.archs.srflow_orig.FlowStep import FlowStep
from models.archs.srflow_orig.flow import Conv2dZeros, GaussianDiag
from utils.util import opt_get

View File

@ -3,7 +3,7 @@ import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from models.modules.FlowActNorms import ActNorm2d
from models.archs.srflow_orig.FlowActNorms import ActNorm2d
from . import thops

View File

@ -105,12 +105,25 @@ class BaseModel():
if 'state_dict' in load_net:
load_net = load_net['state_dict']
is_srflow = False
load_net_clean = OrderedDict() # remove unnecessary 'module.'
for k, v in load_net.items():
if k.startswith('module.'):
load_net_clean[k[7:]] = v
if k.startswith('generator'): # Hack to fix ESRGAN pretrained model.
load_net_clean[k[10:]] = v
if 'RRDB_trunk' in k or is_srflow: # Hacks to fix SRFlow imports, which uses some strange RDB names.
is_srflow = True
fixed_key = k.replace('RRDB_trunk', 'body')
if '.RDB' in fixed_key:
fixed_key = fixed_key.replace('.RDB', '.rdb')
elif '.upconv' in fixed_key:
fixed_key = fixed_key.replace('.upconv', '.conv_up')
elif '.trunk_conv' in fixed_key:
fixed_key = fixed_key.replace('.trunk_conv', '.conv_body')
elif '.HRconv' in fixed_key:
fixed_key = fixed_key.replace('.HRconv', '.conv_hr')
load_net_clean[fixed_key] = v
else:
load_net_clean[k] = v
network.load_state_dict(load_net_clean, strict=strict)

View File

@ -28,11 +28,7 @@ from models.archs.teco_resgen import TecoGen
logger = logging.getLogger('base')
# Generator
def define_G(opt, net_key='network_G', scale=None):
if net_key is not None:
opt_net = opt[net_key]
else:
opt_net = opt
def define_G(opt, opt_net, scale=None):
if scale is None:
scale = opt['scale']
which_model = opt_net['which_model_G']
@ -141,6 +137,17 @@ def define_G(opt, net_key='network_G', scale=None):
netG = stylegan2.StyleGan2GeneratorWithLatent(image_size=opt_net['image_size'], latent_dim=opt_net['latent_dim'],
style_depth=opt_net['style_depth'], structure_input=is_structured,
attn_layers=attn)
elif which_model == 'srflow':
from models.archs.srflow import SRFlow_arch
netG = SRFlow_arch.SRFlowNet(in_nc=3, out_nc=3, nf=opt_net['nf'], nb=opt_net['nb'],
quant=opt_net['quant'], flow_block_maps=opt_net['rrdb_block_maps'],
noise_quant=opt_net['noise_quant'], hidden_channels=opt_net['nf'],
K=opt_net['K'], L=opt_net['L'], train_rrdb_at_step=opt_net['rrdb_train_step'],
hr_img_shape=opt_net['hr_shape'], scale=opt_net['scale'])
elif which_model == 'srflow_orig':
from models.archs.srflow_orig import SRFlowNet_arch
netG = SRFlowNet_arch.SRFlowNet(in_nc=3, out_nc=3, nf=opt_net['nf'], nb=opt_net['nb'], scale=opt['scale'],
K=opt_net['K'], opt=opt)
else:
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
return netG

View File

@ -157,6 +157,8 @@ class AddNoiseInjector(Injector):
scale = state[self.opt['scale']]
else:
scale = self.opt['scale']
if scale is None:
scale = 1
ref = state[self.opt['in']]
if self.mode == 'normal':

View File

@ -139,7 +139,8 @@ class ConfigurableStep(Module):
continue
# Don't do injections tagged with 'after' or 'before' when we are out of spec.
if 'after' in inj.opt.keys() and self.env['step'] < inj.opt['after'] or \
'before' in inj.opt.keys() and self.env['step'] > inj.opt['before']:
'before' in inj.opt.keys() and self.env['step'] > inj.opt['before'] or \
'every' in inj.opt.keys() and self.env['step'] % inj.opt['every'] != 0:
continue
injected = inj(local_state)
local_state.update(injected)

View File

@ -291,7 +291,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_srg2_classic_proper_disc.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_srflow.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()

View File

@ -291,7 +291,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_srg2_grad_penalty.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_srg2_classic_proper_disc.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()