srflow_orig integration

This commit is contained in:
James Betker 2020-11-19 23:47:24 -07:00
parent f80acfcab6
commit 5ccdbcefe3
16 changed files with 239 additions and 122 deletions

View File

@ -58,7 +58,7 @@ class ExtensibleTrainer(BaseModel):
new_net = None new_net = None
if net['type'] == 'generator': if net['type'] == 'generator':
if new_net is None: if new_net is None:
new_net = networks.define_G(net, None, opt['scale']).to(self.device) new_net = networks.define_G(opt, net, opt['scale']).to(self.device)
self.netsG[name] = new_net self.netsG[name] = new_net
elif net['type'] == 'discriminator': elif net['type'] == 'discriminator':
if new_net is None: if new_net is None:

View File

@ -1,7 +1,7 @@
import torch import torch
from torch import nn as nn from torch import nn as nn
from models.modules import thops from models.archs.srflow_orig import thops
class _ActNorm(nn.Module): class _ActNorm(nn.Module):

View File

@ -1,8 +1,8 @@
import torch import torch
from torch import nn as nn from torch import nn as nn
from models.modules import thops from models.archs.srflow_orig import thops
from models.modules.flow import Conv2d, Conv2dZeros from models.archs.srflow_orig.flow import Conv2d, Conv2dZeros
from utils.util import opt_get from utils.util import opt_get
@ -15,10 +15,10 @@ class CondAffineSeparatedAndCond(nn.Module):
self.kernel_hidden = 1 self.kernel_hidden = 1
self.affine_eps = 0.0001 self.affine_eps = 0.0001
self.n_hidden_layers = 1 self.n_hidden_layers = 1
hidden_channels = opt_get(opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'hidden_channels']) hidden_channels = opt_get(opt, ['networks', 'generator','flow', 'CondAffineSeparatedAndCond', 'hidden_channels'])
self.hidden_channels = 64 if hidden_channels is None else hidden_channels self.hidden_channels = 64 if hidden_channels is None else hidden_channels
self.affine_eps = opt_get(opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'eps'], 0.0001) self.affine_eps = opt_get(opt, ['networks', 'generator','flow', 'CondAffineSeparatedAndCond', 'eps'], 0.0001)
self.channels_for_nn = self.in_channels // 2 self.channels_for_nn = self.in_channels // 2
self.channels_for_co = self.in_channels - self.channels_for_nn self.channels_for_co = self.in_channels - self.channels_for_nn

View File

@ -1,9 +1,8 @@
import torch import torch
from torch import nn as nn from torch import nn as nn
import models.modules import models.archs.srflow_orig.Permutations
import models.modules.Permutations from models.archs.srflow_orig import flow, thops, FlowAffineCouplingsAblation
from models.modules import flow, thops, FlowAffineCouplingsAblation
from utils.util import opt_get from utils.util import opt_get
@ -47,16 +46,16 @@ class FlowStep(nn.Module):
self.acOpt = acOpt self.acOpt = acOpt
# 1. actnorm # 1. actnorm
self.actnorm = models.modules.FlowActNorms.ActNorm2d(in_channels, actnorm_scale) self.actnorm = models.archs.srflow_orig.FlowActNorms.ActNorm2d(in_channels, actnorm_scale)
# 2. permute # 2. permute
if flow_permutation == "invconv": if flow_permutation == "invconv":
self.invconv = models.modules.Permutations.InvertibleConv1x1( self.invconv = models.archs.srflow_orig.Permutations.InvertibleConv1x1(
in_channels, LU_decomposed=LU_decomposed) in_channels, LU_decomposed=LU_decomposed)
# 3. coupling # 3. coupling
if flow_coupling == "CondAffineSeparatedAndCond": if flow_coupling == "CondAffineSeparatedAndCond":
self.affine = models.modules.FlowAffineCouplingsAblation.CondAffineSeparatedAndCond(in_channels=in_channels, self.affine = models.archs.srflow_orig.FlowAffineCouplingsAblation.CondAffineSeparatedAndCond(in_channels=in_channels,
opt=opt) opt=opt)
elif flow_coupling == "noCoupling": elif flow_coupling == "noCoupling":
pass pass

View File

@ -2,11 +2,11 @@ import numpy as np
import torch import torch
from torch import nn as nn from torch import nn as nn
import models.modules.Split import models.archs.srflow_orig.Split
from models.modules import flow, thops from models.archs.srflow_orig import flow, thops
from models.modules.Split import Split2d from models.archs.srflow_orig.Split import Split2d
from models.modules.glow_arch import f_conv2d_bias from models.archs.srflow_orig.glow_arch import f_conv2d_bias
from models.modules.FlowStep import FlowStep from models.archs.srflow_orig.FlowStep import FlowStep
from utils.util import opt_get from utils.util import opt_get
@ -21,8 +21,8 @@ class FlowUpsamplerNet(nn.Module):
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
self.output_shapes = [] self.output_shapes = []
self.L = opt_get(opt, ['network_G', 'flow', 'L']) self.L = opt_get(opt, ['networks', 'generator','flow', 'L'])
self.K = opt_get(opt, ['network_G', 'flow', 'K']) self.K = opt_get(opt, ['networks', 'generator','flow', 'K'])
if isinstance(self.K, int): if isinstance(self.K, int):
self.K = [K for K in [K, ] * (self.L + 1)] self.K = [K for K in [K, ] * (self.L + 1)]
@ -60,11 +60,11 @@ class FlowUpsamplerNet(nn.Module):
affineInCh = self.get_affineInCh(opt_get) affineInCh = self.get_affineInCh(opt_get)
flow_permutation = self.get_flow_permutation(flow_permutation, opt) flow_permutation = self.get_flow_permutation(flow_permutation, opt)
normOpt = opt_get(opt, ['network_G', 'flow', 'norm']) normOpt = opt_get(opt, ['networks', 'generator','flow', 'norm'])
conditional_channels = {} conditional_channels = {}
n_rrdb = self.get_n_rrdb_channels(opt, opt_get) n_rrdb = self.get_n_rrdb_channels(opt, opt_get)
n_bypass_channels = opt_get(opt, ['network_G', 'flow', 'levelConditional', 'n_channels']) n_bypass_channels = opt_get(opt, ['networks', 'generator','flow', 'levelConditional', 'n_channels'])
conditional_channels[0] = n_rrdb conditional_channels[0] = n_rrdb
for level in range(1, self.L + 1): for level in range(1, self.L + 1):
# Level 1 gets conditionals from 2, 3, 4 => L - level # Level 1 gets conditionals from 2, 3, 4 => L - level
@ -88,7 +88,7 @@ class FlowUpsamplerNet(nn.Module):
# Split # Split
self.arch_split(H, W, level, self.L, opt, opt_get) self.arch_split(H, W, level, self.L, opt, opt_get)
if opt_get(opt, ['network_G', 'flow', 'split', 'enable']): if opt_get(opt, ['networks', 'generator','flow', 'split', 'enable']):
self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64 // 2 // 2) self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64 // 2 // 2)
else: else:
self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64) self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64)
@ -99,7 +99,7 @@ class FlowUpsamplerNet(nn.Module):
self.scaleW = 160 / W self.scaleW = 160 / W
def get_n_rrdb_channels(self, opt, opt_get): def get_n_rrdb_channels(self, opt, opt_get):
blocks = opt_get(opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) blocks = opt_get(opt, ['networks', 'generator','flow', 'stackRRDB', 'blocks'])
n_rrdb = 64 if blocks is None else (len(blocks) + 1) * 64 n_rrdb = 64 if blocks is None else (len(blocks) + 1) * 64
return n_rrdb return n_rrdb
@ -126,33 +126,33 @@ class FlowUpsamplerNet(nn.Module):
[-1, self.C, H, W]) [-1, self.C, H, W])
def get_condAffSetting(self, opt, opt_get): def get_condAffSetting(self, opt, opt_get):
condAff = opt_get(opt, ['network_G', 'flow', 'condAff']) or None condAff = opt_get(opt, ['networks', 'generator','flow', 'condAff']) or None
condAff = opt_get(opt, ['network_G', 'flow', 'condFtAffine']) or condAff condAff = opt_get(opt, ['networks', 'generator','flow', 'condFtAffine']) or condAff
return condAff return condAff
def arch_split(self, H, W, L, levels, opt, opt_get): def arch_split(self, H, W, L, levels, opt, opt_get):
correct_splits = opt_get(opt, ['network_G', 'flow', 'split', 'correct_splits'], False) correct_splits = opt_get(opt, ['networks', 'generator','flow', 'split', 'correct_splits'], False)
correction = 0 if correct_splits else 1 correction = 0 if correct_splits else 1
if opt_get(opt, ['network_G', 'flow', 'split', 'enable']) and L < levels - correction: if opt_get(opt, ['networks', 'generator','flow', 'split', 'enable']) and L < levels - correction:
logs_eps = opt_get(opt, ['network_G', 'flow', 'split', 'logs_eps']) or 0 logs_eps = opt_get(opt, ['networks', 'generator','flow', 'split', 'logs_eps']) or 0
consume_ratio = opt_get(opt, ['network_G', 'flow', 'split', 'consume_ratio']) or 0.5 consume_ratio = opt_get(opt, ['networks', 'generator','flow', 'split', 'consume_ratio']) or 0.5
position_name = get_position_name(H, self.opt['scale']) position_name = get_position_name(H, self.opt['scale'])
position = position_name if opt_get(opt, ['network_G', 'flow', 'split', 'conditional']) else None position = position_name if opt_get(opt, ['networks', 'generator','flow', 'split', 'conditional']) else None
cond_channels = opt_get(opt, ['network_G', 'flow', 'split', 'cond_channels']) cond_channels = opt_get(opt, ['networks', 'generator','flow', 'split', 'cond_channels'])
cond_channels = 0 if cond_channels is None else cond_channels cond_channels = 0 if cond_channels is None else cond_channels
t = opt_get(opt, ['network_G', 'flow', 'split', 'type'], 'Split2d') t = opt_get(opt, ['networks', 'generator','flow', 'split', 'type'], 'Split2d')
if t == 'Split2d': if t == 'Split2d':
split = models.modules.Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position, split = models.archs.srflow_orig.Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position,
cond_channels=cond_channels, consume_ratio=consume_ratio, opt=opt) cond_channels=cond_channels, consume_ratio=consume_ratio, opt=opt)
self.layers.append(split) self.layers.append(split)
self.output_shapes.append([-1, split.num_channels_pass, H, W]) self.output_shapes.append([-1, split.num_channels_pass, H, W])
self.C = split.num_channels_pass self.C = split.num_channels_pass
def arch_additionalFlowAffine(self, H, LU_decomposed, W, actnorm_scale, hidden_channels, opt): def arch_additionalFlowAffine(self, H, LU_decomposed, W, actnorm_scale, hidden_channels, opt):
if 'additionalFlowNoAffine' in opt['network_G']['flow']: if 'additionalFlowNoAffine' in opt['networks']['generator']['flow']:
n_additionalFlowNoAffine = int(opt['network_G']['flow']['additionalFlowNoAffine']) n_additionalFlowNoAffine = int(opt['networks']['generator']['flow']['additionalFlowNoAffine'])
for _ in range(n_additionalFlowNoAffine): for _ in range(n_additionalFlowNoAffine):
self.layers.append( self.layers.append(
FlowStep(in_channels=self.C, FlowStep(in_channels=self.C,
@ -171,11 +171,11 @@ class FlowUpsamplerNet(nn.Module):
return H, W return H, W
def get_flow_permutation(self, flow_permutation, opt): def get_flow_permutation(self, flow_permutation, opt):
flow_permutation = opt['network_G']['flow'].get('flow_permutation', 'invconv') flow_permutation = opt['networks']['generator']['flow'].get('flow_permutation', 'invconv')
return flow_permutation return flow_permutation
def get_affineInCh(self, opt_get): def get_affineInCh(self, opt_get):
affineInCh = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or [] affineInCh = opt_get(self.opt, ['networks', 'generator','flow', 'stackRRDB', 'blocks']) or []
affineInCh = (len(affineInCh) + 1) * 64 affineInCh = (len(affineInCh) + 1) * 64
return affineInCh return affineInCh
@ -204,7 +204,7 @@ class FlowUpsamplerNet(nn.Module):
level_conditionals = {} level_conditionals = {}
bypasses = {} bypasses = {}
L = opt_get(self.opt, ['network_G', 'flow', 'L']) L = opt_get(self.opt, ['networks', 'generator','flow', 'L'])
for level in range(1, L + 1): for level in range(1, L + 1):
bypasses[level] = torch.nn.functional.interpolate(gt, scale_factor=2 ** -level, mode='bilinear', align_corners=False) bypasses[level] = torch.nn.functional.interpolate(gt, scale_factor=2 ** -level, mode='bilinear', align_corners=False)
@ -255,7 +255,7 @@ class FlowUpsamplerNet(nn.Module):
# debug.imwrite("fl_fea", fl_fea) # debug.imwrite("fl_fea", fl_fea)
bypasses = {} bypasses = {}
level_conditionals = {} level_conditionals = {}
if not opt_get(self.opt, ['network_G', 'flow', 'levelConditional', 'conditional']) == True: if not opt_get(self.opt, ['networks', 'generator','flow', 'levelConditional', 'conditional']) == True:
for level in range(self.L + 1): for level in range(self.L + 1):
level_conditionals[level] = rrdbResults[self.levelToName[level]] level_conditionals[level] = rrdbResults[self.levelToName[level]]

View File

@ -3,7 +3,7 @@ import torch
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from models.modules import thops from models.archs.srflow_orig import thops
class InvertibleConv1x1(nn.Module): class InvertibleConv1x1(nn.Module):

View File

@ -2,70 +2,148 @@ import functools
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import models.modules.module_util as mutil import models.archs.srflow_orig.module_util as mutil
from models.archs.arch_util import default_init_weights, ConvGnSilu
from utils.util import opt_get from utils.util import opt_get
class ResidualDenseBlock_5C(nn.Module): class ResidualDenseBlock(nn.Module):
def __init__(self, nf=64, gc=32, bias=True): """Residual Dense Block.
super(ResidualDenseBlock_5C, self).__init__()
# gc: growth channel, i.e. intermediate channels Used in RRDB block in ESRGAN.
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) Args:
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) mid_channels (int): Channel number of intermediate features.
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) growth_channels (int): Channels for each growth.
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) """
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def __init__(self, mid_channels=64, growth_channels=32):
super(ResidualDenseBlock, self).__init__()
for i in range(5):
out_channels = mid_channels if i == 4 else growth_channels
self.add_module(
f'conv{i+1}',
nn.Conv2d(mid_channels + i * growth_channels, out_channels, 3,
1, 1))
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
for i in range(5):
default_init_weights(getattr(self, f'conv{i+1}'), 0.1)
# initialization
mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
def forward(self, x): def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
x1 = self.lrelu(self.conv1(x)) x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
# Emperically, we use 0.2 to scale the residual for better performance
return x5 * 0.2 + x return x5 * 0.2 + x
class RRDB(nn.Module): class RRDB(nn.Module):
'''Residual in Residual Dense Block''' """Residual in Residual Dense Block.
def __init__(self, nf, gc=32): Used in RRDB-Net in ESRGAN.
Args:
mid_channels (int): Channel number of intermediate features.
growth_channels (int): Channels for each growth.
"""
def __init__(self, mid_channels, growth_channels=32):
super(RRDB, self).__init__() super(RRDB, self).__init__()
self.RDB1 = ResidualDenseBlock_5C(nf, gc) self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels)
self.RDB2 = ResidualDenseBlock_5C(nf, gc) self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels)
self.RDB3 = ResidualDenseBlock_5C(nf, gc) self.rdb3 = ResidualDenseBlock(mid_channels, growth_channels)
def forward(self, x): def forward(self, x):
out = self.RDB1(x) """Forward function.
out = self.RDB2(out)
out = self.RDB3(out) Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
out = self.rdb1(x)
out = self.rdb2(out)
out = self.rdb3(out)
# Emperically, we use 0.2 to scale the residual for better performance
return out * 0.2 + x return out * 0.2 + x
class RRDBWithBypass(nn.Module):
"""Residual in Residual Dense Block.
Used in RRDB-Net in ESRGAN.
Args:
mid_channels (int): Channel number of intermediate features.
growth_channels (int): Channels for each growth.
"""
def __init__(self, mid_channels, growth_channels=32):
super(RRDBWithBypass, self).__init__()
self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels)
self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels)
self.rdb3 = ResidualDenseBlock(mid_channels, growth_channels)
self.bypass = nn.Sequential(ConvGnSilu(mid_channels*2, mid_channels, kernel_size=3, bias=True, activation=True, norm=True),
ConvGnSilu(mid_channels, mid_channels//2, kernel_size=3, bias=False, activation=True, norm=False),
ConvGnSilu(mid_channels//2, 1, kernel_size=3, bias=False, activation=False, norm=False),
nn.Sigmoid())
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
out = self.rdb1(x)
out = self.rdb2(out)
out = self.rdb3(out)
bypass = self.bypass(torch.cat([x, out], dim=1))
self.bypass_map = bypass.detach().clone()
# Empirically, we use 0.2 to scale the residual for better performance
return out * 0.2 * bypass + x
class RRDBNet(nn.Module): class RRDBNet(nn.Module):
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, opt=None): def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, opt=None):
self.opt = opt self.opt = opt
super(RRDBNet, self).__init__() super(RRDBNet, self).__init__()
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
bypass = opt_get(self.opt, ['networks', 'generator', 'rrdb_bypass'])
if bypass:
RRDB_block_f = functools.partial(RRDBWithBypass, mid_channels=nf, growth_channels=gc)
else:
RRDB_block_f = functools.partial(RRDB, mid_channels=nf, growth_channels=gc)
self.scale = scale self.scale = scale
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
self.RRDB_trunk = mutil.make_layer(RRDB_block_f, nb) self.body = mutil.make_layer(RRDB_block_f, nb)
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv_body = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
#### upsampling #### upsampling
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv_up1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv_up2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
if self.scale >= 8: if self.scale >= 8:
self.upconv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv_up3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
if self.scale >= 16: if self.scale >= 16:
self.upconv4 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv_up4 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
if self.scale >= 32: if self.scale >= 32:
self.upconv5 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv_up5 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv_hr = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
@ -73,23 +151,23 @@ class RRDBNet(nn.Module):
def forward(self, x, get_steps=False): def forward(self, x, get_steps=False):
fea = self.conv_first(x) fea = self.conv_first(x)
block_idxs = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or [] block_idxs = opt_get(self.opt, ['networks', 'generator','flow', 'stackRRDB', 'blocks']) or []
block_results = {} block_results = {}
for idx, m in enumerate(self.RRDB_trunk.children()): for idx, m in enumerate(self.body.children()):
fea = m(fea) fea = m(fea)
for b in block_idxs: for b in block_idxs:
if b == idx: if b == idx:
block_results["block_{}".format(idx)] = fea block_results["block_{}".format(idx)] = fea
trunk = self.trunk_conv(fea) trunk = self.conv_body(fea)
last_lr_fea = fea + trunk last_lr_fea = fea + trunk
fea_up2 = self.upconv1(F.interpolate(last_lr_fea, scale_factor=2, mode='nearest')) fea_up2 = self.conv_up1(F.interpolate(last_lr_fea, scale_factor=2, mode='nearest'))
fea = self.lrelu(fea_up2) fea = self.lrelu(fea_up2)
fea_up4 = self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')) fea_up4 = self.conv_up2(F.interpolate(fea, scale_factor=2, mode='nearest'))
fea = self.lrelu(fea_up4) fea = self.lrelu(fea_up4)
fea_up8 = None fea_up8 = None
@ -97,16 +175,16 @@ class RRDBNet(nn.Module):
fea_up32 = None fea_up32 = None
if self.scale >= 8: if self.scale >= 8:
fea_up8 = self.upconv3(F.interpolate(fea, scale_factor=2, mode='nearest')) fea_up8 = self.conv_up3(F.interpolate(fea, scale_factor=2, mode='nearest'))
fea = self.lrelu(fea_up8) fea = self.lrelu(fea_up8)
if self.scale >= 16: if self.scale >= 16:
fea_up16 = self.upconv4(F.interpolate(fea, scale_factor=2, mode='nearest')) fea_up16 = self.conv_up4(F.interpolate(fea, scale_factor=2, mode='nearest'))
fea = self.lrelu(fea_up16) fea = self.lrelu(fea_up16)
if self.scale >= 32: if self.scale >= 32:
fea_up32 = self.upconv5(F.interpolate(fea, scale_factor=2, mode='nearest')) fea_up32 = self.conv_up5(F.interpolate(fea, scale_factor=2, mode='nearest'))
fea = self.lrelu(fea_up32) fea = self.lrelu(fea_up32)
out = self.conv_last(self.lrelu(self.HRconv(fea))) out = self.conv_last(self.lrelu(self.conv_hr(fea)))
results = {'last_lr_fea': last_lr_fea, results = {'last_lr_fea': last_lr_fea,
'fea_up1': last_lr_fea, 'fea_up1': last_lr_fea,
@ -117,10 +195,10 @@ class RRDBNet(nn.Module):
'fea_up32': fea_up32, 'fea_up32': fea_up32,
'out': out} 'out': out}
fea_up0_en = opt_get(self.opt, ['network_G', 'flow', 'fea_up0']) or False fea_up0_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up0']) or False
if fea_up0_en: if fea_up0_en:
results['fea_up0'] = F.interpolate(last_lr_fea, scale_factor=1/2, mode='bilinear', align_corners=False, recompute_scale_factor=True) results['fea_up0'] = F.interpolate(last_lr_fea, scale_factor=1/2, mode='bilinear', align_corners=False, recompute_scale_factor=True)
fea_upn1_en = opt_get(self.opt, ['network_G', 'flow', 'fea_up-1']) or False fea_upn1_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up-1']) or False
if fea_upn1_en: if fea_upn1_en:
results['fea_up-1'] = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True) results['fea_up-1'] = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True)

View File

@ -4,10 +4,10 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np import numpy as np
from models.modules.RRDBNet_arch import RRDBNet from models.archs.srflow_orig.RRDBNet_arch import RRDBNet
from models.modules.FlowUpsamplerNet import FlowUpsamplerNet from models.archs.srflow_orig.FlowUpsamplerNet import FlowUpsamplerNet
import models.modules.thops as thops import models.archs.srflow_orig.thops as thops
import models.modules.flow as flow import models.archs.srflow_orig.flow as flow
from utils.util import opt_get from utils.util import opt_get
@ -19,27 +19,40 @@ class SRFlowNet(nn.Module):
self.quant = 255 if opt_get(opt, ['datasets', 'train', 'quant']) is \ self.quant = 255 if opt_get(opt, ['datasets', 'train', 'quant']) is \
None else opt_get(opt, ['datasets', 'train', 'quant']) None else opt_get(opt, ['datasets', 'train', 'quant'])
self.RRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, opt) self.RRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, opt)
hidden_channels = opt_get(opt, ['network_G', 'flow', 'hidden_channels']) if 'pretrain_rrdb' in opt['networks']['generator'].keys():
rrdb_state_dict = torch.load(opt['networks']['generator']['pretrain_rrdb'])
self.RRDB.load_state_dict(rrdb_state_dict, strict=True)
hidden_channels = opt_get(opt, ['networks', 'generator','flow', 'hidden_channels'])
hidden_channels = hidden_channels or 64 hidden_channels = hidden_channels or 64
self.RRDB_training = True # Default is true self.RRDB_training = True # Default is true
train_RRDB_delay = opt_get(self.opt, ['network_G', 'train_RRDB_delay']) train_RRDB_delay = opt_get(self.opt, ['networks', 'generator','train_RRDB_delay'])
set_RRDB_to_train = False self.RRDB_training = False
if set_RRDB_to_train:
self.set_rrdb_training(True)
self.flowUpsamplerNet = \ self.flowUpsamplerNet = \
FlowUpsamplerNet((160, 160, 3), hidden_channels, K, FlowUpsamplerNet((160, 160, 3), hidden_channels, K,
flow_coupling=opt['network_G']['flow']['coupling'], opt=opt) flow_coupling=opt['networks']['generator']['flow']['coupling'], opt=opt)
self.i = 0 self.i = 0
def set_rrdb_training(self, trainable): def get_random_z(self, heat, seed=None, batch_size=1, lr_shape=None, device='cuda'):
if self.RRDB_training != trainable: if seed: torch.manual_seed(seed)
for p in self.RRDB.parameters(): if opt_get(self.opt, ['networks', 'generator', 'flow', 'split', 'enable']):
p.requires_grad = trainable C = self.flowUpsamplerNet.C
self.RRDB_training = trainable H = int(self.opt['scale'] * lr_shape[2] // self.flowUpsamplerNet.scaleH)
return True W = int(self.opt['scale'] * lr_shape[3] // self.flowUpsamplerNet.scaleW)
return False
size = (batch_size, C, H, W)
if heat == 0:
z = torch.zeros(size)
else:
z = torch.normal(mean=0, std=heat, size=size)
else:
L = opt_get(self.opt, ['networks', 'generator', 'flow', 'L']) or 3
fac = 2 ** (L - 3)
z_size = int(self.lr_size // (2 ** (L - 3)))
z = torch.normal(mean=0, std=heat, size=(batch_size, 3 * 8 * 8 * fac * fac, z_size, z_size))
return z.to(device)
def forward(self, gt=None, lr=None, z=None, eps_std=None, reverse=False, epses=None, reverse_with_grad=False, def forward(self, gt=None, lr=None, z=None, eps_std=None, reverse=False, epses=None, reverse_with_grad=False,
lr_enc=None, lr_enc=None,
@ -48,14 +61,10 @@ class SRFlowNet(nn.Module):
return self.normal_flow(gt, lr, epses=epses, lr_enc=lr_enc, add_gt_noise=add_gt_noise, step=step, return self.normal_flow(gt, lr, epses=epses, lr_enc=lr_enc, add_gt_noise=add_gt_noise, step=step,
y_onehot=y_label) y_onehot=y_label)
else: else:
# assert lr.shape[0] == 1
assert lr.shape[1] == 3 assert lr.shape[1] == 3
# assert lr.shape[2] == 20 if z is None:
# assert lr.shape[3] == 20 # Synthesize it.
# assert z.shape[0] == 1 z = self.get_random_z(eps_std, batch_size=lr.shape[0], lr_shape=lr.shape, device=lr.device)
# assert z.shape[1] == 3 * 8 * 8
# assert z.shape[2] == 20
# assert z.shape[3] == 20
if reverse_with_grad: if reverse_with_grad:
return self.reverse_flow(lr, z, y_onehot=y_label, eps_std=eps_std, epses=epses, lr_enc=lr_enc, return self.reverse_flow(lr, z, y_onehot=y_label, eps_std=eps_std, epses=epses, lr_enc=lr_enc,
add_gt_noise=add_gt_noise) add_gt_noise=add_gt_noise)
@ -66,6 +75,10 @@ class SRFlowNet(nn.Module):
def normal_flow(self, gt, lr, y_onehot=None, epses=None, lr_enc=None, add_gt_noise=True, step=None): def normal_flow(self, gt, lr, y_onehot=None, epses=None, lr_enc=None, add_gt_noise=True, step=None):
if lr_enc is None: if lr_enc is None:
if self.RRDB_training:
lr_enc = self.rrdbPreprocessing(lr)
else:
with torch.no_grad():
lr_enc = self.rrdbPreprocessing(lr) lr_enc = self.rrdbPreprocessing(lr)
logdet = torch.zeros_like(gt[:, 0, 0, 0]) logdet = torch.zeros_like(gt[:, 0, 0, 0])
@ -75,7 +88,7 @@ class SRFlowNet(nn.Module):
if add_gt_noise: if add_gt_noise:
# Setup # Setup
noiseQuant = opt_get(self.opt, ['network_G', 'flow', 'augmentation', 'noiseQuant'], True) noiseQuant = opt_get(self.opt, ['networks', 'generator','flow', 'augmentation', 'noiseQuant'], True)
if noiseQuant: if noiseQuant:
z = z + ((torch.rand(z.shape, device=z.device) - 0.5) / self.quant) z = z + ((torch.rand(z.shape, device=z.device) - 0.5) / self.quant)
logdet = logdet + float(-np.log(self.quant) * pixels) logdet = logdet + float(-np.log(self.quant) * pixels)
@ -101,11 +114,11 @@ class SRFlowNet(nn.Module):
def rrdbPreprocessing(self, lr): def rrdbPreprocessing(self, lr):
rrdbResults = self.RRDB(lr, get_steps=True) rrdbResults = self.RRDB(lr, get_steps=True)
block_idxs = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or [] block_idxs = opt_get(self.opt, ['networks', 'generator','flow', 'stackRRDB', 'blocks']) or []
if len(block_idxs) > 0: if len(block_idxs) > 0:
concat = torch.cat([rrdbResults["block_{}".format(idx)] for idx in block_idxs], dim=1) concat = torch.cat([rrdbResults["block_{}".format(idx)] for idx in block_idxs], dim=1)
if opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'concat']) or False: if opt_get(self.opt, ['networks', 'generator','flow', 'stackRRDB', 'concat']) or False:
keys = ['last_lr_fea', 'fea_up1', 'fea_up2', 'fea_up4'] keys = ['last_lr_fea', 'fea_up1', 'fea_up2', 'fea_up4']
if 'fea_up0' in rrdbResults.keys(): if 'fea_up0' in rrdbResults.keys():
keys.append('fea_up0') keys.append('fea_up0')
@ -134,6 +147,10 @@ class SRFlowNet(nn.Module):
logdet = logdet - float(-np.log(self.quant) * pixels) logdet = logdet - float(-np.log(self.quant) * pixels)
if lr_enc is None: if lr_enc is None:
if self.RRDB_training:
lr_enc = self.rrdbPreprocessing(lr)
else:
with torch.no_grad():
lr_enc = self.rrdbPreprocessing(lr) lr_enc = self.rrdbPreprocessing(lr)
x, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, z=z, eps_std=eps_std, reverse=True, epses=epses, x, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, z=z, eps_std=eps_std, reverse=True, epses=epses,

View File

@ -1,9 +1,9 @@
import torch import torch
from torch import nn as nn from torch import nn as nn
from models.modules import thops from models.archs.srflow_orig import thops
from models.modules.FlowStep import FlowStep from models.archs.srflow_orig.FlowStep import FlowStep
from models.modules.flow import Conv2dZeros, GaussianDiag from models.archs.srflow_orig.flow import Conv2dZeros, GaussianDiag
from utils.util import opt_get from utils.util import opt_get

View File

@ -3,7 +3,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np import numpy as np
from models.modules.FlowActNorms import ActNorm2d from models.archs.srflow_orig.FlowActNorms import ActNorm2d
from . import thops from . import thops

View File

@ -105,12 +105,25 @@ class BaseModel():
if 'state_dict' in load_net: if 'state_dict' in load_net:
load_net = load_net['state_dict'] load_net = load_net['state_dict']
is_srflow = False
load_net_clean = OrderedDict() # remove unnecessary 'module.' load_net_clean = OrderedDict() # remove unnecessary 'module.'
for k, v in load_net.items(): for k, v in load_net.items():
if k.startswith('module.'): if k.startswith('module.'):
load_net_clean[k[7:]] = v load_net_clean[k[7:]] = v
if k.startswith('generator'): # Hack to fix ESRGAN pretrained model. if k.startswith('generator'): # Hack to fix ESRGAN pretrained model.
load_net_clean[k[10:]] = v load_net_clean[k[10:]] = v
if 'RRDB_trunk' in k or is_srflow: # Hacks to fix SRFlow imports, which uses some strange RDB names.
is_srflow = True
fixed_key = k.replace('RRDB_trunk', 'body')
if '.RDB' in fixed_key:
fixed_key = fixed_key.replace('.RDB', '.rdb')
elif '.upconv' in fixed_key:
fixed_key = fixed_key.replace('.upconv', '.conv_up')
elif '.trunk_conv' in fixed_key:
fixed_key = fixed_key.replace('.trunk_conv', '.conv_body')
elif '.HRconv' in fixed_key:
fixed_key = fixed_key.replace('.HRconv', '.conv_hr')
load_net_clean[fixed_key] = v
else: else:
load_net_clean[k] = v load_net_clean[k] = v
network.load_state_dict(load_net_clean, strict=strict) network.load_state_dict(load_net_clean, strict=strict)

View File

@ -28,11 +28,7 @@ from models.archs.teco_resgen import TecoGen
logger = logging.getLogger('base') logger = logging.getLogger('base')
# Generator # Generator
def define_G(opt, net_key='network_G', scale=None): def define_G(opt, opt_net, scale=None):
if net_key is not None:
opt_net = opt[net_key]
else:
opt_net = opt
if scale is None: if scale is None:
scale = opt['scale'] scale = opt['scale']
which_model = opt_net['which_model_G'] which_model = opt_net['which_model_G']
@ -141,6 +137,17 @@ def define_G(opt, net_key='network_G', scale=None):
netG = stylegan2.StyleGan2GeneratorWithLatent(image_size=opt_net['image_size'], latent_dim=opt_net['latent_dim'], netG = stylegan2.StyleGan2GeneratorWithLatent(image_size=opt_net['image_size'], latent_dim=opt_net['latent_dim'],
style_depth=opt_net['style_depth'], structure_input=is_structured, style_depth=opt_net['style_depth'], structure_input=is_structured,
attn_layers=attn) attn_layers=attn)
elif which_model == 'srflow':
from models.archs.srflow import SRFlow_arch
netG = SRFlow_arch.SRFlowNet(in_nc=3, out_nc=3, nf=opt_net['nf'], nb=opt_net['nb'],
quant=opt_net['quant'], flow_block_maps=opt_net['rrdb_block_maps'],
noise_quant=opt_net['noise_quant'], hidden_channels=opt_net['nf'],
K=opt_net['K'], L=opt_net['L'], train_rrdb_at_step=opt_net['rrdb_train_step'],
hr_img_shape=opt_net['hr_shape'], scale=opt_net['scale'])
elif which_model == 'srflow_orig':
from models.archs.srflow_orig import SRFlowNet_arch
netG = SRFlowNet_arch.SRFlowNet(in_nc=3, out_nc=3, nf=opt_net['nf'], nb=opt_net['nb'], scale=opt['scale'],
K=opt_net['K'], opt=opt)
else: else:
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model)) raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
return netG return netG

View File

@ -157,6 +157,8 @@ class AddNoiseInjector(Injector):
scale = state[self.opt['scale']] scale = state[self.opt['scale']]
else: else:
scale = self.opt['scale'] scale = self.opt['scale']
if scale is None:
scale = 1
ref = state[self.opt['in']] ref = state[self.opt['in']]
if self.mode == 'normal': if self.mode == 'normal':

View File

@ -139,7 +139,8 @@ class ConfigurableStep(Module):
continue continue
# Don't do injections tagged with 'after' or 'before' when we are out of spec. # Don't do injections tagged with 'after' or 'before' when we are out of spec.
if 'after' in inj.opt.keys() and self.env['step'] < inj.opt['after'] or \ if 'after' in inj.opt.keys() and self.env['step'] < inj.opt['after'] or \
'before' in inj.opt.keys() and self.env['step'] > inj.opt['before']: 'before' in inj.opt.keys() and self.env['step'] > inj.opt['before'] or \
'every' in inj.opt.keys() and self.env['step'] % inj.opt['every'] != 0:
continue continue
injected = inj(local_state) injected = inj(local_state)
local_state.update(injected) local_state.update(injected)

View File

@ -291,7 +291,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_srg2_classic_proper_disc.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_srflow.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()

View File

@ -291,7 +291,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_srg2_grad_penalty.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_srg2_classic_proper_disc.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()