Add srflow arch
This commit is contained in:
parent
4469d2e661
commit
34d319585c
125
codes/models/archs/srflow/FlowActNorms.py
Normal file
125
codes/models/archs/srflow/FlowActNorms.py
Normal file
|
@ -0,0 +1,125 @@
|
|||
import torch
|
||||
from torch import nn as nn
|
||||
|
||||
from models.archs.srflow import thops
|
||||
|
||||
|
||||
class _ActNorm(nn.Module):
|
||||
"""
|
||||
Activation Normalization
|
||||
Initialize the bias and scale with a given minibatch,
|
||||
so that the output per-channel have zero mean and unit variance for that.
|
||||
|
||||
After initialization, `bias` and `logs` will be trained as parameters.
|
||||
"""
|
||||
|
||||
def __init__(self, num_features, scale=1.):
|
||||
super().__init__()
|
||||
# register mean and scale
|
||||
size = [1, num_features, 1, 1]
|
||||
self.register_parameter("bias", nn.Parameter(torch.zeros(*size)))
|
||||
self.register_parameter("logs", nn.Parameter(torch.zeros(*size)))
|
||||
self.num_features = num_features
|
||||
self.scale = float(scale)
|
||||
self.inited = False
|
||||
|
||||
def _check_input_dim(self, input):
|
||||
return NotImplemented
|
||||
|
||||
def initialize_parameters(self, input):
|
||||
self._check_input_dim(input)
|
||||
if not self.training:
|
||||
return
|
||||
if (self.bias != 0).any():
|
||||
self.inited = True
|
||||
return
|
||||
assert input.device == self.bias.device, (input.device, self.bias.device)
|
||||
with torch.no_grad():
|
||||
bias = thops.mean(input.clone(), dim=[0, 2, 3], keepdim=True) * -1.0
|
||||
vars = thops.mean((input.clone() + bias) ** 2, dim=[0, 2, 3], keepdim=True)
|
||||
logs = torch.log(self.scale / (torch.sqrt(vars) + 1e-6))
|
||||
self.bias.data.copy_(bias.data)
|
||||
self.logs.data.copy_(logs.data)
|
||||
self.inited = True
|
||||
|
||||
def _center(self, input, reverse=False, offset=None):
|
||||
bias = self.bias
|
||||
|
||||
if offset is not None:
|
||||
bias = bias + offset
|
||||
|
||||
if not reverse:
|
||||
return input + bias
|
||||
else:
|
||||
return input - bias
|
||||
|
||||
def _scale(self, input, logdet=None, reverse=False, offset=None):
|
||||
logs = self.logs
|
||||
|
||||
if offset is not None:
|
||||
logs = logs + offset
|
||||
|
||||
if not reverse:
|
||||
input = input * torch.exp(logs) # should have shape batchsize, n_channels, 1, 1
|
||||
# input = input * torch.exp(logs+logs_offset)
|
||||
else:
|
||||
input = input * torch.exp(-logs)
|
||||
if logdet is not None:
|
||||
"""
|
||||
logs is log_std of `mean of channels`
|
||||
so we need to multiply pixels
|
||||
"""
|
||||
dlogdet = thops.sum(logs) * thops.pixels(input)
|
||||
if reverse:
|
||||
dlogdet *= -1
|
||||
logdet = logdet + dlogdet
|
||||
return input, logdet
|
||||
|
||||
def forward(self, input, logdet=None, reverse=False, offset_mask=None, logs_offset=None, bias_offset=None):
|
||||
if not self.inited:
|
||||
self.initialize_parameters(input)
|
||||
self._check_input_dim(input)
|
||||
|
||||
if offset_mask is not None:
|
||||
logs_offset *= offset_mask
|
||||
bias_offset *= offset_mask
|
||||
# no need to permute dims as old version
|
||||
if not reverse:
|
||||
# center and scale
|
||||
|
||||
# self.input = input
|
||||
input = self._center(input, reverse, bias_offset)
|
||||
input, logdet = self._scale(input, logdet, reverse, logs_offset)
|
||||
else:
|
||||
# scale and center
|
||||
input, logdet = self._scale(input, logdet, reverse, logs_offset)
|
||||
input = self._center(input, reverse, bias_offset)
|
||||
return input, logdet
|
||||
|
||||
|
||||
class ActNorm2d(_ActNorm):
|
||||
def __init__(self, num_features, scale=1.):
|
||||
super().__init__(num_features, scale)
|
||||
|
||||
def _check_input_dim(self, input):
|
||||
assert len(input.size()) == 4
|
||||
assert input.size(1) == self.num_features, (
|
||||
"[ActNorm]: input should be in shape as `BCHW`,"
|
||||
" channels should be {} rather than {}".format(
|
||||
self.num_features, input.size()))
|
||||
|
||||
|
||||
class MaskedActNorm2d(ActNorm2d):
|
||||
def __init__(self, num_features, scale=1.):
|
||||
super().__init__(num_features, scale)
|
||||
|
||||
def forward(self, input, mask, logdet=None, reverse=False):
|
||||
|
||||
assert mask.dtype == torch.bool
|
||||
output, logdet_out = super().forward(input, logdet, reverse)
|
||||
|
||||
input[mask] = output[mask]
|
||||
logdet[mask] = logdet_out[mask]
|
||||
|
||||
return input, logdet
|
||||
|
119
codes/models/archs/srflow/FlowAffineCouplingsAblation.py
Normal file
119
codes/models/archs/srflow/FlowAffineCouplingsAblation.py
Normal file
|
@ -0,0 +1,119 @@
|
|||
import torch
|
||||
from torch import nn as nn
|
||||
|
||||
from models.archs.srflow import thops
|
||||
from models.archs.srflow.flow import Conv2d, Conv2dZeros
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
class CondAffineSeparatedAndCond(nn.Module):
|
||||
def __init__(self, in_channels, opt):
|
||||
super().__init__()
|
||||
self.need_features = True
|
||||
self.in_channels = in_channels
|
||||
self.in_channels_rrdb = 320
|
||||
self.kernel_hidden = 1
|
||||
self.affine_eps = 0.0001
|
||||
self.n_hidden_layers = 1
|
||||
hidden_channels = opt_get(opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'hidden_channels'])
|
||||
self.hidden_channels = 64 if hidden_channels is None else hidden_channels
|
||||
|
||||
self.affine_eps = opt_get(opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'eps'], 0.0001)
|
||||
|
||||
self.channels_for_nn = self.in_channels // 2
|
||||
self.channels_for_co = self.in_channels - self.channels_for_nn
|
||||
|
||||
if self.channels_for_nn is None:
|
||||
self.channels_for_nn = self.in_channels // 2
|
||||
|
||||
self.fAffine = self.F(in_channels=self.channels_for_nn + self.in_channels_rrdb,
|
||||
out_channels=self.channels_for_co * 2,
|
||||
hidden_channels=self.hidden_channels,
|
||||
kernel_hidden=self.kernel_hidden,
|
||||
n_hidden_layers=self.n_hidden_layers)
|
||||
|
||||
self.fFeatures = self.F(in_channels=self.in_channels_rrdb,
|
||||
out_channels=self.in_channels * 2,
|
||||
hidden_channels=self.hidden_channels,
|
||||
kernel_hidden=self.kernel_hidden,
|
||||
n_hidden_layers=self.n_hidden_layers)
|
||||
|
||||
def forward(self, input: torch.Tensor, logdet=None, reverse=False, ft=None):
|
||||
if not reverse:
|
||||
z = input
|
||||
assert z.shape[1] == self.in_channels, (z.shape[1], self.in_channels)
|
||||
|
||||
# Feature Conditional
|
||||
scaleFt, shiftFt = self.feature_extract(ft, self.fFeatures)
|
||||
z = z + shiftFt
|
||||
z = z * scaleFt
|
||||
logdet = logdet + self.get_logdet(scaleFt)
|
||||
|
||||
# Self Conditional
|
||||
z1, z2 = self.split(z)
|
||||
scale, shift = self.feature_extract_aff(z1, ft, self.fAffine)
|
||||
self.asserts(scale, shift, z1, z2)
|
||||
z2 = z2 + shift
|
||||
z2 = z2 * scale
|
||||
|
||||
logdet = logdet + self.get_logdet(scale)
|
||||
z = thops.cat_feature(z1, z2)
|
||||
output = z
|
||||
else:
|
||||
z = input
|
||||
|
||||
# Self Conditional
|
||||
z1, z2 = self.split(z)
|
||||
scale, shift = self.feature_extract_aff(z1, ft, self.fAffine)
|
||||
self.asserts(scale, shift, z1, z2)
|
||||
z2 = z2 / scale
|
||||
z2 = z2 - shift
|
||||
z = thops.cat_feature(z1, z2)
|
||||
logdet = logdet - self.get_logdet(scale)
|
||||
|
||||
# Feature Conditional
|
||||
scaleFt, shiftFt = self.feature_extract(ft, self.fFeatures)
|
||||
z = z / scaleFt
|
||||
z = z - shiftFt
|
||||
logdet = logdet - self.get_logdet(scaleFt)
|
||||
|
||||
output = z
|
||||
return output, logdet
|
||||
|
||||
def asserts(self, scale, shift, z1, z2):
|
||||
assert z1.shape[1] == self.channels_for_nn, (z1.shape[1], self.channels_for_nn)
|
||||
assert z2.shape[1] == self.channels_for_co, (z2.shape[1], self.channels_for_co)
|
||||
assert scale.shape[1] == shift.shape[1], (scale.shape[1], shift.shape[1])
|
||||
assert scale.shape[1] == z2.shape[1], (scale.shape[1], z1.shape[1], z2.shape[1])
|
||||
|
||||
def get_logdet(self, scale):
|
||||
return thops.sum(torch.log(scale), dim=[1, 2, 3])
|
||||
|
||||
def feature_extract(self, z, f):
|
||||
h = f(z)
|
||||
shift, scale = thops.split_feature(h, "cross")
|
||||
scale = (torch.sigmoid(scale + 2.) + self.affine_eps)
|
||||
return scale, shift
|
||||
|
||||
def feature_extract_aff(self, z1, ft, f):
|
||||
z = torch.cat([z1, ft], dim=1)
|
||||
h = f(z)
|
||||
shift, scale = thops.split_feature(h, "cross")
|
||||
scale = (torch.sigmoid(scale + 2.) + self.affine_eps)
|
||||
return scale, shift
|
||||
|
||||
def split(self, z):
|
||||
z1 = z[:, :self.channels_for_nn]
|
||||
z2 = z[:, self.channels_for_nn:]
|
||||
assert z1.shape[1] + z2.shape[1] == z.shape[1], (z1.shape[1], z2.shape[1], z.shape[1])
|
||||
return z1, z2
|
||||
|
||||
def F(self, in_channels, out_channels, hidden_channels, kernel_hidden=1, n_hidden_layers=1):
|
||||
layers = [Conv2d(in_channels, hidden_channels), nn.ReLU(inplace=False)]
|
||||
|
||||
for _ in range(n_hidden_layers):
|
||||
layers.append(Conv2d(hidden_channels, hidden_channels, kernel_size=[kernel_hidden, kernel_hidden]))
|
||||
layers.append(nn.ReLU(inplace=False))
|
||||
layers.append(Conv2dZeros(hidden_channels, out_channels))
|
||||
|
||||
return nn.Sequential(*layers)
|
121
codes/models/archs/srflow/FlowStep.py
Normal file
121
codes/models/archs/srflow/FlowStep.py
Normal file
|
@ -0,0 +1,121 @@
|
|||
import torch
|
||||
from torch import nn as nn
|
||||
|
||||
import models.archs.srflow
|
||||
import models.archs.srflow.Permutations
|
||||
from models.archs.srflow import flow, thops, FlowAffineCouplingsAblation
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
def getConditional(rrdbResults, position):
|
||||
img_ft = rrdbResults if isinstance(rrdbResults, torch.Tensor) else rrdbResults[position]
|
||||
return img_ft
|
||||
|
||||
|
||||
class FlowStep(nn.Module):
|
||||
FlowPermutation = {
|
||||
"reverse": lambda obj, z, logdet, rev: (obj.reverse(z, rev), logdet),
|
||||
"shuffle": lambda obj, z, logdet, rev: (obj.shuffle(z, rev), logdet),
|
||||
"invconv": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
|
||||
"squeeze_invconv": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
|
||||
"resqueeze_invconv_alternating_2_3": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
|
||||
"resqueeze_invconv_3": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
|
||||
"InvertibleConv1x1GridAlign": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
|
||||
"InvertibleConv1x1SubblocksShuf": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
|
||||
"InvertibleConv1x1GridAlignIndepBorder": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
|
||||
"InvertibleConv1x1GridAlignIndepBorder4": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
|
||||
}
|
||||
|
||||
def __init__(self, in_channels, hidden_channels,
|
||||
actnorm_scale=1.0, flow_permutation="invconv", flow_coupling="additive",
|
||||
LU_decomposed=False, opt=None, image_injector=None, idx=None, acOpt=None, normOpt=None, in_shape=None,
|
||||
position=None):
|
||||
# check configures
|
||||
assert flow_permutation in FlowStep.FlowPermutation, \
|
||||
"float_permutation should be in `{}`".format(
|
||||
FlowStep.FlowPermutation.keys())
|
||||
super().__init__()
|
||||
self.flow_permutation = flow_permutation
|
||||
self.flow_coupling = flow_coupling
|
||||
self.image_injector = image_injector
|
||||
|
||||
self.norm_type = normOpt['type'] if normOpt else 'ActNorm2d'
|
||||
self.position = normOpt['position'] if normOpt else None
|
||||
|
||||
self.in_shape = in_shape
|
||||
self.position = position
|
||||
self.acOpt = acOpt
|
||||
|
||||
# 1. actnorm
|
||||
self.actnorm = models.modules.FlowActNorms.ActNorm2d(in_channels, actnorm_scale)
|
||||
|
||||
# 2. permute
|
||||
if flow_permutation == "invconv":
|
||||
self.invconv = models.modules.Permutations.InvertibleConv1x1(
|
||||
in_channels, LU_decomposed=LU_decomposed)
|
||||
|
||||
# 3. coupling
|
||||
if flow_coupling == "CondAffineSeparatedAndCond":
|
||||
self.affine = models.modules.FlowAffineCouplingsAblation.CondAffineSeparatedAndCond(in_channels=in_channels,
|
||||
opt=opt)
|
||||
elif flow_coupling == "noCoupling":
|
||||
pass
|
||||
else:
|
||||
raise RuntimeError("coupling not Found:", flow_coupling)
|
||||
|
||||
def forward(self, input, logdet=None, reverse=False, rrdbResults=None):
|
||||
if not reverse:
|
||||
return self.normal_flow(input, logdet, rrdbResults)
|
||||
else:
|
||||
return self.reverse_flow(input, logdet, rrdbResults)
|
||||
|
||||
def normal_flow(self, z, logdet, rrdbResults=None):
|
||||
if self.flow_coupling == "bentIdentityPreAct":
|
||||
z, logdet = self.bentIdentPar(z, logdet, reverse=False)
|
||||
|
||||
# 1. actnorm
|
||||
if self.norm_type == "ConditionalActNormImageInjector":
|
||||
img_ft = getConditional(rrdbResults, self.position)
|
||||
z, logdet = self.actnorm(z, img_ft=img_ft, logdet=logdet, reverse=False)
|
||||
elif self.norm_type == "noNorm":
|
||||
pass
|
||||
else:
|
||||
z, logdet = self.actnorm(z, logdet=logdet, reverse=False)
|
||||
|
||||
# 2. permute
|
||||
z, logdet = FlowStep.FlowPermutation[self.flow_permutation](
|
||||
self, z, logdet, False)
|
||||
|
||||
need_features = self.affine_need_features()
|
||||
|
||||
# 3. coupling
|
||||
if need_features or self.flow_coupling in ["condAffine", "condFtAffine", "condNormAffine"]:
|
||||
img_ft = getConditional(rrdbResults, self.position)
|
||||
z, logdet = self.affine(input=z, logdet=logdet, reverse=False, ft=img_ft)
|
||||
return z, logdet
|
||||
|
||||
def reverse_flow(self, z, logdet, rrdbResults=None):
|
||||
|
||||
need_features = self.affine_need_features()
|
||||
|
||||
# 1.coupling
|
||||
if need_features or self.flow_coupling in ["condAffine", "condFtAffine", "condNormAffine"]:
|
||||
img_ft = getConditional(rrdbResults, self.position)
|
||||
z, logdet = self.affine(input=z, logdet=logdet, reverse=True, ft=img_ft)
|
||||
|
||||
# 2. permute
|
||||
z, logdet = FlowStep.FlowPermutation[self.flow_permutation](
|
||||
self, z, logdet, True)
|
||||
|
||||
# 3. actnorm
|
||||
z, logdet = self.actnorm(z, logdet=logdet, reverse=True)
|
||||
|
||||
return z, logdet
|
||||
|
||||
def affine_need_features(self):
|
||||
need_features = False
|
||||
try:
|
||||
need_features = self.affine.need_features
|
||||
except:
|
||||
pass
|
||||
return need_features
|
293
codes/models/archs/srflow/FlowUpsamplerNet.py
Normal file
293
codes/models/archs/srflow/FlowUpsamplerNet.py
Normal file
|
@ -0,0 +1,293 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
|
||||
import models.archs.srflow.Split
|
||||
from models.archs.srflow import flow, thops
|
||||
from models.archs.srflow.Split import Split2d
|
||||
from models.archs.srflow.glow_arch import f_conv2d_bias
|
||||
from models.archs.srflow.FlowStep import FlowStep
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
class FlowUpsamplerNet(nn.Module):
|
||||
def __init__(self, image_shape, hidden_channels, K, L=None,
|
||||
actnorm_scale=1.0,
|
||||
flow_permutation=None,
|
||||
flow_coupling="affine",
|
||||
LU_decomposed=False, opt=None):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
self.output_shapes = []
|
||||
self.L = opt_get(opt, ['network_G', 'flow', 'L'])
|
||||
self.K = opt_get(opt, ['network_G', 'flow', 'K'])
|
||||
if isinstance(self.K, int):
|
||||
self.K = [K for K in [K, ] * (self.L + 1)]
|
||||
|
||||
self.opt = opt
|
||||
H, W, self.C = image_shape
|
||||
self.check_image_shape()
|
||||
|
||||
if opt['scale'] == 16:
|
||||
self.levelToName = {
|
||||
0: 'fea_up16',
|
||||
1: 'fea_up8',
|
||||
2: 'fea_up4',
|
||||
3: 'fea_up2',
|
||||
4: 'fea_up1',
|
||||
}
|
||||
|
||||
if opt['scale'] == 8:
|
||||
self.levelToName = {
|
||||
0: 'fea_up8',
|
||||
1: 'fea_up4',
|
||||
2: 'fea_up2',
|
||||
3: 'fea_up1',
|
||||
4: 'fea_up0'
|
||||
}
|
||||
|
||||
elif opt['scale'] == 4:
|
||||
self.levelToName = {
|
||||
0: 'fea_up4',
|
||||
1: 'fea_up2',
|
||||
2: 'fea_up1',
|
||||
3: 'fea_up0',
|
||||
4: 'fea_up-1'
|
||||
}
|
||||
|
||||
affineInCh = self.get_affineInCh(opt_get)
|
||||
flow_permutation = self.get_flow_permutation(flow_permutation, opt)
|
||||
|
||||
normOpt = opt_get(opt, ['network_G', 'flow', 'norm'])
|
||||
|
||||
conditional_channels = {}
|
||||
n_rrdb = self.get_n_rrdb_channels(opt, opt_get)
|
||||
n_bypass_channels = opt_get(opt, ['network_G', 'flow', 'levelConditional', 'n_channels'])
|
||||
conditional_channels[0] = n_rrdb
|
||||
for level in range(1, self.L + 1):
|
||||
# Level 1 gets conditionals from 2, 3, 4 => L - level
|
||||
# Level 2 gets conditionals from 3, 4
|
||||
# Level 3 gets conditionals from 4
|
||||
# Level 4 gets conditionals from None
|
||||
n_bypass = 0 if n_bypass_channels is None else (self.L - level) * n_bypass_channels
|
||||
conditional_channels[level] = n_rrdb + n_bypass
|
||||
|
||||
# Upsampler
|
||||
for level in range(1, self.L + 1):
|
||||
# 1. Squeeze
|
||||
H, W = self.arch_squeeze(H, W)
|
||||
|
||||
# 2. K FlowStep
|
||||
self.arch_additionalFlowAffine(H, LU_decomposed, W, actnorm_scale, hidden_channels, opt)
|
||||
self.arch_FlowStep(H, self.K[level], LU_decomposed, W, actnorm_scale, affineInCh, flow_coupling,
|
||||
flow_permutation,
|
||||
hidden_channels, normOpt, opt, opt_get,
|
||||
n_conditinal_channels=conditional_channels[level])
|
||||
# Split
|
||||
self.arch_split(H, W, level, self.L, opt, opt_get)
|
||||
|
||||
if opt_get(opt, ['network_G', 'flow', 'split', 'enable']):
|
||||
self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64 // 2 // 2)
|
||||
else:
|
||||
self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64)
|
||||
|
||||
self.H = H
|
||||
self.W = W
|
||||
self.scaleH = 160 / H
|
||||
self.scaleW = 160 / W
|
||||
|
||||
def get_n_rrdb_channels(self, opt, opt_get):
|
||||
blocks = opt_get(opt, ['network_G', 'flow', 'stackRRDB', 'blocks'])
|
||||
n_rrdb = 64 if blocks is None else (len(blocks) + 1) * 64
|
||||
return n_rrdb
|
||||
|
||||
def arch_FlowStep(self, H, K, LU_decomposed, W, actnorm_scale, affineInCh, flow_coupling, flow_permutation,
|
||||
hidden_channels, normOpt, opt, opt_get, n_conditinal_channels=None):
|
||||
condAff = self.get_condAffSetting(opt, opt_get)
|
||||
if condAff is not None:
|
||||
condAff['in_channels_rrdb'] = n_conditinal_channels
|
||||
|
||||
for k in range(K):
|
||||
position_name = get_position_name(H, self.opt['scale'])
|
||||
if normOpt: normOpt['position'] = position_name
|
||||
|
||||
self.layers.append(
|
||||
FlowStep(in_channels=self.C,
|
||||
hidden_channels=hidden_channels,
|
||||
actnorm_scale=actnorm_scale,
|
||||
flow_permutation=flow_permutation,
|
||||
flow_coupling=flow_coupling,
|
||||
acOpt=condAff,
|
||||
position=position_name,
|
||||
LU_decomposed=LU_decomposed, opt=opt, idx=k, normOpt=normOpt))
|
||||
self.output_shapes.append(
|
||||
[-1, self.C, H, W])
|
||||
|
||||
def get_condAffSetting(self, opt, opt_get):
|
||||
condAff = opt_get(opt, ['network_G', 'flow', 'condAff']) or None
|
||||
condAff = opt_get(opt, ['network_G', 'flow', 'condFtAffine']) or condAff
|
||||
return condAff
|
||||
|
||||
def arch_split(self, H, W, L, levels, opt, opt_get):
|
||||
correct_splits = opt_get(opt, ['network_G', 'flow', 'split', 'correct_splits'], False)
|
||||
correction = 0 if correct_splits else 1
|
||||
if opt_get(opt, ['network_G', 'flow', 'split', 'enable']) and L < levels - correction:
|
||||
logs_eps = opt_get(opt, ['network_G', 'flow', 'split', 'logs_eps']) or 0
|
||||
consume_ratio = opt_get(opt, ['network_G', 'flow', 'split', 'consume_ratio']) or 0.5
|
||||
position_name = get_position_name(H, self.opt['scale'])
|
||||
position = position_name if opt_get(opt, ['network_G', 'flow', 'split', 'conditional']) else None
|
||||
cond_channels = opt_get(opt, ['network_G', 'flow', 'split', 'cond_channels'])
|
||||
cond_channels = 0 if cond_channels is None else cond_channels
|
||||
|
||||
t = opt_get(opt, ['network_G', 'flow', 'split', 'type'], 'Split2d')
|
||||
|
||||
if t == 'Split2d':
|
||||
split = models.modules.Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position,
|
||||
cond_channels=cond_channels, consume_ratio=consume_ratio, opt=opt)
|
||||
self.layers.append(split)
|
||||
self.output_shapes.append([-1, split.num_channels_pass, H, W])
|
||||
self.C = split.num_channels_pass
|
||||
|
||||
def arch_additionalFlowAffine(self, H, LU_decomposed, W, actnorm_scale, hidden_channels, opt):
|
||||
if 'additionalFlowNoAffine' in opt['network_G']['flow']:
|
||||
n_additionalFlowNoAffine = int(opt['network_G']['flow']['additionalFlowNoAffine'])
|
||||
for _ in range(n_additionalFlowNoAffine):
|
||||
self.layers.append(
|
||||
FlowStep(in_channels=self.C,
|
||||
hidden_channels=hidden_channels,
|
||||
actnorm_scale=actnorm_scale,
|
||||
flow_permutation='invconv',
|
||||
flow_coupling='noCoupling',
|
||||
LU_decomposed=LU_decomposed, opt=opt))
|
||||
self.output_shapes.append(
|
||||
[-1, self.C, H, W])
|
||||
|
||||
def arch_squeeze(self, H, W):
|
||||
self.C, H, W = self.C * 4, H // 2, W // 2
|
||||
self.layers.append(flow.SqueezeLayer(factor=2))
|
||||
self.output_shapes.append([-1, self.C, H, W])
|
||||
return H, W
|
||||
|
||||
def get_flow_permutation(self, flow_permutation, opt):
|
||||
flow_permutation = opt['network_G']['flow'].get('flow_permutation', 'invconv')
|
||||
return flow_permutation
|
||||
|
||||
def get_affineInCh(self, opt_get):
|
||||
affineInCh = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or []
|
||||
affineInCh = (len(affineInCh) + 1) * 64
|
||||
return affineInCh
|
||||
|
||||
def check_image_shape(self):
|
||||
assert self.C == 1 or self.C == 3, ("image_shape should be HWC, like (64, 64, 3)"
|
||||
"self.C == 1 or self.C == 3")
|
||||
|
||||
def forward(self, gt=None, rrdbResults=None, z=None, epses=None, logdet=0., reverse=False, eps_std=None,
|
||||
y_onehot=None):
|
||||
|
||||
if reverse:
|
||||
epses_copy = [eps for eps in epses] if isinstance(epses, list) else epses
|
||||
|
||||
sr, logdet = self.decode(rrdbResults, z, eps_std, epses=epses_copy, logdet=logdet, y_onehot=y_onehot)
|
||||
return sr, logdet
|
||||
else:
|
||||
assert gt is not None
|
||||
assert rrdbResults is not None
|
||||
z, logdet = self.encode(gt, rrdbResults, logdet=logdet, epses=epses, y_onehot=y_onehot)
|
||||
|
||||
return z, logdet
|
||||
|
||||
def encode(self, gt, rrdbResults, logdet=0.0, epses=None, y_onehot=None):
|
||||
fl_fea = gt
|
||||
reverse = False
|
||||
level_conditionals = {}
|
||||
bypasses = {}
|
||||
|
||||
L = opt_get(self.opt, ['network_G', 'flow', 'L'])
|
||||
|
||||
for level in range(1, L + 1):
|
||||
bypasses[level] = torch.nn.functional.interpolate(gt, scale_factor=2 ** -level, mode='bilinear', align_corners=False)
|
||||
|
||||
for layer, shape in zip(self.layers, self.output_shapes):
|
||||
size = shape[2]
|
||||
level = int(np.log(160 / size) / np.log(2))
|
||||
|
||||
if level > 0 and level not in level_conditionals.keys():
|
||||
level_conditionals[level] = rrdbResults[self.levelToName[level]]
|
||||
|
||||
level_conditionals[level] = rrdbResults[self.levelToName[level]]
|
||||
|
||||
if isinstance(layer, FlowStep):
|
||||
fl_fea, logdet = layer(fl_fea, logdet, reverse=reverse, rrdbResults=level_conditionals[level])
|
||||
elif isinstance(layer, Split2d):
|
||||
fl_fea, logdet = self.forward_split2d(epses, fl_fea, layer, logdet, reverse, level_conditionals[level],
|
||||
y_onehot=y_onehot)
|
||||
else:
|
||||
fl_fea, logdet = layer(fl_fea, logdet, reverse=reverse)
|
||||
|
||||
z = fl_fea
|
||||
|
||||
if not isinstance(epses, list):
|
||||
return z, logdet
|
||||
|
||||
epses.append(z)
|
||||
return epses, logdet
|
||||
|
||||
def forward_preFlow(self, fl_fea, logdet, reverse):
|
||||
if hasattr(self, 'preFlow'):
|
||||
for l in self.preFlow:
|
||||
fl_fea, logdet = l(fl_fea, logdet, reverse=reverse)
|
||||
return fl_fea, logdet
|
||||
|
||||
def forward_split2d(self, epses, fl_fea, layer, logdet, reverse, rrdbResults, y_onehot=None):
|
||||
ft = None if layer.position is None else rrdbResults[layer.position]
|
||||
fl_fea, logdet, eps = layer(fl_fea, logdet, reverse=reverse, eps=epses, ft=ft, y_onehot=y_onehot)
|
||||
|
||||
if isinstance(epses, list):
|
||||
epses.append(eps)
|
||||
return fl_fea, logdet
|
||||
|
||||
def decode(self, rrdbResults, z, eps_std=None, epses=None, logdet=0.0, y_onehot=None):
|
||||
z = epses.pop() if isinstance(epses, list) else z
|
||||
|
||||
fl_fea = z
|
||||
# debug.imwrite("fl_fea", fl_fea)
|
||||
bypasses = {}
|
||||
level_conditionals = {}
|
||||
if not opt_get(self.opt, ['network_G', 'flow', 'levelConditional', 'conditional']) == True:
|
||||
for level in range(self.L + 1):
|
||||
level_conditionals[level] = rrdbResults[self.levelToName[level]]
|
||||
|
||||
for layer, shape in zip(reversed(self.layers), reversed(self.output_shapes)):
|
||||
size = shape[2]
|
||||
level = int(np.log(160 / size) / np.log(2))
|
||||
# size = fl_fea.shape[2]
|
||||
# level = int(np.log(160 / size) / np.log(2))
|
||||
|
||||
if isinstance(layer, Split2d):
|
||||
fl_fea, logdet = self.forward_split2d_reverse(eps_std, epses, fl_fea, layer,
|
||||
rrdbResults[self.levelToName[level]], logdet=logdet,
|
||||
y_onehot=y_onehot)
|
||||
elif isinstance(layer, FlowStep):
|
||||
fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True, rrdbResults=level_conditionals[level])
|
||||
else:
|
||||
fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True)
|
||||
|
||||
sr = fl_fea
|
||||
|
||||
assert sr.shape[1] == 3
|
||||
return sr, logdet
|
||||
|
||||
def forward_split2d_reverse(self, eps_std, epses, fl_fea, layer, rrdbResults, logdet, y_onehot=None):
|
||||
ft = None if layer.position is None else rrdbResults[layer.position]
|
||||
fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True,
|
||||
eps=epses.pop() if isinstance(epses, list) else None,
|
||||
eps_std=eps_std, ft=ft, y_onehot=y_onehot)
|
||||
return fl_fea, logdet
|
||||
|
||||
|
||||
def get_position_name(H, scale):
|
||||
downscale_factor = 160 // H
|
||||
position_name = 'fea_up{}'.format(scale / downscale_factor)
|
||||
return position_name
|
42
codes/models/archs/srflow/Permutations.py
Normal file
42
codes/models/archs/srflow/Permutations.py
Normal file
|
@ -0,0 +1,42 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from models.modules import thops
|
||||
|
||||
|
||||
class InvertibleConv1x1(nn.Module):
|
||||
def __init__(self, num_channels, LU_decomposed=False):
|
||||
super().__init__()
|
||||
w_shape = [num_channels, num_channels]
|
||||
w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype(np.float32)
|
||||
self.register_parameter("weight", nn.Parameter(torch.Tensor(w_init)))
|
||||
self.w_shape = w_shape
|
||||
self.LU = LU_decomposed
|
||||
|
||||
def get_weight(self, input, reverse):
|
||||
w_shape = self.w_shape
|
||||
pixels = thops.pixels(input)
|
||||
dlogdet = torch.slogdet(self.weight)[1] * pixels
|
||||
if not reverse:
|
||||
weight = self.weight.view(w_shape[0], w_shape[1], 1, 1)
|
||||
else:
|
||||
weight = torch.inverse(self.weight.double()).float() \
|
||||
.view(w_shape[0], w_shape[1], 1, 1)
|
||||
return weight, dlogdet
|
||||
def forward(self, input, logdet=None, reverse=False):
|
||||
"""
|
||||
log-det = log|abs(|W|)| * pixels
|
||||
"""
|
||||
weight, dlogdet = self.get_weight(input, reverse)
|
||||
if not reverse:
|
||||
z = F.conv2d(input, weight)
|
||||
if logdet is not None:
|
||||
logdet = logdet + dlogdet
|
||||
return z, logdet
|
||||
else:
|
||||
z = F.conv2d(input, weight)
|
||||
if logdet is not None:
|
||||
logdet = logdet - dlogdet
|
||||
return z, logdet
|
132
codes/models/archs/srflow/RRDBNet_arch.py
Normal file
132
codes/models/archs/srflow/RRDBNet_arch.py
Normal file
|
@ -0,0 +1,132 @@
|
|||
import functools
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import models.archs.srflow.module_util as mutil
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
class ResidualDenseBlock_5C(nn.Module):
|
||||
def __init__(self, nf=64, gc=32, bias=True):
|
||||
super(ResidualDenseBlock_5C, self).__init__()
|
||||
# gc: growth channel, i.e. intermediate channels
|
||||
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
|
||||
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
# initialization
|
||||
mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.lrelu(self.conv1(x))
|
||||
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
||||
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
||||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||
return x5 * 0.2 + x
|
||||
|
||||
|
||||
class RRDB(nn.Module):
|
||||
'''Residual in Residual Dense Block'''
|
||||
|
||||
def __init__(self, nf, gc=32):
|
||||
super(RRDB, self).__init__()
|
||||
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.RDB1(x)
|
||||
out = self.RDB2(out)
|
||||
out = self.RDB3(out)
|
||||
return out * 0.2 + x
|
||||
|
||||
|
||||
class RRDBNet(nn.Module):
|
||||
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, opt=None):
|
||||
self.opt = opt
|
||||
super(RRDBNet, self).__init__()
|
||||
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
||||
self.scale = scale
|
||||
|
||||
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||
self.RRDB_trunk = mutil.make_layer(RRDB_block_f, nb)
|
||||
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
#### upsampling
|
||||
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
if self.scale >= 8:
|
||||
self.upconv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
if self.scale >= 16:
|
||||
self.upconv4 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
if self.scale >= 32:
|
||||
self.upconv5 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
|
||||
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
||||
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
def forward(self, x, get_steps=False):
|
||||
fea = self.conv_first(x)
|
||||
|
||||
block_idxs = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or []
|
||||
block_results = {}
|
||||
|
||||
for idx, m in enumerate(self.RRDB_trunk.children()):
|
||||
fea = m(fea)
|
||||
for b in block_idxs:
|
||||
if b == idx:
|
||||
block_results["block_{}".format(idx)] = fea
|
||||
|
||||
trunk = self.trunk_conv(fea)
|
||||
|
||||
last_lr_fea = fea + trunk
|
||||
|
||||
fea_up2 = self.upconv1(F.interpolate(last_lr_fea, scale_factor=2, mode='nearest'))
|
||||
fea = self.lrelu(fea_up2)
|
||||
|
||||
fea_up4 = self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))
|
||||
fea = self.lrelu(fea_up4)
|
||||
|
||||
fea_up8 = None
|
||||
fea_up16 = None
|
||||
fea_up32 = None
|
||||
|
||||
if self.scale >= 8:
|
||||
fea_up8 = self.upconv3(F.interpolate(fea, scale_factor=2, mode='nearest'))
|
||||
fea = self.lrelu(fea_up8)
|
||||
if self.scale >= 16:
|
||||
fea_up16 = self.upconv4(F.interpolate(fea, scale_factor=2, mode='nearest'))
|
||||
fea = self.lrelu(fea_up16)
|
||||
if self.scale >= 32:
|
||||
fea_up32 = self.upconv5(F.interpolate(fea, scale_factor=2, mode='nearest'))
|
||||
fea = self.lrelu(fea_up32)
|
||||
|
||||
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
||||
|
||||
results = {'last_lr_fea': last_lr_fea,
|
||||
'fea_up1': last_lr_fea,
|
||||
'fea_up2': fea_up2,
|
||||
'fea_up4': fea_up4,
|
||||
'fea_up8': fea_up8,
|
||||
'fea_up16': fea_up16,
|
||||
'fea_up32': fea_up32,
|
||||
'out': out}
|
||||
|
||||
fea_up0_en = opt_get(self.opt, ['network_G', 'flow', 'fea_up0']) or False
|
||||
if fea_up0_en:
|
||||
results['fea_up0'] = F.interpolate(last_lr_fea, scale_factor=1/2, mode='bilinear', align_corners=False, recompute_scale_factor=True)
|
||||
fea_upn1_en = opt_get(self.opt, ['network_G', 'flow', 'fea_up-1']) or False
|
||||
if fea_upn1_en:
|
||||
results['fea_up-1'] = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True)
|
||||
|
||||
if get_steps:
|
||||
for k, v in block_results.items():
|
||||
results[k] = v
|
||||
return results
|
||||
else:
|
||||
return out
|
69
codes/models/archs/srflow/Split.py
Normal file
69
codes/models/archs/srflow/Split.py
Normal file
|
@ -0,0 +1,69 @@
|
|||
import torch
|
||||
from torch import nn as nn
|
||||
|
||||
from models.archs.srflow import thops
|
||||
from models.archs.srflow.FlowStep import FlowStep
|
||||
from models.archs.srflow.flow import Conv2dZeros, GaussianDiag
|
||||
|
||||
|
||||
class Split2d(nn.Module):
|
||||
def __init__(self, num_channels, logs_eps=0, cond_channels=0, position=None, consume_ratio=0.5, opt=None):
|
||||
super().__init__()
|
||||
|
||||
self.num_channels_consume = int(round(num_channels * consume_ratio))
|
||||
self.num_channels_pass = num_channels - self.num_channels_consume
|
||||
|
||||
self.conv = Conv2dZeros(in_channels=self.num_channels_pass + cond_channels,
|
||||
out_channels=self.num_channels_consume * 2)
|
||||
self.logs_eps = logs_eps
|
||||
self.position = position
|
||||
self.opt = opt
|
||||
|
||||
def split2d_prior(self, z, ft):
|
||||
if ft is not None:
|
||||
z = torch.cat([z, ft], dim=1)
|
||||
h = self.conv(z)
|
||||
return thops.split_feature(h, "cross")
|
||||
|
||||
def exp_eps(self, logs):
|
||||
return torch.exp(logs) + self.logs_eps
|
||||
|
||||
def forward(self, input, logdet=0., reverse=False, eps_std=None, eps=None, ft=None, y_onehot=None):
|
||||
if not reverse:
|
||||
# self.input = input
|
||||
z1, z2 = self.split_ratio(input)
|
||||
mean, logs = self.split2d_prior(z1, ft)
|
||||
|
||||
eps = (z2 - mean) / self.exp_eps(logs)
|
||||
|
||||
logdet = logdet + self.get_logdet(logs, mean, z2)
|
||||
|
||||
# print(logs.shape, mean.shape, z2.shape)
|
||||
# self.eps = eps
|
||||
# print('split, enc eps:', eps)
|
||||
return z1, logdet, eps
|
||||
else:
|
||||
z1 = input
|
||||
mean, logs = self.split2d_prior(z1, ft)
|
||||
|
||||
if eps is None:
|
||||
#print("WARNING: eps is None, generating eps untested functionality!")
|
||||
eps = GaussianDiag.sample_eps(mean.shape, eps_std)
|
||||
|
||||
eps = eps.to(mean.device)
|
||||
z2 = mean + self.exp_eps(logs) * eps
|
||||
|
||||
z = thops.cat_feature(z1, z2)
|
||||
logdet = logdet - self.get_logdet(logs, mean, z2)
|
||||
|
||||
return z, logdet
|
||||
# return z, logdet, eps
|
||||
|
||||
def get_logdet(self, logs, mean, z2):
|
||||
logdet_diff = GaussianDiag.logp(mean, logs, z2)
|
||||
# print("Split2D: logdet diff", logdet_diff.item())
|
||||
return logdet_diff
|
||||
|
||||
def split_ratio(self, input):
|
||||
z1, z2 = input[:, :self.num_channels_pass, ...], input[:, self.num_channels_pass:, ...]
|
||||
return z1, z2
|
0
codes/models/archs/srflow/__init__.py
Normal file
0
codes/models/archs/srflow/__init__.py
Normal file
150
codes/models/archs/srflow/flow.py
Normal file
150
codes/models/archs/srflow/flow.py
Normal file
|
@ -0,0 +1,150 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
from models.archs.srflow.FlowActNorms import ActNorm2d
|
||||
from . import thops
|
||||
|
||||
|
||||
class Conv2d(nn.Conv2d):
|
||||
pad_dict = {
|
||||
"same": lambda kernel, stride: [((k - 1) * s + 1) // 2 for k, s in zip(kernel, stride)],
|
||||
"valid": lambda kernel, stride: [0 for _ in kernel]
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_padding(padding, kernel_size, stride):
|
||||
# make paddding
|
||||
if isinstance(padding, str):
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = [kernel_size, kernel_size]
|
||||
if isinstance(stride, int):
|
||||
stride = [stride, stride]
|
||||
padding = padding.lower()
|
||||
try:
|
||||
padding = Conv2d.pad_dict[padding](kernel_size, stride)
|
||||
except KeyError:
|
||||
raise ValueError("{} is not supported".format(padding))
|
||||
return padding
|
||||
|
||||
def __init__(self, in_channels, out_channels,
|
||||
kernel_size=[3, 3], stride=[1, 1],
|
||||
padding="same", do_actnorm=True, weight_std=0.05):
|
||||
padding = Conv2d.get_padding(padding, kernel_size, stride)
|
||||
super().__init__(in_channels, out_channels, kernel_size, stride,
|
||||
padding, bias=(not do_actnorm))
|
||||
# init weight with std
|
||||
self.weight.data.normal_(mean=0.0, std=weight_std)
|
||||
if not do_actnorm:
|
||||
self.bias.data.zero_()
|
||||
else:
|
||||
self.actnorm = ActNorm2d(out_channels)
|
||||
self.do_actnorm = do_actnorm
|
||||
|
||||
def forward(self, input):
|
||||
x = super().forward(input)
|
||||
if self.do_actnorm:
|
||||
x, _ = self.actnorm(x)
|
||||
return x
|
||||
|
||||
|
||||
class Conv2dZeros(nn.Conv2d):
|
||||
def __init__(self, in_channels, out_channels,
|
||||
kernel_size=[3, 3], stride=[1, 1],
|
||||
padding="same", logscale_factor=3):
|
||||
padding = Conv2d.get_padding(padding, kernel_size, stride)
|
||||
super().__init__(in_channels, out_channels, kernel_size, stride, padding)
|
||||
# logscale_factor
|
||||
self.logscale_factor = logscale_factor
|
||||
self.register_parameter("logs", nn.Parameter(torch.zeros(out_channels, 1, 1)))
|
||||
# init
|
||||
self.weight.data.zero_()
|
||||
self.bias.data.zero_()
|
||||
|
||||
def forward(self, input):
|
||||
output = super().forward(input)
|
||||
return output * torch.exp(self.logs * self.logscale_factor)
|
||||
|
||||
|
||||
class GaussianDiag:
|
||||
Log2PI = float(np.log(2 * np.pi))
|
||||
|
||||
@staticmethod
|
||||
def likelihood(mean, logs, x):
|
||||
"""
|
||||
lnL = -1/2 * { ln|Var| + ((X - Mu)^T)(Var^-1)(X - Mu) + kln(2*PI) }
|
||||
k = 1 (Independent)
|
||||
Var = logs ** 2
|
||||
"""
|
||||
if mean is None and logs is None:
|
||||
return -0.5 * (x ** 2 + GaussianDiag.Log2PI)
|
||||
else:
|
||||
return -0.5 * (logs * 2. + ((x - mean) ** 2) / torch.exp(logs * 2.) + GaussianDiag.Log2PI)
|
||||
|
||||
@staticmethod
|
||||
def logp(mean, logs, x):
|
||||
likelihood = GaussianDiag.likelihood(mean, logs, x)
|
||||
return thops.sum(likelihood, dim=[1, 2, 3])
|
||||
|
||||
@staticmethod
|
||||
def sample(mean, logs, eps_std=None):
|
||||
eps_std = eps_std or 1
|
||||
eps = torch.normal(mean=torch.zeros_like(mean),
|
||||
std=torch.ones_like(logs) * eps_std)
|
||||
return mean + torch.exp(logs) * eps
|
||||
|
||||
@staticmethod
|
||||
def sample_eps(shape, eps_std, seed=None):
|
||||
if seed is not None:
|
||||
torch.manual_seed(seed)
|
||||
eps = torch.normal(mean=torch.zeros(shape),
|
||||
std=torch.ones(shape) * eps_std)
|
||||
return eps
|
||||
|
||||
|
||||
def squeeze2d(input, factor=2):
|
||||
assert factor >= 1 and isinstance(factor, int)
|
||||
if factor == 1:
|
||||
return input
|
||||
size = input.size()
|
||||
B = size[0]
|
||||
C = size[1]
|
||||
H = size[2]
|
||||
W = size[3]
|
||||
assert H % factor == 0 and W % factor == 0, "{}".format((H, W, factor))
|
||||
x = input.view(B, C, H // factor, factor, W // factor, factor)
|
||||
x = x.permute(0, 1, 3, 5, 2, 4).contiguous()
|
||||
x = x.view(B, C * factor * factor, H // factor, W // factor)
|
||||
return x
|
||||
|
||||
|
||||
def unsqueeze2d(input, factor=2):
|
||||
assert factor >= 1 and isinstance(factor, int)
|
||||
factor2 = factor ** 2
|
||||
if factor == 1:
|
||||
return input
|
||||
size = input.size()
|
||||
B = size[0]
|
||||
C = size[1]
|
||||
H = size[2]
|
||||
W = size[3]
|
||||
assert C % (factor2) == 0, "{}".format(C)
|
||||
x = input.view(B, C // factor2, factor, factor, H, W)
|
||||
x = x.permute(0, 1, 4, 2, 5, 3).contiguous()
|
||||
x = x.view(B, C // (factor2), H * factor, W * factor)
|
||||
return x
|
||||
|
||||
|
||||
class SqueezeLayer(nn.Module):
|
||||
def __init__(self, factor):
|
||||
super().__init__()
|
||||
self.factor = factor
|
||||
|
||||
def forward(self, input, logdet=None, reverse=False):
|
||||
if not reverse:
|
||||
output = squeeze2d(input, self.factor) # Squeeze in forward
|
||||
return output, logdet
|
||||
else:
|
||||
output = unsqueeze2d(input, self.factor)
|
||||
return output, logdet
|
12
codes/models/archs/srflow/glow_arch.py
Normal file
12
codes/models/archs/srflow/glow_arch.py
Normal file
|
@ -0,0 +1,12 @@
|
|||
import torch.nn as nn
|
||||
|
||||
|
||||
def f_conv2d_bias(in_channels, out_channels):
|
||||
def padding_same(kernel, stride):
|
||||
return [((k - 1) * s + 1) // 2 for k, s in zip(kernel, stride)]
|
||||
|
||||
padding = padding_same([3, 3], [1, 1])
|
||||
assert padding == [1, 1], padding
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=[3, 3], stride=1, padding=1,
|
||||
bias=True))
|
79
codes/models/archs/srflow/module_util.py
Normal file
79
codes/models/archs/srflow/module_util.py
Normal file
|
@ -0,0 +1,79 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.init as init
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def initialize_weights(net_l, scale=1):
|
||||
if not isinstance(net_l, list):
|
||||
net_l = [net_l]
|
||||
for net in net_l:
|
||||
for m in net.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
||||
m.weight.data *= scale # for residual block
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Linear):
|
||||
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
||||
m.weight.data *= scale
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
init.constant_(m.weight, 1)
|
||||
init.constant_(m.bias.data, 0.0)
|
||||
|
||||
|
||||
def make_layer(block, n_layers):
|
||||
layers = []
|
||||
for _ in range(n_layers):
|
||||
layers.append(block())
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
class ResidualBlock_noBN(nn.Module):
|
||||
'''Residual block w/o BN
|
||||
---Conv-ReLU-Conv-+-
|
||||
|________________|
|
||||
'''
|
||||
|
||||
def __init__(self, nf=64):
|
||||
super(ResidualBlock_noBN, self).__init__()
|
||||
self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
|
||||
# initialization
|
||||
initialize_weights([self.conv1, self.conv2], 0.1)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
out = F.relu(self.conv1(x), inplace=True)
|
||||
out = self.conv2(out)
|
||||
return identity + out
|
||||
|
||||
|
||||
def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'):
|
||||
"""Warp an image or feature map with optical flow
|
||||
Args:
|
||||
x (Tensor): size (N, C, H, W)
|
||||
flow (Tensor): size (N, H, W, 2), normal value
|
||||
interp_mode (str): 'nearest' or 'bilinear'
|
||||
padding_mode (str): 'zeros' or 'border' or 'reflection'
|
||||
|
||||
Returns:
|
||||
Tensor: warped image or feature map
|
||||
"""
|
||||
assert x.size()[-2:] == flow.size()[1:3]
|
||||
B, C, H, W = x.size()
|
||||
# mesh grid
|
||||
grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
|
||||
grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
|
||||
grid.requires_grad = False
|
||||
grid = grid.type_as(x)
|
||||
vgrid = grid + flow
|
||||
# scale grid to [-1,1]
|
||||
vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0
|
||||
vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0
|
||||
vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
|
||||
output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode)
|
||||
return output
|
141
codes/models/archs/srflow/srflow_arch.py
Normal file
141
codes/models/archs/srflow/srflow_arch.py
Normal file
|
@ -0,0 +1,141 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from models.archs.srflow.FlowUpsamplerNet import FlowUpsamplerNet
|
||||
import models.archs.srflow.thops as thops
|
||||
import models.archs.srflow.flow as flow
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
class SRFlowNet(nn.Module):
|
||||
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, K=None, opt=None, step=None):
|
||||
super(SRFlowNet, self).__init__()
|
||||
|
||||
self.opt = opt
|
||||
self.quant = 255 if opt_get(opt, ['datasets', 'train', 'quant']) is \
|
||||
None else opt_get(opt, ['datasets', 'train', 'quant'])
|
||||
self.RRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, opt)
|
||||
hidden_channels = opt_get(opt, ['network_G', 'flow', 'hidden_channels'])
|
||||
hidden_channels = hidden_channels or 64
|
||||
self.RRDB_training = True # Default is true
|
||||
|
||||
train_RRDB_delay = opt_get(self.opt, ['network_G', 'train_RRDB_delay'])
|
||||
set_RRDB_to_train = False
|
||||
if set_RRDB_to_train:
|
||||
self.set_rrdb_training(True)
|
||||
|
||||
self.flowUpsamplerNet = \
|
||||
FlowUpsamplerNet((160, 160, 3), hidden_channels, K,
|
||||
flow_coupling=opt['network_G']['flow']['coupling'], opt=opt)
|
||||
self.i = 0
|
||||
|
||||
def set_rrdb_training(self, trainable):
|
||||
if self.RRDB_training != trainable:
|
||||
for p in self.RRDB.parameters():
|
||||
p.requires_grad = trainable
|
||||
self.RRDB_training = trainable
|
||||
return True
|
||||
return False
|
||||
|
||||
def forward(self, gt=None, lr=None, z=None, eps_std=None, reverse=False, epses=None, reverse_with_grad=False,
|
||||
lr_enc=None,
|
||||
add_gt_noise=False, step=None, y_label=None):
|
||||
if not reverse:
|
||||
return self.normal_flow(gt, lr, epses=epses, lr_enc=lr_enc, add_gt_noise=add_gt_noise, step=step,
|
||||
y_onehot=y_label)
|
||||
else:
|
||||
# assert lr.shape[0] == 1
|
||||
assert lr.shape[1] == 3
|
||||
# assert lr.shape[2] == 20
|
||||
# assert lr.shape[3] == 20
|
||||
# assert z.shape[0] == 1
|
||||
# assert z.shape[1] == 3 * 8 * 8
|
||||
# assert z.shape[2] == 20
|
||||
# assert z.shape[3] == 20
|
||||
if reverse_with_grad:
|
||||
return self.reverse_flow(lr, z, y_onehot=y_label, eps_std=eps_std, epses=epses, lr_enc=lr_enc,
|
||||
add_gt_noise=add_gt_noise)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
return self.reverse_flow(lr, z, y_onehot=y_label, eps_std=eps_std, epses=epses, lr_enc=lr_enc,
|
||||
add_gt_noise=add_gt_noise)
|
||||
|
||||
def normal_flow(self, gt, lr, y_onehot=None, epses=None, lr_enc=None, add_gt_noise=True, step=None):
|
||||
if lr_enc is None:
|
||||
lr_enc = self.rrdbPreprocessing(lr)
|
||||
|
||||
logdet = torch.zeros_like(gt[:, 0, 0, 0])
|
||||
pixels = thops.pixels(gt)
|
||||
|
||||
z = gt
|
||||
|
||||
if add_gt_noise:
|
||||
# Setup
|
||||
noiseQuant = opt_get(self.opt, ['network_G', 'flow', 'augmentation', 'noiseQuant'], True)
|
||||
if noiseQuant:
|
||||
z = z + ((torch.rand(z.shape, device=z.device) - 0.5) / self.quant)
|
||||
logdet = logdet + float(-np.log(self.quant) * pixels)
|
||||
|
||||
# Encode
|
||||
epses, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, gt=z, logdet=logdet, reverse=False, epses=epses,
|
||||
y_onehot=y_onehot)
|
||||
|
||||
objective = logdet.clone()
|
||||
|
||||
if isinstance(epses, (list, tuple)):
|
||||
z = epses[-1]
|
||||
else:
|
||||
z = epses
|
||||
|
||||
objective = objective + flow.GaussianDiag.logp(None, None, z)
|
||||
|
||||
nll = (-objective) / float(np.log(2.) * pixels)
|
||||
|
||||
if isinstance(epses, list):
|
||||
return epses, nll, logdet
|
||||
return z, nll, logdet
|
||||
|
||||
def rrdbPreprocessing(self, lr):
|
||||
rrdbResults = self.RRDB(lr, get_steps=True)
|
||||
block_idxs = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or []
|
||||
if len(block_idxs) > 0:
|
||||
concat = torch.cat([rrdbResults["block_{}".format(idx)] for idx in block_idxs], dim=1)
|
||||
|
||||
if opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'concat']) or False:
|
||||
keys = ['last_lr_fea', 'fea_up1', 'fea_up2', 'fea_up4']
|
||||
if 'fea_up0' in rrdbResults.keys():
|
||||
keys.append('fea_up0')
|
||||
if 'fea_up-1' in rrdbResults.keys():
|
||||
keys.append('fea_up-1')
|
||||
if self.opt['scale'] >= 8:
|
||||
keys.append('fea_up8')
|
||||
if self.opt['scale'] == 16:
|
||||
keys.append('fea_up16')
|
||||
for k in keys:
|
||||
h = rrdbResults[k].shape[2]
|
||||
w = rrdbResults[k].shape[3]
|
||||
rrdbResults[k] = torch.cat([rrdbResults[k], F.interpolate(concat, (h, w))], dim=1)
|
||||
return rrdbResults
|
||||
|
||||
def get_score(self, disc_loss_sigma, z):
|
||||
score_real = 0.5 * (1 - 1 / (disc_loss_sigma ** 2)) * thops.sum(z ** 2, dim=[1, 2, 3]) - \
|
||||
z.shape[1] * z.shape[2] * z.shape[3] * math.log(disc_loss_sigma)
|
||||
return -score_real
|
||||
|
||||
def reverse_flow(self, lr, z, y_onehot, eps_std, epses=None, lr_enc=None, add_gt_noise=True):
|
||||
logdet = torch.zeros_like(lr[:, 0, 0, 0])
|
||||
pixels = thops.pixels(lr) * self.opt['scale'] ** 2
|
||||
|
||||
if add_gt_noise:
|
||||
logdet = logdet - float(-np.log(self.quant) * pixels)
|
||||
|
||||
if lr_enc is None:
|
||||
lr_enc = self.rrdbPreprocessing(lr)
|
||||
|
||||
x, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, z=z, eps_std=eps_std, reverse=True, epses=epses,
|
||||
logdet=logdet)
|
||||
|
||||
return x, logdet
|
52
codes/models/archs/srflow/thops.py
Normal file
52
codes/models/archs/srflow/thops.py
Normal file
|
@ -0,0 +1,52 @@
|
|||
import torch
|
||||
|
||||
|
||||
def sum(tensor, dim=None, keepdim=False):
|
||||
if dim is None:
|
||||
# sum up all dim
|
||||
return torch.sum(tensor)
|
||||
else:
|
||||
if isinstance(dim, int):
|
||||
dim = [dim]
|
||||
dim = sorted(dim)
|
||||
for d in dim:
|
||||
tensor = tensor.sum(dim=d, keepdim=True)
|
||||
if not keepdim:
|
||||
for i, d in enumerate(dim):
|
||||
tensor.squeeze_(d-i)
|
||||
return tensor
|
||||
|
||||
|
||||
def mean(tensor, dim=None, keepdim=False):
|
||||
if dim is None:
|
||||
# mean all dim
|
||||
return torch.mean(tensor)
|
||||
else:
|
||||
if isinstance(dim, int):
|
||||
dim = [dim]
|
||||
dim = sorted(dim)
|
||||
for d in dim:
|
||||
tensor = tensor.mean(dim=d, keepdim=True)
|
||||
if not keepdim:
|
||||
for i, d in enumerate(dim):
|
||||
tensor.squeeze_(d-i)
|
||||
return tensor
|
||||
|
||||
|
||||
def split_feature(tensor, type="split"):
|
||||
"""
|
||||
type = ["split", "cross"]
|
||||
"""
|
||||
C = tensor.size(1)
|
||||
if type == "split":
|
||||
return tensor[:, :C // 2, ...], tensor[:, C // 2:, ...]
|
||||
elif type == "cross":
|
||||
return tensor[:, 0::2, ...], tensor[:, 1::2, ...]
|
||||
|
||||
|
||||
def cat_feature(tensor_a, tensor_b):
|
||||
return torch.cat((tensor_a, tensor_b), dim=1)
|
||||
|
||||
|
||||
def pixels(tensor):
|
||||
return int(tensor.size(2) * tensor.size(3))
|
Loading…
Reference in New Issue
Block a user