forked from mrq/DL-Art-School
"Clean up" SRFlow
This commit is contained in:
parent
d7877d0a36
commit
1e0d7be3ce
|
@ -3,11 +3,10 @@ from torch import nn as nn
|
||||||
|
|
||||||
from models.archs.srflow import thops
|
from models.archs.srflow import thops
|
||||||
from models.archs.srflow.flow import Conv2d, Conv2dZeros
|
from models.archs.srflow.flow import Conv2d, Conv2dZeros
|
||||||
from utils.util import opt_get
|
|
||||||
|
|
||||||
|
|
||||||
class CondAffineSeparatedAndCond(nn.Module):
|
class CondAffineSeparatedAndCond(nn.Module):
|
||||||
def __init__(self, in_channels, opt):
|
def __init__(self, in_channels, hidden_channels=64, affine_eps=.00001):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.need_features = True
|
self.need_features = True
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
@ -15,10 +14,8 @@ class CondAffineSeparatedAndCond(nn.Module):
|
||||||
self.kernel_hidden = 1
|
self.kernel_hidden = 1
|
||||||
self.affine_eps = 0.0001
|
self.affine_eps = 0.0001
|
||||||
self.n_hidden_layers = 1
|
self.n_hidden_layers = 1
|
||||||
hidden_channels = opt_get(opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'hidden_channels'])
|
self.hidden_channels = hidden_channels
|
||||||
self.hidden_channels = 64 if hidden_channels is None else hidden_channels
|
self.affine_eps = affine_eps
|
||||||
|
|
||||||
self.affine_eps = opt_get(opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'eps'], 0.0001)
|
|
||||||
|
|
||||||
self.channels_for_nn = self.in_channels // 2
|
self.channels_for_nn = self.in_channels // 2
|
||||||
self.channels_for_co = self.in_channels - self.channels_for_nn
|
self.channels_for_co = self.in_channels - self.channels_for_nn
|
||||||
|
|
|
@ -1,10 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
|
|
||||||
import models.archs.srflow
|
from models.archs.srflow import flow, thops, FlowAffineCouplingsAblation, FlowActNorms, Permutations
|
||||||
import models.archs.srflow.Permutations
|
|
||||||
from models.archs.srflow import flow, thops, FlowAffineCouplingsAblation
|
|
||||||
from utils.util import opt_get
|
|
||||||
|
|
||||||
|
|
||||||
def getConditional(rrdbResults, position):
|
def getConditional(rrdbResults, position):
|
||||||
|
@ -28,7 +25,7 @@ class FlowStep(nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_channels, hidden_channels,
|
def __init__(self, in_channels, hidden_channels,
|
||||||
actnorm_scale=1.0, flow_permutation="invconv", flow_coupling="additive",
|
actnorm_scale=1.0, flow_permutation="invconv", flow_coupling="additive",
|
||||||
LU_decomposed=False, opt=None, image_injector=None, idx=None, acOpt=None, normOpt=None, in_shape=None,
|
LU_decomposed=False, image_injector=None, idx=None, acOpt=None, normOpt=None, in_shape=None,
|
||||||
position=None):
|
position=None):
|
||||||
# check configures
|
# check configures
|
||||||
assert flow_permutation in FlowStep.FlowPermutation, \
|
assert flow_permutation in FlowStep.FlowPermutation, \
|
||||||
|
@ -47,17 +44,16 @@ class FlowStep(nn.Module):
|
||||||
self.acOpt = acOpt
|
self.acOpt = acOpt
|
||||||
|
|
||||||
# 1. actnorm
|
# 1. actnorm
|
||||||
self.actnorm = models.modules.FlowActNorms.ActNorm2d(in_channels, actnorm_scale)
|
self.actnorm = FlowActNorms.ActNorm2d(in_channels, actnorm_scale)
|
||||||
|
|
||||||
# 2. permute
|
# 2. permute
|
||||||
if flow_permutation == "invconv":
|
if flow_permutation == "invconv":
|
||||||
self.invconv = models.modules.Permutations.InvertibleConv1x1(
|
self.invconv = Permutations.InvertibleConv1x1(
|
||||||
in_channels, LU_decomposed=LU_decomposed)
|
in_channels, LU_decomposed=LU_decomposed)
|
||||||
|
|
||||||
# 3. coupling
|
# 3. coupling
|
||||||
if flow_coupling == "CondAffineSeparatedAndCond":
|
if flow_coupling == "CondAffineSeparatedAndCond":
|
||||||
self.affine = models.modules.FlowAffineCouplingsAblation.CondAffineSeparatedAndCond(in_channels=in_channels,
|
self.affine = FlowAffineCouplingsAblation.CondAffineSeparatedAndCond(in_channels=in_channels)
|
||||||
opt=opt)
|
|
||||||
elif flow_coupling == "noCoupling":
|
elif flow_coupling == "noCoupling":
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -3,34 +3,39 @@ import torch
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
|
|
||||||
import models.archs.srflow.Split
|
import models.archs.srflow.Split
|
||||||
from models.archs.srflow import flow, thops
|
from models.archs.srflow import flow, thops, Split
|
||||||
from models.archs.srflow.Split import Split2d
|
from models.archs.srflow.Split import Split2d
|
||||||
from models.archs.srflow.glow_arch import f_conv2d_bias
|
from models.archs.srflow.glow_arch import f_conv2d_bias
|
||||||
from models.archs.srflow.FlowStep import FlowStep
|
from models.archs.srflow.FlowStep import FlowStep
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
import torchvision
|
||||||
|
|
||||||
|
|
||||||
class FlowUpsamplerNet(nn.Module):
|
class FlowUpsamplerNet(nn.Module):
|
||||||
def __init__(self, image_shape, hidden_channels, K, L=None,
|
def __init__(self, image_shape, hidden_channels, scale,
|
||||||
|
rrdb_blocks,
|
||||||
actnorm_scale=1.0,
|
actnorm_scale=1.0,
|
||||||
flow_permutation=None,
|
flow_permutation='invconv',
|
||||||
flow_coupling="affine",
|
flow_coupling="affine",
|
||||||
LU_decomposed=False, opt=None):
|
LU_decomposed=False, K=16, L=3,
|
||||||
|
norm_opt=None,
|
||||||
|
n_bypass_channels=None):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.layers = nn.ModuleList()
|
self.layers = nn.ModuleList()
|
||||||
self.output_shapes = []
|
self.output_shapes = []
|
||||||
self.L = opt_get(opt, ['network_G', 'flow', 'L'])
|
self.L = L
|
||||||
self.K = opt_get(opt, ['network_G', 'flow', 'K'])
|
self.K = K
|
||||||
|
self.scale=scale
|
||||||
if isinstance(self.K, int):
|
if isinstance(self.K, int):
|
||||||
self.K = [K for K in [K, ] * (self.L + 1)]
|
self.K = [K for K in [K, ] * (self.L + 1)]
|
||||||
|
|
||||||
self.opt = opt
|
|
||||||
H, W, self.C = image_shape
|
H, W, self.C = image_shape
|
||||||
|
self.image_shape = image_shape
|
||||||
self.check_image_shape()
|
self.check_image_shape()
|
||||||
|
|
||||||
if opt['scale'] == 16:
|
if scale == 16:
|
||||||
self.levelToName = {
|
self.levelToName = {
|
||||||
0: 'fea_up16',
|
0: 'fea_up16',
|
||||||
1: 'fea_up8',
|
1: 'fea_up8',
|
||||||
|
@ -39,7 +44,7 @@ class FlowUpsamplerNet(nn.Module):
|
||||||
4: 'fea_up1',
|
4: 'fea_up1',
|
||||||
}
|
}
|
||||||
|
|
||||||
if opt['scale'] == 8:
|
if scale == 8:
|
||||||
self.levelToName = {
|
self.levelToName = {
|
||||||
0: 'fea_up8',
|
0: 'fea_up8',
|
||||||
1: 'fea_up4',
|
1: 'fea_up4',
|
||||||
|
@ -48,7 +53,7 @@ class FlowUpsamplerNet(nn.Module):
|
||||||
4: 'fea_up0'
|
4: 'fea_up0'
|
||||||
}
|
}
|
||||||
|
|
||||||
elif opt['scale'] == 4:
|
elif scale == 4:
|
||||||
self.levelToName = {
|
self.levelToName = {
|
||||||
0: 'fea_up4',
|
0: 'fea_up4',
|
||||||
1: 'fea_up2',
|
1: 'fea_up2',
|
||||||
|
@ -57,14 +62,10 @@ class FlowUpsamplerNet(nn.Module):
|
||||||
4: 'fea_up-1'
|
4: 'fea_up-1'
|
||||||
}
|
}
|
||||||
|
|
||||||
affineInCh = self.get_affineInCh(opt_get)
|
affineInCh = self.get_affineInCh(rrdb_blocks)
|
||||||
flow_permutation = self.get_flow_permutation(flow_permutation, opt)
|
|
||||||
|
|
||||||
normOpt = opt_get(opt, ['network_G', 'flow', 'norm'])
|
|
||||||
|
|
||||||
conditional_channels = {}
|
conditional_channels = {}
|
||||||
n_rrdb = self.get_n_rrdb_channels(opt, opt_get)
|
n_rrdb = self.get_n_rrdb_channels(rrdb_blocks)
|
||||||
n_bypass_channels = opt_get(opt, ['network_G', 'flow', 'levelConditional', 'n_channels'])
|
|
||||||
conditional_channels[0] = n_rrdb
|
conditional_channels[0] = n_rrdb
|
||||||
for level in range(1, self.L + 1):
|
for level in range(1, self.L + 1):
|
||||||
# Level 1 gets conditionals from 2, 3, 4 => L - level
|
# Level 1 gets conditionals from 2, 3, 4 => L - level
|
||||||
|
@ -80,37 +81,29 @@ class FlowUpsamplerNet(nn.Module):
|
||||||
H, W = self.arch_squeeze(H, W)
|
H, W = self.arch_squeeze(H, W)
|
||||||
|
|
||||||
# 2. K FlowStep
|
# 2. K FlowStep
|
||||||
self.arch_additionalFlowAffine(H, LU_decomposed, W, actnorm_scale, hidden_channels, opt)
|
self.arch_additionalFlowAffine(H, LU_decomposed, W, actnorm_scale, hidden_channels)
|
||||||
self.arch_FlowStep(H, self.K[level], LU_decomposed, W, actnorm_scale, affineInCh, flow_coupling,
|
self.arch_FlowStep(H, self.K[level], LU_decomposed, W, actnorm_scale, affineInCh, flow_coupling,
|
||||||
flow_permutation,
|
flow_permutation,
|
||||||
hidden_channels, normOpt, opt, opt_get,
|
hidden_channels, norm_opt,
|
||||||
n_conditinal_channels=conditional_channels[level])
|
n_conditional_channels=conditional_channels[level])
|
||||||
# Split
|
# Split
|
||||||
self.arch_split(H, W, level, self.L, opt, opt_get)
|
self.arch_split(H, W, level, self.L)
|
||||||
|
|
||||||
if opt_get(opt, ['network_G', 'flow', 'split', 'enable']):
|
|
||||||
self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64 // 2 // 2)
|
|
||||||
else:
|
|
||||||
self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64)
|
|
||||||
|
|
||||||
|
self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64 // 2 // 2)
|
||||||
self.H = H
|
self.H = H
|
||||||
self.W = W
|
self.W = W
|
||||||
self.scaleH = 160 / H
|
|
||||||
self.scaleW = 160 / W
|
|
||||||
|
|
||||||
def get_n_rrdb_channels(self, opt, opt_get):
|
def get_n_rrdb_channels(self, blocks):
|
||||||
blocks = opt_get(opt, ['network_G', 'flow', 'stackRRDB', 'blocks'])
|
|
||||||
n_rrdb = 64 if blocks is None else (len(blocks) + 1) * 64
|
n_rrdb = 64 if blocks is None else (len(blocks) + 1) * 64
|
||||||
return n_rrdb
|
return n_rrdb
|
||||||
|
|
||||||
def arch_FlowStep(self, H, K, LU_decomposed, W, actnorm_scale, affineInCh, flow_coupling, flow_permutation,
|
def arch_FlowStep(self, H, K, LU_decomposed, W, actnorm_scale, affineInCh, flow_coupling, flow_permutation,
|
||||||
hidden_channels, normOpt, opt, opt_get, n_conditinal_channels=None):
|
hidden_channels, normOpt, n_conditional_channels=None, condAff=None):
|
||||||
condAff = self.get_condAffSetting(opt, opt_get)
|
|
||||||
if condAff is not None:
|
if condAff is not None:
|
||||||
condAff['in_channels_rrdb'] = n_conditinal_channels
|
condAff['in_channels_rrdb'] = n_conditional_channels
|
||||||
|
|
||||||
for k in range(K):
|
for k in range(K):
|
||||||
position_name = get_position_name(H, self.opt['scale'])
|
position_name = self.get_position_name(H, self.scale)
|
||||||
if normOpt: normOpt['position'] = position_name
|
if normOpt: normOpt['position'] = position_name
|
||||||
|
|
||||||
self.layers.append(
|
self.layers.append(
|
||||||
|
@ -121,48 +114,37 @@ class FlowUpsamplerNet(nn.Module):
|
||||||
flow_coupling=flow_coupling,
|
flow_coupling=flow_coupling,
|
||||||
acOpt=condAff,
|
acOpt=condAff,
|
||||||
position=position_name,
|
position=position_name,
|
||||||
LU_decomposed=LU_decomposed, opt=opt, idx=k, normOpt=normOpt))
|
LU_decomposed=LU_decomposed, idx=k, normOpt=normOpt))
|
||||||
self.output_shapes.append(
|
self.output_shapes.append(
|
||||||
[-1, self.C, H, W])
|
[-1, self.C, H, W])
|
||||||
|
|
||||||
def get_condAffSetting(self, opt, opt_get):
|
def arch_split(self, H, W, L, levels, split_flow=True, correct_splits=False, logs_eps=0, consume_ratio=.5, split_conditional=False, cond_channels=None, split_type='Split2d'):
|
||||||
condAff = opt_get(opt, ['network_G', 'flow', 'condAff']) or None
|
|
||||||
condAff = opt_get(opt, ['network_G', 'flow', 'condFtAffine']) or condAff
|
|
||||||
return condAff
|
|
||||||
|
|
||||||
def arch_split(self, H, W, L, levels, opt, opt_get):
|
|
||||||
correct_splits = opt_get(opt, ['network_G', 'flow', 'split', 'correct_splits'], False)
|
|
||||||
correction = 0 if correct_splits else 1
|
correction = 0 if correct_splits else 1
|
||||||
if opt_get(opt, ['network_G', 'flow', 'split', 'enable']) and L < levels - correction:
|
if split_flow and L < levels - correction:
|
||||||
logs_eps = opt_get(opt, ['network_G', 'flow', 'split', 'logs_eps']) or 0
|
logs_eps = logs_eps
|
||||||
consume_ratio = opt_get(opt, ['network_G', 'flow', 'split', 'consume_ratio']) or 0.5
|
consume_ratio = consume_ratio
|
||||||
position_name = get_position_name(H, self.opt['scale'])
|
position_name = self.get_position_name(H, self.scale)
|
||||||
position = position_name if opt_get(opt, ['network_G', 'flow', 'split', 'conditional']) else None
|
position = position_name if split_conditional else None
|
||||||
cond_channels = opt_get(opt, ['network_G', 'flow', 'split', 'cond_channels'])
|
|
||||||
cond_channels = 0 if cond_channels is None else cond_channels
|
cond_channels = 0 if cond_channels is None else cond_channels
|
||||||
|
|
||||||
t = opt_get(opt, ['network_G', 'flow', 'split', 'type'], 'Split2d')
|
if split_type == 'Split2d':
|
||||||
|
split = Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position,
|
||||||
if t == 'Split2d':
|
cond_channels=cond_channels, consume_ratio=consume_ratio)
|
||||||
split = models.modules.Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position,
|
|
||||||
cond_channels=cond_channels, consume_ratio=consume_ratio, opt=opt)
|
|
||||||
self.layers.append(split)
|
self.layers.append(split)
|
||||||
self.output_shapes.append([-1, split.num_channels_pass, H, W])
|
self.output_shapes.append([-1, split.num_channels_pass, H, W])
|
||||||
self.C = split.num_channels_pass
|
self.C = split.num_channels_pass
|
||||||
|
|
||||||
def arch_additionalFlowAffine(self, H, LU_decomposed, W, actnorm_scale, hidden_channels, opt):
|
def arch_additionalFlowAffine(self, H, LU_decomposed, W, actnorm_scale, hidden_channels, additionalFlowNoAffine=2):
|
||||||
if 'additionalFlowNoAffine' in opt['network_G']['flow']:
|
for _ in range(additionalFlowNoAffine):
|
||||||
n_additionalFlowNoAffine = int(opt['network_G']['flow']['additionalFlowNoAffine'])
|
self.layers.append(
|
||||||
for _ in range(n_additionalFlowNoAffine):
|
FlowStep(in_channels=self.C,
|
||||||
self.layers.append(
|
hidden_channels=hidden_channels,
|
||||||
FlowStep(in_channels=self.C,
|
actnorm_scale=actnorm_scale,
|
||||||
hidden_channels=hidden_channels,
|
flow_permutation='invconv',
|
||||||
actnorm_scale=actnorm_scale,
|
flow_coupling='noCoupling',
|
||||||
flow_permutation='invconv',
|
LU_decomposed=LU_decomposed))
|
||||||
flow_coupling='noCoupling',
|
self.output_shapes.append(
|
||||||
LU_decomposed=LU_decomposed, opt=opt))
|
[-1, self.C, H, W])
|
||||||
self.output_shapes.append(
|
|
||||||
[-1, self.C, H, W])
|
|
||||||
|
|
||||||
def arch_squeeze(self, H, W):
|
def arch_squeeze(self, H, W):
|
||||||
self.C, H, W = self.C * 4, H // 2, W // 2
|
self.C, H, W = self.C * 4, H // 2, W // 2
|
||||||
|
@ -170,13 +152,8 @@ class FlowUpsamplerNet(nn.Module):
|
||||||
self.output_shapes.append([-1, self.C, H, W])
|
self.output_shapes.append([-1, self.C, H, W])
|
||||||
return H, W
|
return H, W
|
||||||
|
|
||||||
def get_flow_permutation(self, flow_permutation, opt):
|
def get_affineInCh(self, rrdb_blocks):
|
||||||
flow_permutation = opt['network_G']['flow'].get('flow_permutation', 'invconv')
|
affineInCh = (len(rrdb_blocks) + 1) * 64
|
||||||
return flow_permutation
|
|
||||||
|
|
||||||
def get_affineInCh(self, opt_get):
|
|
||||||
affineInCh = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or []
|
|
||||||
affineInCh = (len(affineInCh) + 1) * 64
|
|
||||||
return affineInCh
|
return affineInCh
|
||||||
|
|
||||||
def check_image_shape(self):
|
def check_image_shape(self):
|
||||||
|
@ -204,14 +181,14 @@ class FlowUpsamplerNet(nn.Module):
|
||||||
level_conditionals = {}
|
level_conditionals = {}
|
||||||
bypasses = {}
|
bypasses = {}
|
||||||
|
|
||||||
L = opt_get(self.opt, ['network_G', 'flow', 'L'])
|
L = self.L
|
||||||
|
|
||||||
for level in range(1, L + 1):
|
for level in range(1, L + 1):
|
||||||
bypasses[level] = torch.nn.functional.interpolate(gt, scale_factor=2 ** -level, mode='bilinear', align_corners=False)
|
bypasses[level] = torch.nn.functional.interpolate(gt, scale_factor=2 ** -level, mode='bilinear', align_corners=False)
|
||||||
|
|
||||||
for layer, shape in zip(self.layers, self.output_shapes):
|
for layer, shape in zip(self.layers, self.output_shapes):
|
||||||
size = shape[2]
|
size = shape[2]
|
||||||
level = int(np.log(160 / size) / np.log(2))
|
level = int(np.log(self.image_shape[0] / size) / np.log(2))
|
||||||
|
|
||||||
if level > 0 and level not in level_conditionals.keys():
|
if level > 0 and level not in level_conditionals.keys():
|
||||||
level_conditionals[level] = rrdbResults[self.levelToName[level]]
|
level_conditionals[level] = rrdbResults[self.levelToName[level]]
|
||||||
|
@ -255,15 +232,12 @@ class FlowUpsamplerNet(nn.Module):
|
||||||
# debug.imwrite("fl_fea", fl_fea)
|
# debug.imwrite("fl_fea", fl_fea)
|
||||||
bypasses = {}
|
bypasses = {}
|
||||||
level_conditionals = {}
|
level_conditionals = {}
|
||||||
if not opt_get(self.opt, ['network_G', 'flow', 'levelConditional', 'conditional']) == True:
|
for level in range(self.L + 1):
|
||||||
for level in range(self.L + 1):
|
level_conditionals[level] = rrdbResults[self.levelToName[level]]
|
||||||
level_conditionals[level] = rrdbResults[self.levelToName[level]]
|
|
||||||
|
|
||||||
for layer, shape in zip(reversed(self.layers), reversed(self.output_shapes)):
|
for layer, shape in zip(reversed(self.layers), reversed(self.output_shapes)):
|
||||||
size = shape[2]
|
size = shape[2]
|
||||||
level = int(np.log(160 / size) / np.log(2))
|
level = int(np.log(self.H / size) / np.log(2))
|
||||||
# size = fl_fea.shape[2]
|
|
||||||
# level = int(np.log(160 / size) / np.log(2))
|
|
||||||
|
|
||||||
if isinstance(layer, Split2d):
|
if isinstance(layer, Split2d):
|
||||||
fl_fea, logdet = self.forward_split2d_reverse(eps_std, epses, fl_fea, layer,
|
fl_fea, logdet = self.forward_split2d_reverse(eps_std, epses, fl_fea, layer,
|
||||||
|
@ -287,7 +261,7 @@ class FlowUpsamplerNet(nn.Module):
|
||||||
return fl_fea, logdet
|
return fl_fea, logdet
|
||||||
|
|
||||||
|
|
||||||
def get_position_name(H, scale):
|
def get_position_name(self, H, scale):
|
||||||
downscale_factor = 160 // H
|
downscale_factor = self.image_shape[0] // H
|
||||||
position_name = 'fea_up{}'.format(scale / downscale_factor)
|
position_name = 'fea_up{}'.format(scale / downscale_factor)
|
||||||
return position_name
|
return position_name
|
||||||
|
|
|
@ -3,7 +3,7 @@ import torch
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from models.modules import thops
|
from models.archs.srflow import thops
|
||||||
|
|
||||||
|
|
||||||
class InvertibleConv1x1(nn.Module):
|
class InvertibleConv1x1(nn.Module):
|
||||||
|
|
|
@ -3,7 +3,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import models.archs.srflow.module_util as mutil
|
import models.archs.srflow.module_util as mutil
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get, checkpoint
|
||||||
|
|
||||||
|
|
||||||
class ResidualDenseBlock_5C(nn.Module):
|
class ResidualDenseBlock_5C(nn.Module):
|
||||||
|
@ -46,11 +46,14 @@ class RRDB(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class RRDBNet(nn.Module):
|
class RRDBNet(nn.Module):
|
||||||
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, opt=None):
|
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, block_outputs=[], fea_up0=True,
|
||||||
self.opt = opt
|
fea_up1=False):
|
||||||
super(RRDBNet, self).__init__()
|
super(RRDBNet, self).__init__()
|
||||||
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
|
self.block_outputs = block_outputs
|
||||||
|
self.fea_up0 = fea_up0
|
||||||
|
self.fea_up1 = fea_up1
|
||||||
|
|
||||||
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||||
self.RRDB_trunk = mutil.make_layer(RRDB_block_f, nb)
|
self.RRDB_trunk = mutil.make_layer(RRDB_block_f, nb)
|
||||||
|
@ -73,11 +76,11 @@ class RRDBNet(nn.Module):
|
||||||
def forward(self, x, get_steps=False):
|
def forward(self, x, get_steps=False):
|
||||||
fea = self.conv_first(x)
|
fea = self.conv_first(x)
|
||||||
|
|
||||||
block_idxs = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or []
|
block_idxs = self.block_outputs or []
|
||||||
block_results = {}
|
block_results = {}
|
||||||
|
|
||||||
for idx, m in enumerate(self.RRDB_trunk.children()):
|
for idx, m in enumerate(self.RRDB_trunk.children()):
|
||||||
fea = m(fea)
|
fea = checkpoint(m, fea)
|
||||||
for b in block_idxs:
|
for b in block_idxs:
|
||||||
if b == idx:
|
if b == idx:
|
||||||
block_results["block_{}".format(idx)] = fea
|
block_results["block_{}".format(idx)] = fea
|
||||||
|
@ -117,11 +120,9 @@ class RRDBNet(nn.Module):
|
||||||
'fea_up32': fea_up32,
|
'fea_up32': fea_up32,
|
||||||
'out': out}
|
'out': out}
|
||||||
|
|
||||||
fea_up0_en = opt_get(self.opt, ['network_G', 'flow', 'fea_up0']) or False
|
if self.fea_up0:
|
||||||
if fea_up0_en:
|
|
||||||
results['fea_up0'] = F.interpolate(last_lr_fea, scale_factor=1/2, mode='bilinear', align_corners=False, recompute_scale_factor=True)
|
results['fea_up0'] = F.interpolate(last_lr_fea, scale_factor=1/2, mode='bilinear', align_corners=False, recompute_scale_factor=True)
|
||||||
fea_upn1_en = opt_get(self.opt, ['network_G', 'flow', 'fea_up-1']) or False
|
if self.fea_up1:
|
||||||
if fea_upn1_en:
|
|
||||||
results['fea_up-1'] = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True)
|
results['fea_up-1'] = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True)
|
||||||
|
|
||||||
if get_steps:
|
if get_steps:
|
||||||
|
|
|
@ -7,7 +7,7 @@ from models.archs.srflow.flow import Conv2dZeros, GaussianDiag
|
||||||
|
|
||||||
|
|
||||||
class Split2d(nn.Module):
|
class Split2d(nn.Module):
|
||||||
def __init__(self, num_channels, logs_eps=0, cond_channels=0, position=None, consume_ratio=0.5, opt=None):
|
def __init__(self, num_channels, logs_eps=0, cond_channels=0, position=None, consume_ratio=0.5):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.num_channels_consume = int(round(num_channels * consume_ratio))
|
self.num_channels_consume = int(round(num_channels * consume_ratio))
|
||||||
|
@ -17,7 +17,6 @@ class Split2d(nn.Module):
|
||||||
out_channels=self.num_channels_consume * 2)
|
out_channels=self.num_channels_consume * 2)
|
||||||
self.logs_eps = logs_eps
|
self.logs_eps = logs_eps
|
||||||
self.position = position
|
self.position = position
|
||||||
self.opt = opt
|
|
||||||
|
|
||||||
def split2d_prior(self, z, ft):
|
def split2d_prior(self, z, ft):
|
||||||
if ft is not None:
|
if ft is not None:
|
||||||
|
|
|
@ -4,57 +4,41 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from models.archs.srflow.FlowUpsamplerNet import FlowUpsamplerNet
|
from models.archs.srflow.FlowUpsamplerNet import FlowUpsamplerNet
|
||||||
import models.archs.srflow.thops as thops
|
import models.archs.srflow.thops as thops
|
||||||
import models.archs.srflow.flow as flow
|
import models.archs.srflow.flow as flow
|
||||||
from utils.util import opt_get
|
from models.archs.srflow.RRDBNet_arch import RRDBNet
|
||||||
|
|
||||||
|
|
||||||
class SRFlowNet(nn.Module):
|
class SRFlowNet(nn.Module):
|
||||||
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, K=None, opt=None, step=None):
|
def __init__(self, in_nc, out_nc, nf, nb, quant, flow_block_maps, noise_quant,
|
||||||
|
hidden_channels=64, gc=32, scale=4, K=16, L=3, train_rrdb_at_step=0,
|
||||||
|
hr_img_shape=(128,128,3), coupling='CondAffineSeparatedAndCond'):
|
||||||
super(SRFlowNet, self).__init__()
|
super(SRFlowNet, self).__init__()
|
||||||
|
|
||||||
self.opt = opt
|
self.scale = scale
|
||||||
self.quant = 255 if opt_get(opt, ['datasets', 'train', 'quant']) is \
|
self.noise_quant = noise_quant
|
||||||
None else opt_get(opt, ['datasets', 'train', 'quant'])
|
self.quant = quant
|
||||||
self.RRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, opt)
|
self.flow_block_maps = flow_block_maps
|
||||||
hidden_channels = opt_get(opt, ['network_G', 'flow', 'hidden_channels'])
|
self.RRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, flow_block_maps)
|
||||||
hidden_channels = hidden_channels or 64
|
self.train_rrdb_step = train_rrdb_at_step
|
||||||
self.RRDB_training = True # Default is true
|
self.RRDB_training = True
|
||||||
|
|
||||||
train_RRDB_delay = opt_get(self.opt, ['network_G', 'train_RRDB_delay'])
|
self.flowUpsamplerNet = FlowUpsamplerNet(image_shape=hr_img_shape,
|
||||||
set_RRDB_to_train = False
|
hidden_channels=hidden_channels,
|
||||||
if set_RRDB_to_train:
|
scale=scale, rrdb_blocks=flow_block_maps,
|
||||||
self.set_rrdb_training(True)
|
K=K, L=L, flow_coupling=coupling)
|
||||||
|
|
||||||
self.flowUpsamplerNet = \
|
|
||||||
FlowUpsamplerNet((160, 160, 3), hidden_channels, K,
|
|
||||||
flow_coupling=opt['network_G']['flow']['coupling'], opt=opt)
|
|
||||||
self.i = 0
|
self.i = 0
|
||||||
|
|
||||||
def set_rrdb_training(self, trainable):
|
def forward(self, gt=None, lr=None, reverse=False, z=None, eps_std=None, epses=None, reverse_with_grad=False,
|
||||||
if self.RRDB_training != trainable:
|
|
||||||
for p in self.RRDB.parameters():
|
|
||||||
p.requires_grad = trainable
|
|
||||||
self.RRDB_training = trainable
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def forward(self, gt=None, lr=None, z=None, eps_std=None, reverse=False, epses=None, reverse_with_grad=False,
|
|
||||||
lr_enc=None,
|
lr_enc=None,
|
||||||
add_gt_noise=False, step=None, y_label=None):
|
add_gt_noise=False, step=None, y_label=None):
|
||||||
if not reverse:
|
if not reverse:
|
||||||
return self.normal_flow(gt, lr, epses=epses, lr_enc=lr_enc, add_gt_noise=add_gt_noise, step=step,
|
return self.normal_flow(gt, lr, epses=epses, lr_enc=lr_enc, add_gt_noise=add_gt_noise, step=step,
|
||||||
y_onehot=y_label)
|
y_onehot=y_label)
|
||||||
else:
|
else:
|
||||||
# assert lr.shape[0] == 1
|
|
||||||
assert lr.shape[1] == 3
|
assert lr.shape[1] == 3
|
||||||
# assert lr.shape[2] == 20
|
|
||||||
# assert lr.shape[3] == 20
|
|
||||||
# assert z.shape[0] == 1
|
|
||||||
# assert z.shape[1] == 3 * 8 * 8
|
|
||||||
# assert z.shape[2] == 20
|
|
||||||
# assert z.shape[3] == 20
|
|
||||||
if reverse_with_grad:
|
if reverse_with_grad:
|
||||||
return self.reverse_flow(lr, z, y_onehot=y_label, eps_std=eps_std, epses=epses, lr_enc=lr_enc,
|
return self.reverse_flow(lr, z, y_onehot=y_label, eps_std=eps_std, epses=epses, lr_enc=lr_enc,
|
||||||
add_gt_noise=add_gt_noise)
|
add_gt_noise=add_gt_noise)
|
||||||
|
@ -74,8 +58,7 @@ class SRFlowNet(nn.Module):
|
||||||
|
|
||||||
if add_gt_noise:
|
if add_gt_noise:
|
||||||
# Setup
|
# Setup
|
||||||
noiseQuant = opt_get(self.opt, ['network_G', 'flow', 'augmentation', 'noiseQuant'], True)
|
if self.noise_quant:
|
||||||
if noiseQuant:
|
|
||||||
z = z + ((torch.rand(z.shape, device=z.device) - 0.5) / self.quant)
|
z = z + ((torch.rand(z.shape, device=z.device) - 0.5) / self.quant)
|
||||||
logdet = logdet + float(-np.log(self.quant) * pixels)
|
logdet = logdet + float(-np.log(self.quant) * pixels)
|
||||||
|
|
||||||
|
@ -100,24 +83,23 @@ class SRFlowNet(nn.Module):
|
||||||
|
|
||||||
def rrdbPreprocessing(self, lr):
|
def rrdbPreprocessing(self, lr):
|
||||||
rrdbResults = self.RRDB(lr, get_steps=True)
|
rrdbResults = self.RRDB(lr, get_steps=True)
|
||||||
block_idxs = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or []
|
block_idxs = self.flow_block_maps
|
||||||
if len(block_idxs) > 0:
|
if len(block_idxs) > 0:
|
||||||
concat = torch.cat([rrdbResults["block_{}".format(idx)] for idx in block_idxs], dim=1)
|
concat = torch.cat([rrdbResults["block_{}".format(idx)] for idx in block_idxs], dim=1)
|
||||||
|
|
||||||
if opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'concat']) or False:
|
keys = ['last_lr_fea', 'fea_up1', 'fea_up2', 'fea_up4']
|
||||||
keys = ['last_lr_fea', 'fea_up1', 'fea_up2', 'fea_up4']
|
if 'fea_up0' in rrdbResults.keys():
|
||||||
if 'fea_up0' in rrdbResults.keys():
|
keys.append('fea_up0')
|
||||||
keys.append('fea_up0')
|
if 'fea_up-1' in rrdbResults.keys():
|
||||||
if 'fea_up-1' in rrdbResults.keys():
|
keys.append('fea_up-1')
|
||||||
keys.append('fea_up-1')
|
if self.scale >= 8:
|
||||||
if self.opt['scale'] >= 8:
|
keys.append('fea_up8')
|
||||||
keys.append('fea_up8')
|
if self.scale == 16:
|
||||||
if self.opt['scale'] == 16:
|
keys.append('fea_up16')
|
||||||
keys.append('fea_up16')
|
for k in keys:
|
||||||
for k in keys:
|
h = rrdbResults[k].shape[2]
|
||||||
h = rrdbResults[k].shape[2]
|
w = rrdbResults[k].shape[3]
|
||||||
w = rrdbResults[k].shape[3]
|
rrdbResults[k] = torch.cat([rrdbResults[k], F.interpolate(concat, (h, w))], dim=1)
|
||||||
rrdbResults[k] = torch.cat([rrdbResults[k], F.interpolate(concat, (h, w))], dim=1)
|
|
||||||
return rrdbResults
|
return rrdbResults
|
||||||
|
|
||||||
def get_score(self, disc_loss_sigma, z):
|
def get_score(self, disc_loss_sigma, z):
|
||||||
|
@ -127,7 +109,7 @@ class SRFlowNet(nn.Module):
|
||||||
|
|
||||||
def reverse_flow(self, lr, z, y_onehot, eps_std, epses=None, lr_enc=None, add_gt_noise=True):
|
def reverse_flow(self, lr, z, y_onehot, eps_std, epses=None, lr_enc=None, add_gt_noise=True):
|
||||||
logdet = torch.zeros_like(lr[:, 0, 0, 0])
|
logdet = torch.zeros_like(lr[:, 0, 0, 0])
|
||||||
pixels = thops.pixels(lr) * self.opt['scale'] ** 2
|
pixels = thops.pixels(lr) * self.scale ** 2
|
||||||
|
|
||||||
if add_gt_noise:
|
if add_gt_noise:
|
||||||
logdet = logdet - float(-np.log(self.quant) * pixels)
|
logdet = logdet - float(-np.log(self.quant) * pixels)
|
||||||
|
@ -138,4 +120,16 @@ class SRFlowNet(nn.Module):
|
||||||
x, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, z=z, eps_std=eps_std, reverse=True, epses=epses,
|
x, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, z=z, eps_std=eps_std, reverse=True, epses=epses,
|
||||||
logdet=logdet)
|
logdet=logdet)
|
||||||
|
|
||||||
return x, logdet
|
return x, logdet
|
||||||
|
|
||||||
|
def set_rrdb_training(self, trainable):
|
||||||
|
if self.RRDB_training != trainable:
|
||||||
|
for p in self.RRDB.parameters():
|
||||||
|
if not trainable:
|
||||||
|
p.DO_NOT_TRAIN = True
|
||||||
|
elif hasattr(p, "DO_NOT_TRAIN"):
|
||||||
|
del p.DO_NOT_TRAIN
|
||||||
|
self.RRDB_training = trainable
|
||||||
|
|
||||||
|
def update_for_step(self, step, experiments_path='.'):
|
||||||
|
self.set_rrdb_training(step > self.train_rrdb_step)
|
Loading…
Reference in New Issue
Block a user