"Clean up" SRFlow

This commit is contained in:
James Betker 2020-11-19 21:42:24 -07:00
parent d7877d0a36
commit 1e0d7be3ce
7 changed files with 124 additions and 163 deletions

View File

@ -3,11 +3,10 @@ from torch import nn as nn
from models.archs.srflow import thops from models.archs.srflow import thops
from models.archs.srflow.flow import Conv2d, Conv2dZeros from models.archs.srflow.flow import Conv2d, Conv2dZeros
from utils.util import opt_get
class CondAffineSeparatedAndCond(nn.Module): class CondAffineSeparatedAndCond(nn.Module):
def __init__(self, in_channels, opt): def __init__(self, in_channels, hidden_channels=64, affine_eps=.00001):
super().__init__() super().__init__()
self.need_features = True self.need_features = True
self.in_channels = in_channels self.in_channels = in_channels
@ -15,10 +14,8 @@ class CondAffineSeparatedAndCond(nn.Module):
self.kernel_hidden = 1 self.kernel_hidden = 1
self.affine_eps = 0.0001 self.affine_eps = 0.0001
self.n_hidden_layers = 1 self.n_hidden_layers = 1
hidden_channels = opt_get(opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'hidden_channels']) self.hidden_channels = hidden_channels
self.hidden_channels = 64 if hidden_channels is None else hidden_channels self.affine_eps = affine_eps
self.affine_eps = opt_get(opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'eps'], 0.0001)
self.channels_for_nn = self.in_channels // 2 self.channels_for_nn = self.in_channels // 2
self.channels_for_co = self.in_channels - self.channels_for_nn self.channels_for_co = self.in_channels - self.channels_for_nn

View File

@ -1,10 +1,7 @@
import torch import torch
from torch import nn as nn from torch import nn as nn
import models.archs.srflow from models.archs.srflow import flow, thops, FlowAffineCouplingsAblation, FlowActNorms, Permutations
import models.archs.srflow.Permutations
from models.archs.srflow import flow, thops, FlowAffineCouplingsAblation
from utils.util import opt_get
def getConditional(rrdbResults, position): def getConditional(rrdbResults, position):
@ -28,7 +25,7 @@ class FlowStep(nn.Module):
def __init__(self, in_channels, hidden_channels, def __init__(self, in_channels, hidden_channels,
actnorm_scale=1.0, flow_permutation="invconv", flow_coupling="additive", actnorm_scale=1.0, flow_permutation="invconv", flow_coupling="additive",
LU_decomposed=False, opt=None, image_injector=None, idx=None, acOpt=None, normOpt=None, in_shape=None, LU_decomposed=False, image_injector=None, idx=None, acOpt=None, normOpt=None, in_shape=None,
position=None): position=None):
# check configures # check configures
assert flow_permutation in FlowStep.FlowPermutation, \ assert flow_permutation in FlowStep.FlowPermutation, \
@ -47,17 +44,16 @@ class FlowStep(nn.Module):
self.acOpt = acOpt self.acOpt = acOpt
# 1. actnorm # 1. actnorm
self.actnorm = models.modules.FlowActNorms.ActNorm2d(in_channels, actnorm_scale) self.actnorm = FlowActNorms.ActNorm2d(in_channels, actnorm_scale)
# 2. permute # 2. permute
if flow_permutation == "invconv": if flow_permutation == "invconv":
self.invconv = models.modules.Permutations.InvertibleConv1x1( self.invconv = Permutations.InvertibleConv1x1(
in_channels, LU_decomposed=LU_decomposed) in_channels, LU_decomposed=LU_decomposed)
# 3. coupling # 3. coupling
if flow_coupling == "CondAffineSeparatedAndCond": if flow_coupling == "CondAffineSeparatedAndCond":
self.affine = models.modules.FlowAffineCouplingsAblation.CondAffineSeparatedAndCond(in_channels=in_channels, self.affine = FlowAffineCouplingsAblation.CondAffineSeparatedAndCond(in_channels=in_channels)
opt=opt)
elif flow_coupling == "noCoupling": elif flow_coupling == "noCoupling":
pass pass
else: else:

View File

@ -3,34 +3,39 @@ import torch
from torch import nn as nn from torch import nn as nn
import models.archs.srflow.Split import models.archs.srflow.Split
from models.archs.srflow import flow, thops from models.archs.srflow import flow, thops, Split
from models.archs.srflow.Split import Split2d from models.archs.srflow.Split import Split2d
from models.archs.srflow.glow_arch import f_conv2d_bias from models.archs.srflow.glow_arch import f_conv2d_bias
from models.archs.srflow.FlowStep import FlowStep from models.archs.srflow.FlowStep import FlowStep
from utils.util import opt_get from utils.util import opt_get
import torchvision
class FlowUpsamplerNet(nn.Module): class FlowUpsamplerNet(nn.Module):
def __init__(self, image_shape, hidden_channels, K, L=None, def __init__(self, image_shape, hidden_channels, scale,
rrdb_blocks,
actnorm_scale=1.0, actnorm_scale=1.0,
flow_permutation=None, flow_permutation='invconv',
flow_coupling="affine", flow_coupling="affine",
LU_decomposed=False, opt=None): LU_decomposed=False, K=16, L=3,
norm_opt=None,
n_bypass_channels=None):
super().__init__() super().__init__()
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
self.output_shapes = [] self.output_shapes = []
self.L = opt_get(opt, ['network_G', 'flow', 'L']) self.L = L
self.K = opt_get(opt, ['network_G', 'flow', 'K']) self.K = K
self.scale=scale
if isinstance(self.K, int): if isinstance(self.K, int):
self.K = [K for K in [K, ] * (self.L + 1)] self.K = [K for K in [K, ] * (self.L + 1)]
self.opt = opt
H, W, self.C = image_shape H, W, self.C = image_shape
self.image_shape = image_shape
self.check_image_shape() self.check_image_shape()
if opt['scale'] == 16: if scale == 16:
self.levelToName = { self.levelToName = {
0: 'fea_up16', 0: 'fea_up16',
1: 'fea_up8', 1: 'fea_up8',
@ -39,7 +44,7 @@ class FlowUpsamplerNet(nn.Module):
4: 'fea_up1', 4: 'fea_up1',
} }
if opt['scale'] == 8: if scale == 8:
self.levelToName = { self.levelToName = {
0: 'fea_up8', 0: 'fea_up8',
1: 'fea_up4', 1: 'fea_up4',
@ -48,7 +53,7 @@ class FlowUpsamplerNet(nn.Module):
4: 'fea_up0' 4: 'fea_up0'
} }
elif opt['scale'] == 4: elif scale == 4:
self.levelToName = { self.levelToName = {
0: 'fea_up4', 0: 'fea_up4',
1: 'fea_up2', 1: 'fea_up2',
@ -57,14 +62,10 @@ class FlowUpsamplerNet(nn.Module):
4: 'fea_up-1' 4: 'fea_up-1'
} }
affineInCh = self.get_affineInCh(opt_get) affineInCh = self.get_affineInCh(rrdb_blocks)
flow_permutation = self.get_flow_permutation(flow_permutation, opt)
normOpt = opt_get(opt, ['network_G', 'flow', 'norm'])
conditional_channels = {} conditional_channels = {}
n_rrdb = self.get_n_rrdb_channels(opt, opt_get) n_rrdb = self.get_n_rrdb_channels(rrdb_blocks)
n_bypass_channels = opt_get(opt, ['network_G', 'flow', 'levelConditional', 'n_channels'])
conditional_channels[0] = n_rrdb conditional_channels[0] = n_rrdb
for level in range(1, self.L + 1): for level in range(1, self.L + 1):
# Level 1 gets conditionals from 2, 3, 4 => L - level # Level 1 gets conditionals from 2, 3, 4 => L - level
@ -80,37 +81,29 @@ class FlowUpsamplerNet(nn.Module):
H, W = self.arch_squeeze(H, W) H, W = self.arch_squeeze(H, W)
# 2. K FlowStep # 2. K FlowStep
self.arch_additionalFlowAffine(H, LU_decomposed, W, actnorm_scale, hidden_channels, opt) self.arch_additionalFlowAffine(H, LU_decomposed, W, actnorm_scale, hidden_channels)
self.arch_FlowStep(H, self.K[level], LU_decomposed, W, actnorm_scale, affineInCh, flow_coupling, self.arch_FlowStep(H, self.K[level], LU_decomposed, W, actnorm_scale, affineInCh, flow_coupling,
flow_permutation, flow_permutation,
hidden_channels, normOpt, opt, opt_get, hidden_channels, norm_opt,
n_conditinal_channels=conditional_channels[level]) n_conditional_channels=conditional_channels[level])
# Split # Split
self.arch_split(H, W, level, self.L, opt, opt_get) self.arch_split(H, W, level, self.L)
if opt_get(opt, ['network_G', 'flow', 'split', 'enable']):
self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64 // 2 // 2)
else:
self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64)
self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64 // 2 // 2)
self.H = H self.H = H
self.W = W self.W = W
self.scaleH = 160 / H
self.scaleW = 160 / W
def get_n_rrdb_channels(self, opt, opt_get): def get_n_rrdb_channels(self, blocks):
blocks = opt_get(opt, ['network_G', 'flow', 'stackRRDB', 'blocks'])
n_rrdb = 64 if blocks is None else (len(blocks) + 1) * 64 n_rrdb = 64 if blocks is None else (len(blocks) + 1) * 64
return n_rrdb return n_rrdb
def arch_FlowStep(self, H, K, LU_decomposed, W, actnorm_scale, affineInCh, flow_coupling, flow_permutation, def arch_FlowStep(self, H, K, LU_decomposed, W, actnorm_scale, affineInCh, flow_coupling, flow_permutation,
hidden_channels, normOpt, opt, opt_get, n_conditinal_channels=None): hidden_channels, normOpt, n_conditional_channels=None, condAff=None):
condAff = self.get_condAffSetting(opt, opt_get)
if condAff is not None: if condAff is not None:
condAff['in_channels_rrdb'] = n_conditinal_channels condAff['in_channels_rrdb'] = n_conditional_channels
for k in range(K): for k in range(K):
position_name = get_position_name(H, self.opt['scale']) position_name = self.get_position_name(H, self.scale)
if normOpt: normOpt['position'] = position_name if normOpt: normOpt['position'] = position_name
self.layers.append( self.layers.append(
@ -121,48 +114,37 @@ class FlowUpsamplerNet(nn.Module):
flow_coupling=flow_coupling, flow_coupling=flow_coupling,
acOpt=condAff, acOpt=condAff,
position=position_name, position=position_name,
LU_decomposed=LU_decomposed, opt=opt, idx=k, normOpt=normOpt)) LU_decomposed=LU_decomposed, idx=k, normOpt=normOpt))
self.output_shapes.append( self.output_shapes.append(
[-1, self.C, H, W]) [-1, self.C, H, W])
def get_condAffSetting(self, opt, opt_get): def arch_split(self, H, W, L, levels, split_flow=True, correct_splits=False, logs_eps=0, consume_ratio=.5, split_conditional=False, cond_channels=None, split_type='Split2d'):
condAff = opt_get(opt, ['network_G', 'flow', 'condAff']) or None
condAff = opt_get(opt, ['network_G', 'flow', 'condFtAffine']) or condAff
return condAff
def arch_split(self, H, W, L, levels, opt, opt_get):
correct_splits = opt_get(opt, ['network_G', 'flow', 'split', 'correct_splits'], False)
correction = 0 if correct_splits else 1 correction = 0 if correct_splits else 1
if opt_get(opt, ['network_G', 'flow', 'split', 'enable']) and L < levels - correction: if split_flow and L < levels - correction:
logs_eps = opt_get(opt, ['network_G', 'flow', 'split', 'logs_eps']) or 0 logs_eps = logs_eps
consume_ratio = opt_get(opt, ['network_G', 'flow', 'split', 'consume_ratio']) or 0.5 consume_ratio = consume_ratio
position_name = get_position_name(H, self.opt['scale']) position_name = self.get_position_name(H, self.scale)
position = position_name if opt_get(opt, ['network_G', 'flow', 'split', 'conditional']) else None position = position_name if split_conditional else None
cond_channels = opt_get(opt, ['network_G', 'flow', 'split', 'cond_channels'])
cond_channels = 0 if cond_channels is None else cond_channels cond_channels = 0 if cond_channels is None else cond_channels
t = opt_get(opt, ['network_G', 'flow', 'split', 'type'], 'Split2d') if split_type == 'Split2d':
split = Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position,
if t == 'Split2d': cond_channels=cond_channels, consume_ratio=consume_ratio)
split = models.modules.Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position,
cond_channels=cond_channels, consume_ratio=consume_ratio, opt=opt)
self.layers.append(split) self.layers.append(split)
self.output_shapes.append([-1, split.num_channels_pass, H, W]) self.output_shapes.append([-1, split.num_channels_pass, H, W])
self.C = split.num_channels_pass self.C = split.num_channels_pass
def arch_additionalFlowAffine(self, H, LU_decomposed, W, actnorm_scale, hidden_channels, opt): def arch_additionalFlowAffine(self, H, LU_decomposed, W, actnorm_scale, hidden_channels, additionalFlowNoAffine=2):
if 'additionalFlowNoAffine' in opt['network_G']['flow']: for _ in range(additionalFlowNoAffine):
n_additionalFlowNoAffine = int(opt['network_G']['flow']['additionalFlowNoAffine']) self.layers.append(
for _ in range(n_additionalFlowNoAffine): FlowStep(in_channels=self.C,
self.layers.append( hidden_channels=hidden_channels,
FlowStep(in_channels=self.C, actnorm_scale=actnorm_scale,
hidden_channels=hidden_channels, flow_permutation='invconv',
actnorm_scale=actnorm_scale, flow_coupling='noCoupling',
flow_permutation='invconv', LU_decomposed=LU_decomposed))
flow_coupling='noCoupling', self.output_shapes.append(
LU_decomposed=LU_decomposed, opt=opt)) [-1, self.C, H, W])
self.output_shapes.append(
[-1, self.C, H, W])
def arch_squeeze(self, H, W): def arch_squeeze(self, H, W):
self.C, H, W = self.C * 4, H // 2, W // 2 self.C, H, W = self.C * 4, H // 2, W // 2
@ -170,13 +152,8 @@ class FlowUpsamplerNet(nn.Module):
self.output_shapes.append([-1, self.C, H, W]) self.output_shapes.append([-1, self.C, H, W])
return H, W return H, W
def get_flow_permutation(self, flow_permutation, opt): def get_affineInCh(self, rrdb_blocks):
flow_permutation = opt['network_G']['flow'].get('flow_permutation', 'invconv') affineInCh = (len(rrdb_blocks) + 1) * 64
return flow_permutation
def get_affineInCh(self, opt_get):
affineInCh = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or []
affineInCh = (len(affineInCh) + 1) * 64
return affineInCh return affineInCh
def check_image_shape(self): def check_image_shape(self):
@ -204,14 +181,14 @@ class FlowUpsamplerNet(nn.Module):
level_conditionals = {} level_conditionals = {}
bypasses = {} bypasses = {}
L = opt_get(self.opt, ['network_G', 'flow', 'L']) L = self.L
for level in range(1, L + 1): for level in range(1, L + 1):
bypasses[level] = torch.nn.functional.interpolate(gt, scale_factor=2 ** -level, mode='bilinear', align_corners=False) bypasses[level] = torch.nn.functional.interpolate(gt, scale_factor=2 ** -level, mode='bilinear', align_corners=False)
for layer, shape in zip(self.layers, self.output_shapes): for layer, shape in zip(self.layers, self.output_shapes):
size = shape[2] size = shape[2]
level = int(np.log(160 / size) / np.log(2)) level = int(np.log(self.image_shape[0] / size) / np.log(2))
if level > 0 and level not in level_conditionals.keys(): if level > 0 and level not in level_conditionals.keys():
level_conditionals[level] = rrdbResults[self.levelToName[level]] level_conditionals[level] = rrdbResults[self.levelToName[level]]
@ -255,15 +232,12 @@ class FlowUpsamplerNet(nn.Module):
# debug.imwrite("fl_fea", fl_fea) # debug.imwrite("fl_fea", fl_fea)
bypasses = {} bypasses = {}
level_conditionals = {} level_conditionals = {}
if not opt_get(self.opt, ['network_G', 'flow', 'levelConditional', 'conditional']) == True: for level in range(self.L + 1):
for level in range(self.L + 1): level_conditionals[level] = rrdbResults[self.levelToName[level]]
level_conditionals[level] = rrdbResults[self.levelToName[level]]
for layer, shape in zip(reversed(self.layers), reversed(self.output_shapes)): for layer, shape in zip(reversed(self.layers), reversed(self.output_shapes)):
size = shape[2] size = shape[2]
level = int(np.log(160 / size) / np.log(2)) level = int(np.log(self.H / size) / np.log(2))
# size = fl_fea.shape[2]
# level = int(np.log(160 / size) / np.log(2))
if isinstance(layer, Split2d): if isinstance(layer, Split2d):
fl_fea, logdet = self.forward_split2d_reverse(eps_std, epses, fl_fea, layer, fl_fea, logdet = self.forward_split2d_reverse(eps_std, epses, fl_fea, layer,
@ -287,7 +261,7 @@ class FlowUpsamplerNet(nn.Module):
return fl_fea, logdet return fl_fea, logdet
def get_position_name(H, scale): def get_position_name(self, H, scale):
downscale_factor = 160 // H downscale_factor = self.image_shape[0] // H
position_name = 'fea_up{}'.format(scale / downscale_factor) position_name = 'fea_up{}'.format(scale / downscale_factor)
return position_name return position_name

View File

@ -3,7 +3,7 @@ import torch
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from models.modules import thops from models.archs.srflow import thops
class InvertibleConv1x1(nn.Module): class InvertibleConv1x1(nn.Module):

View File

@ -3,7 +3,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import models.archs.srflow.module_util as mutil import models.archs.srflow.module_util as mutil
from utils.util import opt_get from utils.util import opt_get, checkpoint
class ResidualDenseBlock_5C(nn.Module): class ResidualDenseBlock_5C(nn.Module):
@ -46,11 +46,14 @@ class RRDB(nn.Module):
class RRDBNet(nn.Module): class RRDBNet(nn.Module):
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, opt=None): def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, block_outputs=[], fea_up0=True,
self.opt = opt fea_up1=False):
super(RRDBNet, self).__init__() super(RRDBNet, self).__init__()
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
self.scale = scale self.scale = scale
self.block_outputs = block_outputs
self.fea_up0 = fea_up0
self.fea_up1 = fea_up1
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
self.RRDB_trunk = mutil.make_layer(RRDB_block_f, nb) self.RRDB_trunk = mutil.make_layer(RRDB_block_f, nb)
@ -73,11 +76,11 @@ class RRDBNet(nn.Module):
def forward(self, x, get_steps=False): def forward(self, x, get_steps=False):
fea = self.conv_first(x) fea = self.conv_first(x)
block_idxs = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or [] block_idxs = self.block_outputs or []
block_results = {} block_results = {}
for idx, m in enumerate(self.RRDB_trunk.children()): for idx, m in enumerate(self.RRDB_trunk.children()):
fea = m(fea) fea = checkpoint(m, fea)
for b in block_idxs: for b in block_idxs:
if b == idx: if b == idx:
block_results["block_{}".format(idx)] = fea block_results["block_{}".format(idx)] = fea
@ -117,11 +120,9 @@ class RRDBNet(nn.Module):
'fea_up32': fea_up32, 'fea_up32': fea_up32,
'out': out} 'out': out}
fea_up0_en = opt_get(self.opt, ['network_G', 'flow', 'fea_up0']) or False if self.fea_up0:
if fea_up0_en:
results['fea_up0'] = F.interpolate(last_lr_fea, scale_factor=1/2, mode='bilinear', align_corners=False, recompute_scale_factor=True) results['fea_up0'] = F.interpolate(last_lr_fea, scale_factor=1/2, mode='bilinear', align_corners=False, recompute_scale_factor=True)
fea_upn1_en = opt_get(self.opt, ['network_G', 'flow', 'fea_up-1']) or False if self.fea_up1:
if fea_upn1_en:
results['fea_up-1'] = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True) results['fea_up-1'] = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True)
if get_steps: if get_steps:

View File

@ -7,7 +7,7 @@ from models.archs.srflow.flow import Conv2dZeros, GaussianDiag
class Split2d(nn.Module): class Split2d(nn.Module):
def __init__(self, num_channels, logs_eps=0, cond_channels=0, position=None, consume_ratio=0.5, opt=None): def __init__(self, num_channels, logs_eps=0, cond_channels=0, position=None, consume_ratio=0.5):
super().__init__() super().__init__()
self.num_channels_consume = int(round(num_channels * consume_ratio)) self.num_channels_consume = int(round(num_channels * consume_ratio))
@ -17,7 +17,6 @@ class Split2d(nn.Module):
out_channels=self.num_channels_consume * 2) out_channels=self.num_channels_consume * 2)
self.logs_eps = logs_eps self.logs_eps = logs_eps
self.position = position self.position = position
self.opt = opt
def split2d_prior(self, z, ft): def split2d_prior(self, z, ft):
if ft is not None: if ft is not None:

View File

@ -4,57 +4,41 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np import numpy as np
from models.archs.srflow.FlowUpsamplerNet import FlowUpsamplerNet from models.archs.srflow.FlowUpsamplerNet import FlowUpsamplerNet
import models.archs.srflow.thops as thops import models.archs.srflow.thops as thops
import models.archs.srflow.flow as flow import models.archs.srflow.flow as flow
from utils.util import opt_get from models.archs.srflow.RRDBNet_arch import RRDBNet
class SRFlowNet(nn.Module): class SRFlowNet(nn.Module):
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, K=None, opt=None, step=None): def __init__(self, in_nc, out_nc, nf, nb, quant, flow_block_maps, noise_quant,
hidden_channels=64, gc=32, scale=4, K=16, L=3, train_rrdb_at_step=0,
hr_img_shape=(128,128,3), coupling='CondAffineSeparatedAndCond'):
super(SRFlowNet, self).__init__() super(SRFlowNet, self).__init__()
self.opt = opt self.scale = scale
self.quant = 255 if opt_get(opt, ['datasets', 'train', 'quant']) is \ self.noise_quant = noise_quant
None else opt_get(opt, ['datasets', 'train', 'quant']) self.quant = quant
self.RRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, opt) self.flow_block_maps = flow_block_maps
hidden_channels = opt_get(opt, ['network_G', 'flow', 'hidden_channels']) self.RRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, flow_block_maps)
hidden_channels = hidden_channels or 64 self.train_rrdb_step = train_rrdb_at_step
self.RRDB_training = True # Default is true self.RRDB_training = True
train_RRDB_delay = opt_get(self.opt, ['network_G', 'train_RRDB_delay']) self.flowUpsamplerNet = FlowUpsamplerNet(image_shape=hr_img_shape,
set_RRDB_to_train = False hidden_channels=hidden_channels,
if set_RRDB_to_train: scale=scale, rrdb_blocks=flow_block_maps,
self.set_rrdb_training(True) K=K, L=L, flow_coupling=coupling)
self.flowUpsamplerNet = \
FlowUpsamplerNet((160, 160, 3), hidden_channels, K,
flow_coupling=opt['network_G']['flow']['coupling'], opt=opt)
self.i = 0 self.i = 0
def set_rrdb_training(self, trainable): def forward(self, gt=None, lr=None, reverse=False, z=None, eps_std=None, epses=None, reverse_with_grad=False,
if self.RRDB_training != trainable:
for p in self.RRDB.parameters():
p.requires_grad = trainable
self.RRDB_training = trainable
return True
return False
def forward(self, gt=None, lr=None, z=None, eps_std=None, reverse=False, epses=None, reverse_with_grad=False,
lr_enc=None, lr_enc=None,
add_gt_noise=False, step=None, y_label=None): add_gt_noise=False, step=None, y_label=None):
if not reverse: if not reverse:
return self.normal_flow(gt, lr, epses=epses, lr_enc=lr_enc, add_gt_noise=add_gt_noise, step=step, return self.normal_flow(gt, lr, epses=epses, lr_enc=lr_enc, add_gt_noise=add_gt_noise, step=step,
y_onehot=y_label) y_onehot=y_label)
else: else:
# assert lr.shape[0] == 1
assert lr.shape[1] == 3 assert lr.shape[1] == 3
# assert lr.shape[2] == 20
# assert lr.shape[3] == 20
# assert z.shape[0] == 1
# assert z.shape[1] == 3 * 8 * 8
# assert z.shape[2] == 20
# assert z.shape[3] == 20
if reverse_with_grad: if reverse_with_grad:
return self.reverse_flow(lr, z, y_onehot=y_label, eps_std=eps_std, epses=epses, lr_enc=lr_enc, return self.reverse_flow(lr, z, y_onehot=y_label, eps_std=eps_std, epses=epses, lr_enc=lr_enc,
add_gt_noise=add_gt_noise) add_gt_noise=add_gt_noise)
@ -74,8 +58,7 @@ class SRFlowNet(nn.Module):
if add_gt_noise: if add_gt_noise:
# Setup # Setup
noiseQuant = opt_get(self.opt, ['network_G', 'flow', 'augmentation', 'noiseQuant'], True) if self.noise_quant:
if noiseQuant:
z = z + ((torch.rand(z.shape, device=z.device) - 0.5) / self.quant) z = z + ((torch.rand(z.shape, device=z.device) - 0.5) / self.quant)
logdet = logdet + float(-np.log(self.quant) * pixels) logdet = logdet + float(-np.log(self.quant) * pixels)
@ -100,24 +83,23 @@ class SRFlowNet(nn.Module):
def rrdbPreprocessing(self, lr): def rrdbPreprocessing(self, lr):
rrdbResults = self.RRDB(lr, get_steps=True) rrdbResults = self.RRDB(lr, get_steps=True)
block_idxs = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or [] block_idxs = self.flow_block_maps
if len(block_idxs) > 0: if len(block_idxs) > 0:
concat = torch.cat([rrdbResults["block_{}".format(idx)] for idx in block_idxs], dim=1) concat = torch.cat([rrdbResults["block_{}".format(idx)] for idx in block_idxs], dim=1)
if opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'concat']) or False: keys = ['last_lr_fea', 'fea_up1', 'fea_up2', 'fea_up4']
keys = ['last_lr_fea', 'fea_up1', 'fea_up2', 'fea_up4'] if 'fea_up0' in rrdbResults.keys():
if 'fea_up0' in rrdbResults.keys(): keys.append('fea_up0')
keys.append('fea_up0') if 'fea_up-1' in rrdbResults.keys():
if 'fea_up-1' in rrdbResults.keys(): keys.append('fea_up-1')
keys.append('fea_up-1') if self.scale >= 8:
if self.opt['scale'] >= 8: keys.append('fea_up8')
keys.append('fea_up8') if self.scale == 16:
if self.opt['scale'] == 16: keys.append('fea_up16')
keys.append('fea_up16') for k in keys:
for k in keys: h = rrdbResults[k].shape[2]
h = rrdbResults[k].shape[2] w = rrdbResults[k].shape[3]
w = rrdbResults[k].shape[3] rrdbResults[k] = torch.cat([rrdbResults[k], F.interpolate(concat, (h, w))], dim=1)
rrdbResults[k] = torch.cat([rrdbResults[k], F.interpolate(concat, (h, w))], dim=1)
return rrdbResults return rrdbResults
def get_score(self, disc_loss_sigma, z): def get_score(self, disc_loss_sigma, z):
@ -127,7 +109,7 @@ class SRFlowNet(nn.Module):
def reverse_flow(self, lr, z, y_onehot, eps_std, epses=None, lr_enc=None, add_gt_noise=True): def reverse_flow(self, lr, z, y_onehot, eps_std, epses=None, lr_enc=None, add_gt_noise=True):
logdet = torch.zeros_like(lr[:, 0, 0, 0]) logdet = torch.zeros_like(lr[:, 0, 0, 0])
pixels = thops.pixels(lr) * self.opt['scale'] ** 2 pixels = thops.pixels(lr) * self.scale ** 2
if add_gt_noise: if add_gt_noise:
logdet = logdet - float(-np.log(self.quant) * pixels) logdet = logdet - float(-np.log(self.quant) * pixels)
@ -138,4 +120,16 @@ class SRFlowNet(nn.Module):
x, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, z=z, eps_std=eps_std, reverse=True, epses=epses, x, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, z=z, eps_std=eps_std, reverse=True, epses=epses,
logdet=logdet) logdet=logdet)
return x, logdet return x, logdet
def set_rrdb_training(self, trainable):
if self.RRDB_training != trainable:
for p in self.RRDB.parameters():
if not trainable:
p.DO_NOT_TRAIN = True
elif hasattr(p, "DO_NOT_TRAIN"):
del p.DO_NOT_TRAIN
self.RRDB_training = trainable
def update_for_step(self, step, experiments_path='.'):
self.set_rrdb_training(step > self.train_rrdb_step)