Add srflow arch

This commit is contained in:
James Betker 2020-11-06 20:38:04 -07:00
parent 4469d2e661
commit 34d319585c
13 changed files with 1335 additions and 0 deletions

View File

@ -0,0 +1,125 @@
import torch
from torch import nn as nn
from models.archs.srflow import thops
class _ActNorm(nn.Module):
"""
Activation Normalization
Initialize the bias and scale with a given minibatch,
so that the output per-channel have zero mean and unit variance for that.
After initialization, `bias` and `logs` will be trained as parameters.
"""
def __init__(self, num_features, scale=1.):
super().__init__()
# register mean and scale
size = [1, num_features, 1, 1]
self.register_parameter("bias", nn.Parameter(torch.zeros(*size)))
self.register_parameter("logs", nn.Parameter(torch.zeros(*size)))
self.num_features = num_features
self.scale = float(scale)
self.inited = False
def _check_input_dim(self, input):
return NotImplemented
def initialize_parameters(self, input):
self._check_input_dim(input)
if not self.training:
return
if (self.bias != 0).any():
self.inited = True
return
assert input.device == self.bias.device, (input.device, self.bias.device)
with torch.no_grad():
bias = thops.mean(input.clone(), dim=[0, 2, 3], keepdim=True) * -1.0
vars = thops.mean((input.clone() + bias) ** 2, dim=[0, 2, 3], keepdim=True)
logs = torch.log(self.scale / (torch.sqrt(vars) + 1e-6))
self.bias.data.copy_(bias.data)
self.logs.data.copy_(logs.data)
self.inited = True
def _center(self, input, reverse=False, offset=None):
bias = self.bias
if offset is not None:
bias = bias + offset
if not reverse:
return input + bias
else:
return input - bias
def _scale(self, input, logdet=None, reverse=False, offset=None):
logs = self.logs
if offset is not None:
logs = logs + offset
if not reverse:
input = input * torch.exp(logs) # should have shape batchsize, n_channels, 1, 1
# input = input * torch.exp(logs+logs_offset)
else:
input = input * torch.exp(-logs)
if logdet is not None:
"""
logs is log_std of `mean of channels`
so we need to multiply pixels
"""
dlogdet = thops.sum(logs) * thops.pixels(input)
if reverse:
dlogdet *= -1
logdet = logdet + dlogdet
return input, logdet
def forward(self, input, logdet=None, reverse=False, offset_mask=None, logs_offset=None, bias_offset=None):
if not self.inited:
self.initialize_parameters(input)
self._check_input_dim(input)
if offset_mask is not None:
logs_offset *= offset_mask
bias_offset *= offset_mask
# no need to permute dims as old version
if not reverse:
# center and scale
# self.input = input
input = self._center(input, reverse, bias_offset)
input, logdet = self._scale(input, logdet, reverse, logs_offset)
else:
# scale and center
input, logdet = self._scale(input, logdet, reverse, logs_offset)
input = self._center(input, reverse, bias_offset)
return input, logdet
class ActNorm2d(_ActNorm):
def __init__(self, num_features, scale=1.):
super().__init__(num_features, scale)
def _check_input_dim(self, input):
assert len(input.size()) == 4
assert input.size(1) == self.num_features, (
"[ActNorm]: input should be in shape as `BCHW`,"
" channels should be {} rather than {}".format(
self.num_features, input.size()))
class MaskedActNorm2d(ActNorm2d):
def __init__(self, num_features, scale=1.):
super().__init__(num_features, scale)
def forward(self, input, mask, logdet=None, reverse=False):
assert mask.dtype == torch.bool
output, logdet_out = super().forward(input, logdet, reverse)
input[mask] = output[mask]
logdet[mask] = logdet_out[mask]
return input, logdet

View File

@ -0,0 +1,119 @@
import torch
from torch import nn as nn
from models.archs.srflow import thops
from models.archs.srflow.flow import Conv2d, Conv2dZeros
from utils.util import opt_get
class CondAffineSeparatedAndCond(nn.Module):
def __init__(self, in_channels, opt):
super().__init__()
self.need_features = True
self.in_channels = in_channels
self.in_channels_rrdb = 320
self.kernel_hidden = 1
self.affine_eps = 0.0001
self.n_hidden_layers = 1
hidden_channels = opt_get(opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'hidden_channels'])
self.hidden_channels = 64 if hidden_channels is None else hidden_channels
self.affine_eps = opt_get(opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'eps'], 0.0001)
self.channels_for_nn = self.in_channels // 2
self.channels_for_co = self.in_channels - self.channels_for_nn
if self.channels_for_nn is None:
self.channels_for_nn = self.in_channels // 2
self.fAffine = self.F(in_channels=self.channels_for_nn + self.in_channels_rrdb,
out_channels=self.channels_for_co * 2,
hidden_channels=self.hidden_channels,
kernel_hidden=self.kernel_hidden,
n_hidden_layers=self.n_hidden_layers)
self.fFeatures = self.F(in_channels=self.in_channels_rrdb,
out_channels=self.in_channels * 2,
hidden_channels=self.hidden_channels,
kernel_hidden=self.kernel_hidden,
n_hidden_layers=self.n_hidden_layers)
def forward(self, input: torch.Tensor, logdet=None, reverse=False, ft=None):
if not reverse:
z = input
assert z.shape[1] == self.in_channels, (z.shape[1], self.in_channels)
# Feature Conditional
scaleFt, shiftFt = self.feature_extract(ft, self.fFeatures)
z = z + shiftFt
z = z * scaleFt
logdet = logdet + self.get_logdet(scaleFt)
# Self Conditional
z1, z2 = self.split(z)
scale, shift = self.feature_extract_aff(z1, ft, self.fAffine)
self.asserts(scale, shift, z1, z2)
z2 = z2 + shift
z2 = z2 * scale
logdet = logdet + self.get_logdet(scale)
z = thops.cat_feature(z1, z2)
output = z
else:
z = input
# Self Conditional
z1, z2 = self.split(z)
scale, shift = self.feature_extract_aff(z1, ft, self.fAffine)
self.asserts(scale, shift, z1, z2)
z2 = z2 / scale
z2 = z2 - shift
z = thops.cat_feature(z1, z2)
logdet = logdet - self.get_logdet(scale)
# Feature Conditional
scaleFt, shiftFt = self.feature_extract(ft, self.fFeatures)
z = z / scaleFt
z = z - shiftFt
logdet = logdet - self.get_logdet(scaleFt)
output = z
return output, logdet
def asserts(self, scale, shift, z1, z2):
assert z1.shape[1] == self.channels_for_nn, (z1.shape[1], self.channels_for_nn)
assert z2.shape[1] == self.channels_for_co, (z2.shape[1], self.channels_for_co)
assert scale.shape[1] == shift.shape[1], (scale.shape[1], shift.shape[1])
assert scale.shape[1] == z2.shape[1], (scale.shape[1], z1.shape[1], z2.shape[1])
def get_logdet(self, scale):
return thops.sum(torch.log(scale), dim=[1, 2, 3])
def feature_extract(self, z, f):
h = f(z)
shift, scale = thops.split_feature(h, "cross")
scale = (torch.sigmoid(scale + 2.) + self.affine_eps)
return scale, shift
def feature_extract_aff(self, z1, ft, f):
z = torch.cat([z1, ft], dim=1)
h = f(z)
shift, scale = thops.split_feature(h, "cross")
scale = (torch.sigmoid(scale + 2.) + self.affine_eps)
return scale, shift
def split(self, z):
z1 = z[:, :self.channels_for_nn]
z2 = z[:, self.channels_for_nn:]
assert z1.shape[1] + z2.shape[1] == z.shape[1], (z1.shape[1], z2.shape[1], z.shape[1])
return z1, z2
def F(self, in_channels, out_channels, hidden_channels, kernel_hidden=1, n_hidden_layers=1):
layers = [Conv2d(in_channels, hidden_channels), nn.ReLU(inplace=False)]
for _ in range(n_hidden_layers):
layers.append(Conv2d(hidden_channels, hidden_channels, kernel_size=[kernel_hidden, kernel_hidden]))
layers.append(nn.ReLU(inplace=False))
layers.append(Conv2dZeros(hidden_channels, out_channels))
return nn.Sequential(*layers)

View File

@ -0,0 +1,121 @@
import torch
from torch import nn as nn
import models.archs.srflow
import models.archs.srflow.Permutations
from models.archs.srflow import flow, thops, FlowAffineCouplingsAblation
from utils.util import opt_get
def getConditional(rrdbResults, position):
img_ft = rrdbResults if isinstance(rrdbResults, torch.Tensor) else rrdbResults[position]
return img_ft
class FlowStep(nn.Module):
FlowPermutation = {
"reverse": lambda obj, z, logdet, rev: (obj.reverse(z, rev), logdet),
"shuffle": lambda obj, z, logdet, rev: (obj.shuffle(z, rev), logdet),
"invconv": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
"squeeze_invconv": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
"resqueeze_invconv_alternating_2_3": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
"resqueeze_invconv_3": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
"InvertibleConv1x1GridAlign": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
"InvertibleConv1x1SubblocksShuf": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
"InvertibleConv1x1GridAlignIndepBorder": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
"InvertibleConv1x1GridAlignIndepBorder4": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
}
def __init__(self, in_channels, hidden_channels,
actnorm_scale=1.0, flow_permutation="invconv", flow_coupling="additive",
LU_decomposed=False, opt=None, image_injector=None, idx=None, acOpt=None, normOpt=None, in_shape=None,
position=None):
# check configures
assert flow_permutation in FlowStep.FlowPermutation, \
"float_permutation should be in `{}`".format(
FlowStep.FlowPermutation.keys())
super().__init__()
self.flow_permutation = flow_permutation
self.flow_coupling = flow_coupling
self.image_injector = image_injector
self.norm_type = normOpt['type'] if normOpt else 'ActNorm2d'
self.position = normOpt['position'] if normOpt else None
self.in_shape = in_shape
self.position = position
self.acOpt = acOpt
# 1. actnorm
self.actnorm = models.modules.FlowActNorms.ActNorm2d(in_channels, actnorm_scale)
# 2. permute
if flow_permutation == "invconv":
self.invconv = models.modules.Permutations.InvertibleConv1x1(
in_channels, LU_decomposed=LU_decomposed)
# 3. coupling
if flow_coupling == "CondAffineSeparatedAndCond":
self.affine = models.modules.FlowAffineCouplingsAblation.CondAffineSeparatedAndCond(in_channels=in_channels,
opt=opt)
elif flow_coupling == "noCoupling":
pass
else:
raise RuntimeError("coupling not Found:", flow_coupling)
def forward(self, input, logdet=None, reverse=False, rrdbResults=None):
if not reverse:
return self.normal_flow(input, logdet, rrdbResults)
else:
return self.reverse_flow(input, logdet, rrdbResults)
def normal_flow(self, z, logdet, rrdbResults=None):
if self.flow_coupling == "bentIdentityPreAct":
z, logdet = self.bentIdentPar(z, logdet, reverse=False)
# 1. actnorm
if self.norm_type == "ConditionalActNormImageInjector":
img_ft = getConditional(rrdbResults, self.position)
z, logdet = self.actnorm(z, img_ft=img_ft, logdet=logdet, reverse=False)
elif self.norm_type == "noNorm":
pass
else:
z, logdet = self.actnorm(z, logdet=logdet, reverse=False)
# 2. permute
z, logdet = FlowStep.FlowPermutation[self.flow_permutation](
self, z, logdet, False)
need_features = self.affine_need_features()
# 3. coupling
if need_features or self.flow_coupling in ["condAffine", "condFtAffine", "condNormAffine"]:
img_ft = getConditional(rrdbResults, self.position)
z, logdet = self.affine(input=z, logdet=logdet, reverse=False, ft=img_ft)
return z, logdet
def reverse_flow(self, z, logdet, rrdbResults=None):
need_features = self.affine_need_features()
# 1.coupling
if need_features or self.flow_coupling in ["condAffine", "condFtAffine", "condNormAffine"]:
img_ft = getConditional(rrdbResults, self.position)
z, logdet = self.affine(input=z, logdet=logdet, reverse=True, ft=img_ft)
# 2. permute
z, logdet = FlowStep.FlowPermutation[self.flow_permutation](
self, z, logdet, True)
# 3. actnorm
z, logdet = self.actnorm(z, logdet=logdet, reverse=True)
return z, logdet
def affine_need_features(self):
need_features = False
try:
need_features = self.affine.need_features
except:
pass
return need_features

View File

@ -0,0 +1,293 @@
import numpy as np
import torch
from torch import nn as nn
import models.archs.srflow.Split
from models.archs.srflow import flow, thops
from models.archs.srflow.Split import Split2d
from models.archs.srflow.glow_arch import f_conv2d_bias
from models.archs.srflow.FlowStep import FlowStep
from utils.util import opt_get
class FlowUpsamplerNet(nn.Module):
def __init__(self, image_shape, hidden_channels, K, L=None,
actnorm_scale=1.0,
flow_permutation=None,
flow_coupling="affine",
LU_decomposed=False, opt=None):
super().__init__()
self.layers = nn.ModuleList()
self.output_shapes = []
self.L = opt_get(opt, ['network_G', 'flow', 'L'])
self.K = opt_get(opt, ['network_G', 'flow', 'K'])
if isinstance(self.K, int):
self.K = [K for K in [K, ] * (self.L + 1)]
self.opt = opt
H, W, self.C = image_shape
self.check_image_shape()
if opt['scale'] == 16:
self.levelToName = {
0: 'fea_up16',
1: 'fea_up8',
2: 'fea_up4',
3: 'fea_up2',
4: 'fea_up1',
}
if opt['scale'] == 8:
self.levelToName = {
0: 'fea_up8',
1: 'fea_up4',
2: 'fea_up2',
3: 'fea_up1',
4: 'fea_up0'
}
elif opt['scale'] == 4:
self.levelToName = {
0: 'fea_up4',
1: 'fea_up2',
2: 'fea_up1',
3: 'fea_up0',
4: 'fea_up-1'
}
affineInCh = self.get_affineInCh(opt_get)
flow_permutation = self.get_flow_permutation(flow_permutation, opt)
normOpt = opt_get(opt, ['network_G', 'flow', 'norm'])
conditional_channels = {}
n_rrdb = self.get_n_rrdb_channels(opt, opt_get)
n_bypass_channels = opt_get(opt, ['network_G', 'flow', 'levelConditional', 'n_channels'])
conditional_channels[0] = n_rrdb
for level in range(1, self.L + 1):
# Level 1 gets conditionals from 2, 3, 4 => L - level
# Level 2 gets conditionals from 3, 4
# Level 3 gets conditionals from 4
# Level 4 gets conditionals from None
n_bypass = 0 if n_bypass_channels is None else (self.L - level) * n_bypass_channels
conditional_channels[level] = n_rrdb + n_bypass
# Upsampler
for level in range(1, self.L + 1):
# 1. Squeeze
H, W = self.arch_squeeze(H, W)
# 2. K FlowStep
self.arch_additionalFlowAffine(H, LU_decomposed, W, actnorm_scale, hidden_channels, opt)
self.arch_FlowStep(H, self.K[level], LU_decomposed, W, actnorm_scale, affineInCh, flow_coupling,
flow_permutation,
hidden_channels, normOpt, opt, opt_get,
n_conditinal_channels=conditional_channels[level])
# Split
self.arch_split(H, W, level, self.L, opt, opt_get)
if opt_get(opt, ['network_G', 'flow', 'split', 'enable']):
self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64 // 2 // 2)
else:
self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64)
self.H = H
self.W = W
self.scaleH = 160 / H
self.scaleW = 160 / W
def get_n_rrdb_channels(self, opt, opt_get):
blocks = opt_get(opt, ['network_G', 'flow', 'stackRRDB', 'blocks'])
n_rrdb = 64 if blocks is None else (len(blocks) + 1) * 64
return n_rrdb
def arch_FlowStep(self, H, K, LU_decomposed, W, actnorm_scale, affineInCh, flow_coupling, flow_permutation,
hidden_channels, normOpt, opt, opt_get, n_conditinal_channels=None):
condAff = self.get_condAffSetting(opt, opt_get)
if condAff is not None:
condAff['in_channels_rrdb'] = n_conditinal_channels
for k in range(K):
position_name = get_position_name(H, self.opt['scale'])
if normOpt: normOpt['position'] = position_name
self.layers.append(
FlowStep(in_channels=self.C,
hidden_channels=hidden_channels,
actnorm_scale=actnorm_scale,
flow_permutation=flow_permutation,
flow_coupling=flow_coupling,
acOpt=condAff,
position=position_name,
LU_decomposed=LU_decomposed, opt=opt, idx=k, normOpt=normOpt))
self.output_shapes.append(
[-1, self.C, H, W])
def get_condAffSetting(self, opt, opt_get):
condAff = opt_get(opt, ['network_G', 'flow', 'condAff']) or None
condAff = opt_get(opt, ['network_G', 'flow', 'condFtAffine']) or condAff
return condAff
def arch_split(self, H, W, L, levels, opt, opt_get):
correct_splits = opt_get(opt, ['network_G', 'flow', 'split', 'correct_splits'], False)
correction = 0 if correct_splits else 1
if opt_get(opt, ['network_G', 'flow', 'split', 'enable']) and L < levels - correction:
logs_eps = opt_get(opt, ['network_G', 'flow', 'split', 'logs_eps']) or 0
consume_ratio = opt_get(opt, ['network_G', 'flow', 'split', 'consume_ratio']) or 0.5
position_name = get_position_name(H, self.opt['scale'])
position = position_name if opt_get(opt, ['network_G', 'flow', 'split', 'conditional']) else None
cond_channels = opt_get(opt, ['network_G', 'flow', 'split', 'cond_channels'])
cond_channels = 0 if cond_channels is None else cond_channels
t = opt_get(opt, ['network_G', 'flow', 'split', 'type'], 'Split2d')
if t == 'Split2d':
split = models.modules.Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position,
cond_channels=cond_channels, consume_ratio=consume_ratio, opt=opt)
self.layers.append(split)
self.output_shapes.append([-1, split.num_channels_pass, H, W])
self.C = split.num_channels_pass
def arch_additionalFlowAffine(self, H, LU_decomposed, W, actnorm_scale, hidden_channels, opt):
if 'additionalFlowNoAffine' in opt['network_G']['flow']:
n_additionalFlowNoAffine = int(opt['network_G']['flow']['additionalFlowNoAffine'])
for _ in range(n_additionalFlowNoAffine):
self.layers.append(
FlowStep(in_channels=self.C,
hidden_channels=hidden_channels,
actnorm_scale=actnorm_scale,
flow_permutation='invconv',
flow_coupling='noCoupling',
LU_decomposed=LU_decomposed, opt=opt))
self.output_shapes.append(
[-1, self.C, H, W])
def arch_squeeze(self, H, W):
self.C, H, W = self.C * 4, H // 2, W // 2
self.layers.append(flow.SqueezeLayer(factor=2))
self.output_shapes.append([-1, self.C, H, W])
return H, W
def get_flow_permutation(self, flow_permutation, opt):
flow_permutation = opt['network_G']['flow'].get('flow_permutation', 'invconv')
return flow_permutation
def get_affineInCh(self, opt_get):
affineInCh = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or []
affineInCh = (len(affineInCh) + 1) * 64
return affineInCh
def check_image_shape(self):
assert self.C == 1 or self.C == 3, ("image_shape should be HWC, like (64, 64, 3)"
"self.C == 1 or self.C == 3")
def forward(self, gt=None, rrdbResults=None, z=None, epses=None, logdet=0., reverse=False, eps_std=None,
y_onehot=None):
if reverse:
epses_copy = [eps for eps in epses] if isinstance(epses, list) else epses
sr, logdet = self.decode(rrdbResults, z, eps_std, epses=epses_copy, logdet=logdet, y_onehot=y_onehot)
return sr, logdet
else:
assert gt is not None
assert rrdbResults is not None
z, logdet = self.encode(gt, rrdbResults, logdet=logdet, epses=epses, y_onehot=y_onehot)
return z, logdet
def encode(self, gt, rrdbResults, logdet=0.0, epses=None, y_onehot=None):
fl_fea = gt
reverse = False
level_conditionals = {}
bypasses = {}
L = opt_get(self.opt, ['network_G', 'flow', 'L'])
for level in range(1, L + 1):
bypasses[level] = torch.nn.functional.interpolate(gt, scale_factor=2 ** -level, mode='bilinear', align_corners=False)
for layer, shape in zip(self.layers, self.output_shapes):
size = shape[2]
level = int(np.log(160 / size) / np.log(2))
if level > 0 and level not in level_conditionals.keys():
level_conditionals[level] = rrdbResults[self.levelToName[level]]
level_conditionals[level] = rrdbResults[self.levelToName[level]]
if isinstance(layer, FlowStep):
fl_fea, logdet = layer(fl_fea, logdet, reverse=reverse, rrdbResults=level_conditionals[level])
elif isinstance(layer, Split2d):
fl_fea, logdet = self.forward_split2d(epses, fl_fea, layer, logdet, reverse, level_conditionals[level],
y_onehot=y_onehot)
else:
fl_fea, logdet = layer(fl_fea, logdet, reverse=reverse)
z = fl_fea
if not isinstance(epses, list):
return z, logdet
epses.append(z)
return epses, logdet
def forward_preFlow(self, fl_fea, logdet, reverse):
if hasattr(self, 'preFlow'):
for l in self.preFlow:
fl_fea, logdet = l(fl_fea, logdet, reverse=reverse)
return fl_fea, logdet
def forward_split2d(self, epses, fl_fea, layer, logdet, reverse, rrdbResults, y_onehot=None):
ft = None if layer.position is None else rrdbResults[layer.position]
fl_fea, logdet, eps = layer(fl_fea, logdet, reverse=reverse, eps=epses, ft=ft, y_onehot=y_onehot)
if isinstance(epses, list):
epses.append(eps)
return fl_fea, logdet
def decode(self, rrdbResults, z, eps_std=None, epses=None, logdet=0.0, y_onehot=None):
z = epses.pop() if isinstance(epses, list) else z
fl_fea = z
# debug.imwrite("fl_fea", fl_fea)
bypasses = {}
level_conditionals = {}
if not opt_get(self.opt, ['network_G', 'flow', 'levelConditional', 'conditional']) == True:
for level in range(self.L + 1):
level_conditionals[level] = rrdbResults[self.levelToName[level]]
for layer, shape in zip(reversed(self.layers), reversed(self.output_shapes)):
size = shape[2]
level = int(np.log(160 / size) / np.log(2))
# size = fl_fea.shape[2]
# level = int(np.log(160 / size) / np.log(2))
if isinstance(layer, Split2d):
fl_fea, logdet = self.forward_split2d_reverse(eps_std, epses, fl_fea, layer,
rrdbResults[self.levelToName[level]], logdet=logdet,
y_onehot=y_onehot)
elif isinstance(layer, FlowStep):
fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True, rrdbResults=level_conditionals[level])
else:
fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True)
sr = fl_fea
assert sr.shape[1] == 3
return sr, logdet
def forward_split2d_reverse(self, eps_std, epses, fl_fea, layer, rrdbResults, logdet, y_onehot=None):
ft = None if layer.position is None else rrdbResults[layer.position]
fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True,
eps=epses.pop() if isinstance(epses, list) else None,
eps_std=eps_std, ft=ft, y_onehot=y_onehot)
return fl_fea, logdet
def get_position_name(H, scale):
downscale_factor = 160 // H
position_name = 'fea_up{}'.format(scale / downscale_factor)
return position_name

View File

@ -0,0 +1,42 @@
import numpy as np
import torch
from torch import nn as nn
from torch.nn import functional as F
from models.modules import thops
class InvertibleConv1x1(nn.Module):
def __init__(self, num_channels, LU_decomposed=False):
super().__init__()
w_shape = [num_channels, num_channels]
w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype(np.float32)
self.register_parameter("weight", nn.Parameter(torch.Tensor(w_init)))
self.w_shape = w_shape
self.LU = LU_decomposed
def get_weight(self, input, reverse):
w_shape = self.w_shape
pixels = thops.pixels(input)
dlogdet = torch.slogdet(self.weight)[1] * pixels
if not reverse:
weight = self.weight.view(w_shape[0], w_shape[1], 1, 1)
else:
weight = torch.inverse(self.weight.double()).float() \
.view(w_shape[0], w_shape[1], 1, 1)
return weight, dlogdet
def forward(self, input, logdet=None, reverse=False):
"""
log-det = log|abs(|W|)| * pixels
"""
weight, dlogdet = self.get_weight(input, reverse)
if not reverse:
z = F.conv2d(input, weight)
if logdet is not None:
logdet = logdet + dlogdet
return z, logdet
else:
z = F.conv2d(input, weight)
if logdet is not None:
logdet = logdet - dlogdet
return z, logdet

View File

@ -0,0 +1,132 @@
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
import models.archs.srflow.module_util as mutil
from utils.util import opt_get
class ResidualDenseBlock_5C(nn.Module):
def __init__(self, nf=64, gc=32, bias=True):
super(ResidualDenseBlock_5C, self).__init__()
# gc: growth channel, i.e. intermediate channels
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
# initialization
mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
def forward(self, x):
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return x5 * 0.2 + x
class RRDB(nn.Module):
'''Residual in Residual Dense Block'''
def __init__(self, nf, gc=32):
super(RRDB, self).__init__()
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
def forward(self, x):
out = self.RDB1(x)
out = self.RDB2(out)
out = self.RDB3(out)
return out * 0.2 + x
class RRDBNet(nn.Module):
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, opt=None):
self.opt = opt
super(RRDBNet, self).__init__()
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
self.scale = scale
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
self.RRDB_trunk = mutil.make_layer(RRDB_block_f, nb)
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
#### upsampling
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
if self.scale >= 8:
self.upconv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
if self.scale >= 16:
self.upconv4 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
if self.scale >= 32:
self.upconv5 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x, get_steps=False):
fea = self.conv_first(x)
block_idxs = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or []
block_results = {}
for idx, m in enumerate(self.RRDB_trunk.children()):
fea = m(fea)
for b in block_idxs:
if b == idx:
block_results["block_{}".format(idx)] = fea
trunk = self.trunk_conv(fea)
last_lr_fea = fea + trunk
fea_up2 = self.upconv1(F.interpolate(last_lr_fea, scale_factor=2, mode='nearest'))
fea = self.lrelu(fea_up2)
fea_up4 = self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))
fea = self.lrelu(fea_up4)
fea_up8 = None
fea_up16 = None
fea_up32 = None
if self.scale >= 8:
fea_up8 = self.upconv3(F.interpolate(fea, scale_factor=2, mode='nearest'))
fea = self.lrelu(fea_up8)
if self.scale >= 16:
fea_up16 = self.upconv4(F.interpolate(fea, scale_factor=2, mode='nearest'))
fea = self.lrelu(fea_up16)
if self.scale >= 32:
fea_up32 = self.upconv5(F.interpolate(fea, scale_factor=2, mode='nearest'))
fea = self.lrelu(fea_up32)
out = self.conv_last(self.lrelu(self.HRconv(fea)))
results = {'last_lr_fea': last_lr_fea,
'fea_up1': last_lr_fea,
'fea_up2': fea_up2,
'fea_up4': fea_up4,
'fea_up8': fea_up8,
'fea_up16': fea_up16,
'fea_up32': fea_up32,
'out': out}
fea_up0_en = opt_get(self.opt, ['network_G', 'flow', 'fea_up0']) or False
if fea_up0_en:
results['fea_up0'] = F.interpolate(last_lr_fea, scale_factor=1/2, mode='bilinear', align_corners=False, recompute_scale_factor=True)
fea_upn1_en = opt_get(self.opt, ['network_G', 'flow', 'fea_up-1']) or False
if fea_upn1_en:
results['fea_up-1'] = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True)
if get_steps:
for k, v in block_results.items():
results[k] = v
return results
else:
return out

View File

@ -0,0 +1,69 @@
import torch
from torch import nn as nn
from models.archs.srflow import thops
from models.archs.srflow.FlowStep import FlowStep
from models.archs.srflow.flow import Conv2dZeros, GaussianDiag
class Split2d(nn.Module):
def __init__(self, num_channels, logs_eps=0, cond_channels=0, position=None, consume_ratio=0.5, opt=None):
super().__init__()
self.num_channels_consume = int(round(num_channels * consume_ratio))
self.num_channels_pass = num_channels - self.num_channels_consume
self.conv = Conv2dZeros(in_channels=self.num_channels_pass + cond_channels,
out_channels=self.num_channels_consume * 2)
self.logs_eps = logs_eps
self.position = position
self.opt = opt
def split2d_prior(self, z, ft):
if ft is not None:
z = torch.cat([z, ft], dim=1)
h = self.conv(z)
return thops.split_feature(h, "cross")
def exp_eps(self, logs):
return torch.exp(logs) + self.logs_eps
def forward(self, input, logdet=0., reverse=False, eps_std=None, eps=None, ft=None, y_onehot=None):
if not reverse:
# self.input = input
z1, z2 = self.split_ratio(input)
mean, logs = self.split2d_prior(z1, ft)
eps = (z2 - mean) / self.exp_eps(logs)
logdet = logdet + self.get_logdet(logs, mean, z2)
# print(logs.shape, mean.shape, z2.shape)
# self.eps = eps
# print('split, enc eps:', eps)
return z1, logdet, eps
else:
z1 = input
mean, logs = self.split2d_prior(z1, ft)
if eps is None:
#print("WARNING: eps is None, generating eps untested functionality!")
eps = GaussianDiag.sample_eps(mean.shape, eps_std)
eps = eps.to(mean.device)
z2 = mean + self.exp_eps(logs) * eps
z = thops.cat_feature(z1, z2)
logdet = logdet - self.get_logdet(logs, mean, z2)
return z, logdet
# return z, logdet, eps
def get_logdet(self, logs, mean, z2):
logdet_diff = GaussianDiag.logp(mean, logs, z2)
# print("Split2D: logdet diff", logdet_diff.item())
return logdet_diff
def split_ratio(self, input):
z1, z2 = input[:, :self.num_channels_pass, ...], input[:, self.num_channels_pass:, ...]
return z1, z2

View File

View File

@ -0,0 +1,150 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from models.archs.srflow.FlowActNorms import ActNorm2d
from . import thops
class Conv2d(nn.Conv2d):
pad_dict = {
"same": lambda kernel, stride: [((k - 1) * s + 1) // 2 for k, s in zip(kernel, stride)],
"valid": lambda kernel, stride: [0 for _ in kernel]
}
@staticmethod
def get_padding(padding, kernel_size, stride):
# make paddding
if isinstance(padding, str):
if isinstance(kernel_size, int):
kernel_size = [kernel_size, kernel_size]
if isinstance(stride, int):
stride = [stride, stride]
padding = padding.lower()
try:
padding = Conv2d.pad_dict[padding](kernel_size, stride)
except KeyError:
raise ValueError("{} is not supported".format(padding))
return padding
def __init__(self, in_channels, out_channels,
kernel_size=[3, 3], stride=[1, 1],
padding="same", do_actnorm=True, weight_std=0.05):
padding = Conv2d.get_padding(padding, kernel_size, stride)
super().__init__(in_channels, out_channels, kernel_size, stride,
padding, bias=(not do_actnorm))
# init weight with std
self.weight.data.normal_(mean=0.0, std=weight_std)
if not do_actnorm:
self.bias.data.zero_()
else:
self.actnorm = ActNorm2d(out_channels)
self.do_actnorm = do_actnorm
def forward(self, input):
x = super().forward(input)
if self.do_actnorm:
x, _ = self.actnorm(x)
return x
class Conv2dZeros(nn.Conv2d):
def __init__(self, in_channels, out_channels,
kernel_size=[3, 3], stride=[1, 1],
padding="same", logscale_factor=3):
padding = Conv2d.get_padding(padding, kernel_size, stride)
super().__init__(in_channels, out_channels, kernel_size, stride, padding)
# logscale_factor
self.logscale_factor = logscale_factor
self.register_parameter("logs", nn.Parameter(torch.zeros(out_channels, 1, 1)))
# init
self.weight.data.zero_()
self.bias.data.zero_()
def forward(self, input):
output = super().forward(input)
return output * torch.exp(self.logs * self.logscale_factor)
class GaussianDiag:
Log2PI = float(np.log(2 * np.pi))
@staticmethod
def likelihood(mean, logs, x):
"""
lnL = -1/2 * { ln|Var| + ((X - Mu)^T)(Var^-1)(X - Mu) + kln(2*PI) }
k = 1 (Independent)
Var = logs ** 2
"""
if mean is None and logs is None:
return -0.5 * (x ** 2 + GaussianDiag.Log2PI)
else:
return -0.5 * (logs * 2. + ((x - mean) ** 2) / torch.exp(logs * 2.) + GaussianDiag.Log2PI)
@staticmethod
def logp(mean, logs, x):
likelihood = GaussianDiag.likelihood(mean, logs, x)
return thops.sum(likelihood, dim=[1, 2, 3])
@staticmethod
def sample(mean, logs, eps_std=None):
eps_std = eps_std or 1
eps = torch.normal(mean=torch.zeros_like(mean),
std=torch.ones_like(logs) * eps_std)
return mean + torch.exp(logs) * eps
@staticmethod
def sample_eps(shape, eps_std, seed=None):
if seed is not None:
torch.manual_seed(seed)
eps = torch.normal(mean=torch.zeros(shape),
std=torch.ones(shape) * eps_std)
return eps
def squeeze2d(input, factor=2):
assert factor >= 1 and isinstance(factor, int)
if factor == 1:
return input
size = input.size()
B = size[0]
C = size[1]
H = size[2]
W = size[3]
assert H % factor == 0 and W % factor == 0, "{}".format((H, W, factor))
x = input.view(B, C, H // factor, factor, W // factor, factor)
x = x.permute(0, 1, 3, 5, 2, 4).contiguous()
x = x.view(B, C * factor * factor, H // factor, W // factor)
return x
def unsqueeze2d(input, factor=2):
assert factor >= 1 and isinstance(factor, int)
factor2 = factor ** 2
if factor == 1:
return input
size = input.size()
B = size[0]
C = size[1]
H = size[2]
W = size[3]
assert C % (factor2) == 0, "{}".format(C)
x = input.view(B, C // factor2, factor, factor, H, W)
x = x.permute(0, 1, 4, 2, 5, 3).contiguous()
x = x.view(B, C // (factor2), H * factor, W * factor)
return x
class SqueezeLayer(nn.Module):
def __init__(self, factor):
super().__init__()
self.factor = factor
def forward(self, input, logdet=None, reverse=False):
if not reverse:
output = squeeze2d(input, self.factor) # Squeeze in forward
return output, logdet
else:
output = unsqueeze2d(input, self.factor)
return output, logdet

View File

@ -0,0 +1,12 @@
import torch.nn as nn
def f_conv2d_bias(in_channels, out_channels):
def padding_same(kernel, stride):
return [((k - 1) * s + 1) // 2 for k, s in zip(kernel, stride)]
padding = padding_same([3, 3], [1, 1])
assert padding == [1, 1], padding
return nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=[3, 3], stride=1, padding=1,
bias=True))

View File

@ -0,0 +1,79 @@
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
def initialize_weights(net_l, scale=1):
if not isinstance(net_l, list):
net_l = [net_l]
for net in net_l:
for m in net.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
m.weight.data *= scale # for residual block
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
m.weight.data *= scale
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias.data, 0.0)
def make_layer(block, n_layers):
layers = []
for _ in range(n_layers):
layers.append(block())
return nn.Sequential(*layers)
class ResidualBlock_noBN(nn.Module):
'''Residual block w/o BN
---Conv-ReLU-Conv-+-
|________________|
'''
def __init__(self, nf=64):
super(ResidualBlock_noBN, self).__init__()
self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
# initialization
initialize_weights([self.conv1, self.conv2], 0.1)
def forward(self, x):
identity = x
out = F.relu(self.conv1(x), inplace=True)
out = self.conv2(out)
return identity + out
def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'):
"""Warp an image or feature map with optical flow
Args:
x (Tensor): size (N, C, H, W)
flow (Tensor): size (N, H, W, 2), normal value
interp_mode (str): 'nearest' or 'bilinear'
padding_mode (str): 'zeros' or 'border' or 'reflection'
Returns:
Tensor: warped image or feature map
"""
assert x.size()[-2:] == flow.size()[1:3]
B, C, H, W = x.size()
# mesh grid
grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
grid.requires_grad = False
grid = grid.type_as(x)
vgrid = grid + flow
# scale grid to [-1,1]
vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0
vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0
vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode)
return output

View File

@ -0,0 +1,141 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from models.archs.srflow.FlowUpsamplerNet import FlowUpsamplerNet
import models.archs.srflow.thops as thops
import models.archs.srflow.flow as flow
from utils.util import opt_get
class SRFlowNet(nn.Module):
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, K=None, opt=None, step=None):
super(SRFlowNet, self).__init__()
self.opt = opt
self.quant = 255 if opt_get(opt, ['datasets', 'train', 'quant']) is \
None else opt_get(opt, ['datasets', 'train', 'quant'])
self.RRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, opt)
hidden_channels = opt_get(opt, ['network_G', 'flow', 'hidden_channels'])
hidden_channels = hidden_channels or 64
self.RRDB_training = True # Default is true
train_RRDB_delay = opt_get(self.opt, ['network_G', 'train_RRDB_delay'])
set_RRDB_to_train = False
if set_RRDB_to_train:
self.set_rrdb_training(True)
self.flowUpsamplerNet = \
FlowUpsamplerNet((160, 160, 3), hidden_channels, K,
flow_coupling=opt['network_G']['flow']['coupling'], opt=opt)
self.i = 0
def set_rrdb_training(self, trainable):
if self.RRDB_training != trainable:
for p in self.RRDB.parameters():
p.requires_grad = trainable
self.RRDB_training = trainable
return True
return False
def forward(self, gt=None, lr=None, z=None, eps_std=None, reverse=False, epses=None, reverse_with_grad=False,
lr_enc=None,
add_gt_noise=False, step=None, y_label=None):
if not reverse:
return self.normal_flow(gt, lr, epses=epses, lr_enc=lr_enc, add_gt_noise=add_gt_noise, step=step,
y_onehot=y_label)
else:
# assert lr.shape[0] == 1
assert lr.shape[1] == 3
# assert lr.shape[2] == 20
# assert lr.shape[3] == 20
# assert z.shape[0] == 1
# assert z.shape[1] == 3 * 8 * 8
# assert z.shape[2] == 20
# assert z.shape[3] == 20
if reverse_with_grad:
return self.reverse_flow(lr, z, y_onehot=y_label, eps_std=eps_std, epses=epses, lr_enc=lr_enc,
add_gt_noise=add_gt_noise)
else:
with torch.no_grad():
return self.reverse_flow(lr, z, y_onehot=y_label, eps_std=eps_std, epses=epses, lr_enc=lr_enc,
add_gt_noise=add_gt_noise)
def normal_flow(self, gt, lr, y_onehot=None, epses=None, lr_enc=None, add_gt_noise=True, step=None):
if lr_enc is None:
lr_enc = self.rrdbPreprocessing(lr)
logdet = torch.zeros_like(gt[:, 0, 0, 0])
pixels = thops.pixels(gt)
z = gt
if add_gt_noise:
# Setup
noiseQuant = opt_get(self.opt, ['network_G', 'flow', 'augmentation', 'noiseQuant'], True)
if noiseQuant:
z = z + ((torch.rand(z.shape, device=z.device) - 0.5) / self.quant)
logdet = logdet + float(-np.log(self.quant) * pixels)
# Encode
epses, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, gt=z, logdet=logdet, reverse=False, epses=epses,
y_onehot=y_onehot)
objective = logdet.clone()
if isinstance(epses, (list, tuple)):
z = epses[-1]
else:
z = epses
objective = objective + flow.GaussianDiag.logp(None, None, z)
nll = (-objective) / float(np.log(2.) * pixels)
if isinstance(epses, list):
return epses, nll, logdet
return z, nll, logdet
def rrdbPreprocessing(self, lr):
rrdbResults = self.RRDB(lr, get_steps=True)
block_idxs = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or []
if len(block_idxs) > 0:
concat = torch.cat([rrdbResults["block_{}".format(idx)] for idx in block_idxs], dim=1)
if opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'concat']) or False:
keys = ['last_lr_fea', 'fea_up1', 'fea_up2', 'fea_up4']
if 'fea_up0' in rrdbResults.keys():
keys.append('fea_up0')
if 'fea_up-1' in rrdbResults.keys():
keys.append('fea_up-1')
if self.opt['scale'] >= 8:
keys.append('fea_up8')
if self.opt['scale'] == 16:
keys.append('fea_up16')
for k in keys:
h = rrdbResults[k].shape[2]
w = rrdbResults[k].shape[3]
rrdbResults[k] = torch.cat([rrdbResults[k], F.interpolate(concat, (h, w))], dim=1)
return rrdbResults
def get_score(self, disc_loss_sigma, z):
score_real = 0.5 * (1 - 1 / (disc_loss_sigma ** 2)) * thops.sum(z ** 2, dim=[1, 2, 3]) - \
z.shape[1] * z.shape[2] * z.shape[3] * math.log(disc_loss_sigma)
return -score_real
def reverse_flow(self, lr, z, y_onehot, eps_std, epses=None, lr_enc=None, add_gt_noise=True):
logdet = torch.zeros_like(lr[:, 0, 0, 0])
pixels = thops.pixels(lr) * self.opt['scale'] ** 2
if add_gt_noise:
logdet = logdet - float(-np.log(self.quant) * pixels)
if lr_enc is None:
lr_enc = self.rrdbPreprocessing(lr)
x, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, z=z, eps_std=eps_std, reverse=True, epses=epses,
logdet=logdet)
return x, logdet

View File

@ -0,0 +1,52 @@
import torch
def sum(tensor, dim=None, keepdim=False):
if dim is None:
# sum up all dim
return torch.sum(tensor)
else:
if isinstance(dim, int):
dim = [dim]
dim = sorted(dim)
for d in dim:
tensor = tensor.sum(dim=d, keepdim=True)
if not keepdim:
for i, d in enumerate(dim):
tensor.squeeze_(d-i)
return tensor
def mean(tensor, dim=None, keepdim=False):
if dim is None:
# mean all dim
return torch.mean(tensor)
else:
if isinstance(dim, int):
dim = [dim]
dim = sorted(dim)
for d in dim:
tensor = tensor.mean(dim=d, keepdim=True)
if not keepdim:
for i, d in enumerate(dim):
tensor.squeeze_(d-i)
return tensor
def split_feature(tensor, type="split"):
"""
type = ["split", "cross"]
"""
C = tensor.size(1)
if type == "split":
return tensor[:, :C // 2, ...], tensor[:, C // 2:, ...]
elif type == "cross":
return tensor[:, 0::2, ...], tensor[:, 1::2, ...]
def cat_feature(tensor_a, tensor_b):
return torch.cat((tensor_a, tensor_b), dim=1)
def pixels(tensor):
return int(tensor.size(2) * tensor.size(3))