forked from mrq/DL-Art-School
Remove srflow (modified version)
Starting from orig and re-working from there.
This commit is contained in:
parent
5f5420ff4a
commit
6de4dabb73
|
@ -1,125 +0,0 @@
|
||||||
import torch
|
|
||||||
from torch import nn as nn
|
|
||||||
|
|
||||||
from models.archs.srflow import thops
|
|
||||||
|
|
||||||
|
|
||||||
class _ActNorm(nn.Module):
|
|
||||||
"""
|
|
||||||
Activation Normalization
|
|
||||||
Initialize the bias and scale with a given minibatch,
|
|
||||||
so that the output per-channel have zero mean and unit variance for that.
|
|
||||||
|
|
||||||
After initialization, `bias` and `logs` will be trained as parameters.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, num_features, scale=1.):
|
|
||||||
super().__init__()
|
|
||||||
# register mean and scale
|
|
||||||
size = [1, num_features, 1, 1]
|
|
||||||
self.register_parameter("bias", nn.Parameter(torch.zeros(*size)))
|
|
||||||
self.register_parameter("logs", nn.Parameter(torch.zeros(*size)))
|
|
||||||
self.num_features = num_features
|
|
||||||
self.scale = float(scale)
|
|
||||||
self.inited = False
|
|
||||||
|
|
||||||
def _check_input_dim(self, input):
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
def initialize_parameters(self, input):
|
|
||||||
self._check_input_dim(input)
|
|
||||||
if not self.training:
|
|
||||||
return
|
|
||||||
if (self.bias != 0).any():
|
|
||||||
self.inited = True
|
|
||||||
return
|
|
||||||
assert input.device == self.bias.device, (input.device, self.bias.device)
|
|
||||||
with torch.no_grad():
|
|
||||||
bias = thops.mean(input.clone(), dim=[0, 2, 3], keepdim=True) * -1.0
|
|
||||||
vars = thops.mean((input.clone() + bias) ** 2, dim=[0, 2, 3], keepdim=True)
|
|
||||||
logs = torch.log(self.scale / (torch.sqrt(vars) + 1e-6))
|
|
||||||
self.bias.data.copy_(bias.data)
|
|
||||||
self.logs.data.copy_(logs.data)
|
|
||||||
self.inited = True
|
|
||||||
|
|
||||||
def _center(self, input, reverse=False, offset=None):
|
|
||||||
bias = self.bias
|
|
||||||
|
|
||||||
if offset is not None:
|
|
||||||
bias = bias + offset
|
|
||||||
|
|
||||||
if not reverse:
|
|
||||||
return input + bias
|
|
||||||
else:
|
|
||||||
return input - bias
|
|
||||||
|
|
||||||
def _scale(self, input, logdet=None, reverse=False, offset=None):
|
|
||||||
logs = self.logs
|
|
||||||
|
|
||||||
if offset is not None:
|
|
||||||
logs = logs + offset
|
|
||||||
|
|
||||||
if not reverse:
|
|
||||||
input = input * torch.exp(logs) # should have shape batchsize, n_channels, 1, 1
|
|
||||||
# input = input * torch.exp(logs+logs_offset)
|
|
||||||
else:
|
|
||||||
input = input * torch.exp(-logs)
|
|
||||||
if logdet is not None:
|
|
||||||
"""
|
|
||||||
logs is log_std of `mean of channels`
|
|
||||||
so we need to multiply pixels
|
|
||||||
"""
|
|
||||||
dlogdet = thops.sum(logs) * thops.pixels(input)
|
|
||||||
if reverse:
|
|
||||||
dlogdet *= -1
|
|
||||||
logdet = logdet + dlogdet
|
|
||||||
return input, logdet
|
|
||||||
|
|
||||||
def forward(self, input, logdet=None, reverse=False, offset_mask=None, logs_offset=None, bias_offset=None):
|
|
||||||
if not self.inited:
|
|
||||||
self.initialize_parameters(input)
|
|
||||||
self._check_input_dim(input)
|
|
||||||
|
|
||||||
if offset_mask is not None:
|
|
||||||
logs_offset *= offset_mask
|
|
||||||
bias_offset *= offset_mask
|
|
||||||
# no need to permute dims as old version
|
|
||||||
if not reverse:
|
|
||||||
# center and scale
|
|
||||||
|
|
||||||
# self.input = input
|
|
||||||
input = self._center(input, reverse, bias_offset)
|
|
||||||
input, logdet = self._scale(input, logdet, reverse, logs_offset)
|
|
||||||
else:
|
|
||||||
# scale and center
|
|
||||||
input, logdet = self._scale(input, logdet, reverse, logs_offset)
|
|
||||||
input = self._center(input, reverse, bias_offset)
|
|
||||||
return input, logdet
|
|
||||||
|
|
||||||
|
|
||||||
class ActNorm2d(_ActNorm):
|
|
||||||
def __init__(self, num_features, scale=1.):
|
|
||||||
super().__init__(num_features, scale)
|
|
||||||
|
|
||||||
def _check_input_dim(self, input):
|
|
||||||
assert len(input.size()) == 4
|
|
||||||
assert input.size(1) == self.num_features, (
|
|
||||||
"[ActNorm]: input should be in shape as `BCHW`,"
|
|
||||||
" channels should be {} rather than {}".format(
|
|
||||||
self.num_features, input.size()))
|
|
||||||
|
|
||||||
|
|
||||||
class MaskedActNorm2d(ActNorm2d):
|
|
||||||
def __init__(self, num_features, scale=1.):
|
|
||||||
super().__init__(num_features, scale)
|
|
||||||
|
|
||||||
def forward(self, input, mask, logdet=None, reverse=False):
|
|
||||||
|
|
||||||
assert mask.dtype == torch.bool
|
|
||||||
output, logdet_out = super().forward(input, logdet, reverse)
|
|
||||||
|
|
||||||
input[mask] = output[mask]
|
|
||||||
logdet[mask] = logdet_out[mask]
|
|
||||||
|
|
||||||
return input, logdet
|
|
||||||
|
|
|
@ -1,116 +0,0 @@
|
||||||
import torch
|
|
||||||
from torch import nn as nn
|
|
||||||
|
|
||||||
from models.archs.srflow import thops
|
|
||||||
from models.archs.srflow.flow import Conv2d, Conv2dZeros
|
|
||||||
|
|
||||||
|
|
||||||
class CondAffineSeparatedAndCond(nn.Module):
|
|
||||||
def __init__(self, in_channels, hidden_channels=64, affine_eps=.00001):
|
|
||||||
super().__init__()
|
|
||||||
self.need_features = True
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.in_channels_rrdb = 320
|
|
||||||
self.kernel_hidden = 1
|
|
||||||
self.affine_eps = 0.0001
|
|
||||||
self.n_hidden_layers = 1
|
|
||||||
self.hidden_channels = hidden_channels
|
|
||||||
self.affine_eps = affine_eps
|
|
||||||
|
|
||||||
self.channels_for_nn = self.in_channels // 2
|
|
||||||
self.channels_for_co = self.in_channels - self.channels_for_nn
|
|
||||||
|
|
||||||
if self.channels_for_nn is None:
|
|
||||||
self.channels_for_nn = self.in_channels // 2
|
|
||||||
|
|
||||||
self.fAffine = self.F(in_channels=self.channels_for_nn + self.in_channels_rrdb,
|
|
||||||
out_channels=self.channels_for_co * 2,
|
|
||||||
hidden_channels=self.hidden_channels,
|
|
||||||
kernel_hidden=self.kernel_hidden,
|
|
||||||
n_hidden_layers=self.n_hidden_layers)
|
|
||||||
|
|
||||||
self.fFeatures = self.F(in_channels=self.in_channels_rrdb,
|
|
||||||
out_channels=self.in_channels * 2,
|
|
||||||
hidden_channels=self.hidden_channels,
|
|
||||||
kernel_hidden=self.kernel_hidden,
|
|
||||||
n_hidden_layers=self.n_hidden_layers)
|
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor, logdet=None, reverse=False, ft=None):
|
|
||||||
if not reverse:
|
|
||||||
z = input
|
|
||||||
assert z.shape[1] == self.in_channels, (z.shape[1], self.in_channels)
|
|
||||||
|
|
||||||
# Feature Conditional
|
|
||||||
scaleFt, shiftFt = self.feature_extract(ft, self.fFeatures)
|
|
||||||
z = z + shiftFt
|
|
||||||
z = z * scaleFt
|
|
||||||
logdet = logdet + self.get_logdet(scaleFt)
|
|
||||||
|
|
||||||
# Self Conditional
|
|
||||||
z1, z2 = self.split(z)
|
|
||||||
scale, shift = self.feature_extract_aff(z1, ft, self.fAffine)
|
|
||||||
self.asserts(scale, shift, z1, z2)
|
|
||||||
z2 = z2 + shift
|
|
||||||
z2 = z2 * scale
|
|
||||||
|
|
||||||
logdet = logdet + self.get_logdet(scale)
|
|
||||||
z = thops.cat_feature(z1, z2)
|
|
||||||
output = z
|
|
||||||
else:
|
|
||||||
z = input
|
|
||||||
|
|
||||||
# Self Conditional
|
|
||||||
z1, z2 = self.split(z)
|
|
||||||
scale, shift = self.feature_extract_aff(z1, ft, self.fAffine)
|
|
||||||
self.asserts(scale, shift, z1, z2)
|
|
||||||
z2 = z2 / scale
|
|
||||||
z2 = z2 - shift
|
|
||||||
z = thops.cat_feature(z1, z2)
|
|
||||||
logdet = logdet - self.get_logdet(scale)
|
|
||||||
|
|
||||||
# Feature Conditional
|
|
||||||
scaleFt, shiftFt = self.feature_extract(ft, self.fFeatures)
|
|
||||||
z = z / scaleFt
|
|
||||||
z = z - shiftFt
|
|
||||||
logdet = logdet - self.get_logdet(scaleFt)
|
|
||||||
|
|
||||||
output = z
|
|
||||||
return output, logdet
|
|
||||||
|
|
||||||
def asserts(self, scale, shift, z1, z2):
|
|
||||||
assert z1.shape[1] == self.channels_for_nn, (z1.shape[1], self.channels_for_nn)
|
|
||||||
assert z2.shape[1] == self.channels_for_co, (z2.shape[1], self.channels_for_co)
|
|
||||||
assert scale.shape[1] == shift.shape[1], (scale.shape[1], shift.shape[1])
|
|
||||||
assert scale.shape[1] == z2.shape[1], (scale.shape[1], z1.shape[1], z2.shape[1])
|
|
||||||
|
|
||||||
def get_logdet(self, scale):
|
|
||||||
return thops.sum(torch.log(scale), dim=[1, 2, 3])
|
|
||||||
|
|
||||||
def feature_extract(self, z, f):
|
|
||||||
h = f(z)
|
|
||||||
shift, scale = thops.split_feature(h, "cross")
|
|
||||||
scale = (torch.sigmoid(scale + 2.) + self.affine_eps)
|
|
||||||
return scale, shift
|
|
||||||
|
|
||||||
def feature_extract_aff(self, z1, ft, f):
|
|
||||||
z = torch.cat([z1, ft], dim=1)
|
|
||||||
h = f(z)
|
|
||||||
shift, scale = thops.split_feature(h, "cross")
|
|
||||||
scale = (torch.sigmoid(scale + 2.) + self.affine_eps)
|
|
||||||
return scale, shift
|
|
||||||
|
|
||||||
def split(self, z):
|
|
||||||
z1 = z[:, :self.channels_for_nn]
|
|
||||||
z2 = z[:, self.channels_for_nn:]
|
|
||||||
assert z1.shape[1] + z2.shape[1] == z.shape[1], (z1.shape[1], z2.shape[1], z.shape[1])
|
|
||||||
return z1, z2
|
|
||||||
|
|
||||||
def F(self, in_channels, out_channels, hidden_channels, kernel_hidden=1, n_hidden_layers=1):
|
|
||||||
layers = [Conv2d(in_channels, hidden_channels), nn.ReLU(inplace=False)]
|
|
||||||
|
|
||||||
for _ in range(n_hidden_layers):
|
|
||||||
layers.append(Conv2d(hidden_channels, hidden_channels, kernel_size=[kernel_hidden, kernel_hidden]))
|
|
||||||
layers.append(nn.ReLU(inplace=False))
|
|
||||||
layers.append(Conv2dZeros(hidden_channels, out_channels))
|
|
||||||
|
|
||||||
return nn.Sequential(*layers)
|
|
|
@ -1,117 +0,0 @@
|
||||||
import torch
|
|
||||||
from torch import nn as nn
|
|
||||||
|
|
||||||
from models.archs.srflow import flow, thops, FlowAffineCouplingsAblation, FlowActNorms, Permutations
|
|
||||||
|
|
||||||
|
|
||||||
def getConditional(rrdbResults, position):
|
|
||||||
img_ft = rrdbResults if isinstance(rrdbResults, torch.Tensor) else rrdbResults[position]
|
|
||||||
return img_ft
|
|
||||||
|
|
||||||
|
|
||||||
class FlowStep(nn.Module):
|
|
||||||
FlowPermutation = {
|
|
||||||
"reverse": lambda obj, z, logdet, rev: (obj.reverse(z, rev), logdet),
|
|
||||||
"shuffle": lambda obj, z, logdet, rev: (obj.shuffle(z, rev), logdet),
|
|
||||||
"invconv": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
|
|
||||||
"squeeze_invconv": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
|
|
||||||
"resqueeze_invconv_alternating_2_3": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
|
|
||||||
"resqueeze_invconv_3": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
|
|
||||||
"InvertibleConv1x1GridAlign": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
|
|
||||||
"InvertibleConv1x1SubblocksShuf": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
|
|
||||||
"InvertibleConv1x1GridAlignIndepBorder": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
|
|
||||||
"InvertibleConv1x1GridAlignIndepBorder4": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(self, in_channels, hidden_channels,
|
|
||||||
actnorm_scale=1.0, flow_permutation="invconv", flow_coupling="additive",
|
|
||||||
LU_decomposed=False, image_injector=None, idx=None, acOpt=None, normOpt=None, in_shape=None,
|
|
||||||
position=None):
|
|
||||||
# check configures
|
|
||||||
assert flow_permutation in FlowStep.FlowPermutation, \
|
|
||||||
"float_permutation should be in `{}`".format(
|
|
||||||
FlowStep.FlowPermutation.keys())
|
|
||||||
super().__init__()
|
|
||||||
self.flow_permutation = flow_permutation
|
|
||||||
self.flow_coupling = flow_coupling
|
|
||||||
self.image_injector = image_injector
|
|
||||||
|
|
||||||
self.norm_type = normOpt['type'] if normOpt else 'ActNorm2d'
|
|
||||||
self.position = normOpt['position'] if normOpt else None
|
|
||||||
|
|
||||||
self.in_shape = in_shape
|
|
||||||
self.position = position
|
|
||||||
self.acOpt = acOpt
|
|
||||||
|
|
||||||
# 1. actnorm
|
|
||||||
self.actnorm = FlowActNorms.ActNorm2d(in_channels, actnorm_scale)
|
|
||||||
|
|
||||||
# 2. permute
|
|
||||||
if flow_permutation == "invconv":
|
|
||||||
self.invconv = Permutations.InvertibleConv1x1(
|
|
||||||
in_channels, LU_decomposed=LU_decomposed)
|
|
||||||
|
|
||||||
# 3. coupling
|
|
||||||
if flow_coupling == "CondAffineSeparatedAndCond":
|
|
||||||
self.affine = FlowAffineCouplingsAblation.CondAffineSeparatedAndCond(in_channels=in_channels)
|
|
||||||
elif flow_coupling == "noCoupling":
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
raise RuntimeError("coupling not Found:", flow_coupling)
|
|
||||||
|
|
||||||
def forward(self, input, logdet=None, reverse=False, rrdbResults=None):
|
|
||||||
if not reverse:
|
|
||||||
return self.normal_flow(input, logdet, rrdbResults)
|
|
||||||
else:
|
|
||||||
return self.reverse_flow(input, logdet, rrdbResults)
|
|
||||||
|
|
||||||
def normal_flow(self, z, logdet, rrdbResults=None):
|
|
||||||
if self.flow_coupling == "bentIdentityPreAct":
|
|
||||||
z, logdet = self.bentIdentPar(z, logdet, reverse=False)
|
|
||||||
|
|
||||||
# 1. actnorm
|
|
||||||
if self.norm_type == "ConditionalActNormImageInjector":
|
|
||||||
img_ft = getConditional(rrdbResults, self.position)
|
|
||||||
z, logdet = self.actnorm(z, img_ft=img_ft, logdet=logdet, reverse=False)
|
|
||||||
elif self.norm_type == "noNorm":
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
z, logdet = self.actnorm(z, logdet=logdet, reverse=False)
|
|
||||||
|
|
||||||
# 2. permute
|
|
||||||
z, logdet = FlowStep.FlowPermutation[self.flow_permutation](
|
|
||||||
self, z, logdet, False)
|
|
||||||
|
|
||||||
need_features = self.affine_need_features()
|
|
||||||
|
|
||||||
# 3. coupling
|
|
||||||
if need_features or self.flow_coupling in ["condAffine", "condFtAffine", "condNormAffine"]:
|
|
||||||
img_ft = getConditional(rrdbResults, self.position)
|
|
||||||
z, logdet = self.affine(input=z, logdet=logdet, reverse=False, ft=img_ft)
|
|
||||||
return z, logdet
|
|
||||||
|
|
||||||
def reverse_flow(self, z, logdet, rrdbResults=None):
|
|
||||||
|
|
||||||
need_features = self.affine_need_features()
|
|
||||||
|
|
||||||
# 1.coupling
|
|
||||||
if need_features or self.flow_coupling in ["condAffine", "condFtAffine", "condNormAffine"]:
|
|
||||||
img_ft = getConditional(rrdbResults, self.position)
|
|
||||||
z, logdet = self.affine(input=z, logdet=logdet, reverse=True, ft=img_ft)
|
|
||||||
|
|
||||||
# 2. permute
|
|
||||||
z, logdet = FlowStep.FlowPermutation[self.flow_permutation](
|
|
||||||
self, z, logdet, True)
|
|
||||||
|
|
||||||
# 3. actnorm
|
|
||||||
z, logdet = self.actnorm(z, logdet=logdet, reverse=True)
|
|
||||||
|
|
||||||
return z, logdet
|
|
||||||
|
|
||||||
def affine_need_features(self):
|
|
||||||
need_features = False
|
|
||||||
try:
|
|
||||||
need_features = self.affine.need_features
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
return need_features
|
|
|
@ -1,267 +0,0 @@
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torch import nn as nn
|
|
||||||
|
|
||||||
import models.archs.srflow.Split
|
|
||||||
from models.archs.srflow import flow, thops, Split
|
|
||||||
from models.archs.srflow.Split import Split2d
|
|
||||||
from models.archs.srflow.glow_arch import f_conv2d_bias
|
|
||||||
from models.archs.srflow.FlowStep import FlowStep
|
|
||||||
from utils.util import opt_get
|
|
||||||
import torchvision
|
|
||||||
|
|
||||||
|
|
||||||
class FlowUpsamplerNet(nn.Module):
|
|
||||||
def __init__(self, image_shape, hidden_channels, scale,
|
|
||||||
rrdb_blocks,
|
|
||||||
actnorm_scale=1.0,
|
|
||||||
flow_permutation='invconv',
|
|
||||||
flow_coupling="affine",
|
|
||||||
LU_decomposed=False, K=16, L=3,
|
|
||||||
norm_opt=None,
|
|
||||||
n_bypass_channels=None):
|
|
||||||
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.layers = nn.ModuleList()
|
|
||||||
self.output_shapes = []
|
|
||||||
self.L = L
|
|
||||||
self.K = K
|
|
||||||
self.scale=scale
|
|
||||||
if isinstance(self.K, int):
|
|
||||||
self.K = [K for K in [K, ] * (self.L + 1)]
|
|
||||||
|
|
||||||
H, W, self.C = image_shape
|
|
||||||
self.image_shape = image_shape
|
|
||||||
self.check_image_shape()
|
|
||||||
|
|
||||||
if scale == 16:
|
|
||||||
self.levelToName = {
|
|
||||||
0: 'fea_up16',
|
|
||||||
1: 'fea_up8',
|
|
||||||
2: 'fea_up4',
|
|
||||||
3: 'fea_up2',
|
|
||||||
4: 'fea_up1',
|
|
||||||
}
|
|
||||||
|
|
||||||
if scale == 8:
|
|
||||||
self.levelToName = {
|
|
||||||
0: 'fea_up8',
|
|
||||||
1: 'fea_up4',
|
|
||||||
2: 'fea_up2',
|
|
||||||
3: 'fea_up1',
|
|
||||||
4: 'fea_up0'
|
|
||||||
}
|
|
||||||
|
|
||||||
elif scale == 4:
|
|
||||||
self.levelToName = {
|
|
||||||
0: 'fea_up4',
|
|
||||||
1: 'fea_up2',
|
|
||||||
2: 'fea_up1',
|
|
||||||
3: 'fea_up0',
|
|
||||||
4: 'fea_up-1'
|
|
||||||
}
|
|
||||||
|
|
||||||
affineInCh = self.get_affineInCh(rrdb_blocks)
|
|
||||||
|
|
||||||
conditional_channels = {}
|
|
||||||
n_rrdb = self.get_n_rrdb_channels(rrdb_blocks)
|
|
||||||
conditional_channels[0] = n_rrdb
|
|
||||||
for level in range(1, self.L + 1):
|
|
||||||
# Level 1 gets conditionals from 2, 3, 4 => L - level
|
|
||||||
# Level 2 gets conditionals from 3, 4
|
|
||||||
# Level 3 gets conditionals from 4
|
|
||||||
# Level 4 gets conditionals from None
|
|
||||||
n_bypass = 0 if n_bypass_channels is None else (self.L - level) * n_bypass_channels
|
|
||||||
conditional_channels[level] = n_rrdb + n_bypass
|
|
||||||
|
|
||||||
# Upsampler
|
|
||||||
for level in range(1, self.L + 1):
|
|
||||||
# 1. Squeeze
|
|
||||||
H, W = self.arch_squeeze(H, W)
|
|
||||||
|
|
||||||
# 2. K FlowStep
|
|
||||||
self.arch_additionalFlowAffine(H, LU_decomposed, W, actnorm_scale, hidden_channels)
|
|
||||||
self.arch_FlowStep(H, self.K[level], LU_decomposed, W, actnorm_scale, affineInCh, flow_coupling,
|
|
||||||
flow_permutation,
|
|
||||||
hidden_channels, norm_opt,
|
|
||||||
n_conditional_channels=conditional_channels[level])
|
|
||||||
# Split
|
|
||||||
self.arch_split(H, W, level, self.L)
|
|
||||||
|
|
||||||
self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64 // 2 // 2)
|
|
||||||
self.H = H
|
|
||||||
self.W = W
|
|
||||||
|
|
||||||
def get_n_rrdb_channels(self, blocks):
|
|
||||||
n_rrdb = 64 if blocks is None else (len(blocks) + 1) * 64
|
|
||||||
return n_rrdb
|
|
||||||
|
|
||||||
def arch_FlowStep(self, H, K, LU_decomposed, W, actnorm_scale, affineInCh, flow_coupling, flow_permutation,
|
|
||||||
hidden_channels, normOpt, n_conditional_channels=None, condAff=None):
|
|
||||||
if condAff is not None:
|
|
||||||
condAff['in_channels_rrdb'] = n_conditional_channels
|
|
||||||
|
|
||||||
for k in range(K):
|
|
||||||
position_name = self.get_position_name(H, self.scale)
|
|
||||||
if normOpt: normOpt['position'] = position_name
|
|
||||||
|
|
||||||
self.layers.append(
|
|
||||||
FlowStep(in_channels=self.C,
|
|
||||||
hidden_channels=hidden_channels,
|
|
||||||
actnorm_scale=actnorm_scale,
|
|
||||||
flow_permutation=flow_permutation,
|
|
||||||
flow_coupling=flow_coupling,
|
|
||||||
acOpt=condAff,
|
|
||||||
position=position_name,
|
|
||||||
LU_decomposed=LU_decomposed, idx=k, normOpt=normOpt))
|
|
||||||
self.output_shapes.append(
|
|
||||||
[-1, self.C, H, W])
|
|
||||||
|
|
||||||
def arch_split(self, H, W, L, levels, split_flow=True, correct_splits=False, logs_eps=0, consume_ratio=.5, split_conditional=False, cond_channels=None, split_type='Split2d'):
|
|
||||||
correction = 0 if correct_splits else 1
|
|
||||||
if split_flow and L < levels - correction:
|
|
||||||
logs_eps = logs_eps
|
|
||||||
consume_ratio = consume_ratio
|
|
||||||
position_name = self.get_position_name(H, self.scale)
|
|
||||||
position = position_name if split_conditional else None
|
|
||||||
cond_channels = 0 if cond_channels is None else cond_channels
|
|
||||||
|
|
||||||
if split_type == 'Split2d':
|
|
||||||
split = Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position,
|
|
||||||
cond_channels=cond_channels, consume_ratio=consume_ratio)
|
|
||||||
self.layers.append(split)
|
|
||||||
self.output_shapes.append([-1, split.num_channels_pass, H, W])
|
|
||||||
self.C = split.num_channels_pass
|
|
||||||
|
|
||||||
def arch_additionalFlowAffine(self, H, LU_decomposed, W, actnorm_scale, hidden_channels, additionalFlowNoAffine=2):
|
|
||||||
for _ in range(additionalFlowNoAffine):
|
|
||||||
self.layers.append(
|
|
||||||
FlowStep(in_channels=self.C,
|
|
||||||
hidden_channels=hidden_channels,
|
|
||||||
actnorm_scale=actnorm_scale,
|
|
||||||
flow_permutation='invconv',
|
|
||||||
flow_coupling='noCoupling',
|
|
||||||
LU_decomposed=LU_decomposed))
|
|
||||||
self.output_shapes.append(
|
|
||||||
[-1, self.C, H, W])
|
|
||||||
|
|
||||||
def arch_squeeze(self, H, W):
|
|
||||||
self.C, H, W = self.C * 4, H // 2, W // 2
|
|
||||||
self.layers.append(flow.SqueezeLayer(factor=2))
|
|
||||||
self.output_shapes.append([-1, self.C, H, W])
|
|
||||||
return H, W
|
|
||||||
|
|
||||||
def get_affineInCh(self, rrdb_blocks):
|
|
||||||
affineInCh = (len(rrdb_blocks) + 1) * 64
|
|
||||||
return affineInCh
|
|
||||||
|
|
||||||
def check_image_shape(self):
|
|
||||||
assert self.C == 1 or self.C == 3, ("image_shape should be HWC, like (64, 64, 3)"
|
|
||||||
"self.C == 1 or self.C == 3")
|
|
||||||
|
|
||||||
def forward(self, gt=None, rrdbResults=None, z=None, epses=None, logdet=0., reverse=False, eps_std=None,
|
|
||||||
y_onehot=None):
|
|
||||||
|
|
||||||
if reverse:
|
|
||||||
epses_copy = [eps for eps in epses] if isinstance(epses, list) else epses
|
|
||||||
|
|
||||||
sr, logdet = self.decode(rrdbResults, z, eps_std, epses=epses_copy, logdet=logdet, y_onehot=y_onehot)
|
|
||||||
return sr, logdet
|
|
||||||
else:
|
|
||||||
assert gt is not None
|
|
||||||
assert rrdbResults is not None
|
|
||||||
z, logdet = self.encode(gt, rrdbResults, logdet=logdet, epses=epses, y_onehot=y_onehot)
|
|
||||||
|
|
||||||
return z, logdet
|
|
||||||
|
|
||||||
def encode(self, gt, rrdbResults, logdet=0.0, epses=None, y_onehot=None):
|
|
||||||
fl_fea = gt
|
|
||||||
reverse = False
|
|
||||||
level_conditionals = {}
|
|
||||||
bypasses = {}
|
|
||||||
|
|
||||||
L = self.L
|
|
||||||
|
|
||||||
for level in range(1, L + 1):
|
|
||||||
bypasses[level] = torch.nn.functional.interpolate(gt, scale_factor=2 ** -level, mode='bilinear', align_corners=False)
|
|
||||||
|
|
||||||
for layer, shape in zip(self.layers, self.output_shapes):
|
|
||||||
size = shape[2]
|
|
||||||
level = int(np.log(self.image_shape[0] / size) / np.log(2))
|
|
||||||
|
|
||||||
if level > 0 and level not in level_conditionals.keys():
|
|
||||||
level_conditionals[level] = rrdbResults[self.levelToName[level]]
|
|
||||||
|
|
||||||
level_conditionals[level] = rrdbResults[self.levelToName[level]]
|
|
||||||
|
|
||||||
if isinstance(layer, FlowStep):
|
|
||||||
fl_fea, logdet = layer(fl_fea, logdet, reverse=reverse, rrdbResults=level_conditionals[level])
|
|
||||||
elif isinstance(layer, Split2d):
|
|
||||||
fl_fea, logdet = self.forward_split2d(epses, fl_fea, layer, logdet, reverse, level_conditionals[level],
|
|
||||||
y_onehot=y_onehot)
|
|
||||||
else:
|
|
||||||
fl_fea, logdet = layer(fl_fea, logdet, reverse=reverse)
|
|
||||||
|
|
||||||
z = fl_fea
|
|
||||||
|
|
||||||
if not isinstance(epses, list):
|
|
||||||
return z, logdet
|
|
||||||
|
|
||||||
epses.append(z)
|
|
||||||
return epses, logdet
|
|
||||||
|
|
||||||
def forward_preFlow(self, fl_fea, logdet, reverse):
|
|
||||||
if hasattr(self, 'preFlow'):
|
|
||||||
for l in self.preFlow:
|
|
||||||
fl_fea, logdet = l(fl_fea, logdet, reverse=reverse)
|
|
||||||
return fl_fea, logdet
|
|
||||||
|
|
||||||
def forward_split2d(self, epses, fl_fea, layer, logdet, reverse, rrdbResults, y_onehot=None):
|
|
||||||
ft = None if layer.position is None else rrdbResults[layer.position]
|
|
||||||
fl_fea, logdet, eps = layer(fl_fea, logdet, reverse=reverse, eps=epses, ft=ft, y_onehot=y_onehot)
|
|
||||||
|
|
||||||
if isinstance(epses, list):
|
|
||||||
epses.append(eps)
|
|
||||||
return fl_fea, logdet
|
|
||||||
|
|
||||||
def decode(self, rrdbResults, z, eps_std=None, epses=None, logdet=0.0, y_onehot=None):
|
|
||||||
z = epses.pop() if isinstance(epses, list) else z
|
|
||||||
|
|
||||||
fl_fea = z
|
|
||||||
# debug.imwrite("fl_fea", fl_fea)
|
|
||||||
bypasses = {}
|
|
||||||
level_conditionals = {}
|
|
||||||
for level in range(self.L + 1):
|
|
||||||
level_conditionals[level] = rrdbResults[self.levelToName[level]]
|
|
||||||
|
|
||||||
for layer, shape in zip(reversed(self.layers), reversed(self.output_shapes)):
|
|
||||||
size = shape[2]
|
|
||||||
level = int(np.log(self.H / size) / np.log(2))
|
|
||||||
|
|
||||||
if isinstance(layer, Split2d):
|
|
||||||
fl_fea, logdet = self.forward_split2d_reverse(eps_std, epses, fl_fea, layer,
|
|
||||||
rrdbResults[self.levelToName[level]], logdet=logdet,
|
|
||||||
y_onehot=y_onehot)
|
|
||||||
elif isinstance(layer, FlowStep):
|
|
||||||
fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True, rrdbResults=level_conditionals[level])
|
|
||||||
else:
|
|
||||||
fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True)
|
|
||||||
|
|
||||||
sr = fl_fea
|
|
||||||
|
|
||||||
assert sr.shape[1] == 3
|
|
||||||
return sr, logdet
|
|
||||||
|
|
||||||
def forward_split2d_reverse(self, eps_std, epses, fl_fea, layer, rrdbResults, logdet, y_onehot=None):
|
|
||||||
ft = None if layer.position is None else rrdbResults[layer.position]
|
|
||||||
fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True,
|
|
||||||
eps=epses.pop() if isinstance(epses, list) else None,
|
|
||||||
eps_std=eps_std, ft=ft, y_onehot=y_onehot)
|
|
||||||
return fl_fea, logdet
|
|
||||||
|
|
||||||
|
|
||||||
def get_position_name(self, H, scale):
|
|
||||||
downscale_factor = self.image_shape[0] // H
|
|
||||||
position_name = 'fea_up{}'.format(scale / downscale_factor)
|
|
||||||
return position_name
|
|
|
@ -1,42 +0,0 @@
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torch import nn as nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
from models.archs.srflow import thops
|
|
||||||
|
|
||||||
|
|
||||||
class InvertibleConv1x1(nn.Module):
|
|
||||||
def __init__(self, num_channels, LU_decomposed=False):
|
|
||||||
super().__init__()
|
|
||||||
w_shape = [num_channels, num_channels]
|
|
||||||
w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype(np.float32)
|
|
||||||
self.register_parameter("weight", nn.Parameter(torch.Tensor(w_init)))
|
|
||||||
self.w_shape = w_shape
|
|
||||||
self.LU = LU_decomposed
|
|
||||||
|
|
||||||
def get_weight(self, input, reverse):
|
|
||||||
w_shape = self.w_shape
|
|
||||||
pixels = thops.pixels(input)
|
|
||||||
dlogdet = torch.slogdet(self.weight)[1] * pixels
|
|
||||||
if not reverse:
|
|
||||||
weight = self.weight.view(w_shape[0], w_shape[1], 1, 1)
|
|
||||||
else:
|
|
||||||
weight = torch.inverse(self.weight.double()).float() \
|
|
||||||
.view(w_shape[0], w_shape[1], 1, 1)
|
|
||||||
return weight, dlogdet
|
|
||||||
def forward(self, input, logdet=None, reverse=False):
|
|
||||||
"""
|
|
||||||
log-det = log|abs(|W|)| * pixels
|
|
||||||
"""
|
|
||||||
weight, dlogdet = self.get_weight(input, reverse)
|
|
||||||
if not reverse:
|
|
||||||
z = F.conv2d(input, weight)
|
|
||||||
if logdet is not None:
|
|
||||||
logdet = logdet + dlogdet
|
|
||||||
return z, logdet
|
|
||||||
else:
|
|
||||||
z = F.conv2d(input, weight)
|
|
||||||
if logdet is not None:
|
|
||||||
logdet = logdet - dlogdet
|
|
||||||
return z, logdet
|
|
|
@ -1,133 +0,0 @@
|
||||||
import functools
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import models.archs.srflow.module_util as mutil
|
|
||||||
from utils.util import opt_get, checkpoint
|
|
||||||
|
|
||||||
|
|
||||||
class ResidualDenseBlock_5C(nn.Module):
|
|
||||||
def __init__(self, nf=64, gc=32, bias=True):
|
|
||||||
super(ResidualDenseBlock_5C, self).__init__()
|
|
||||||
# gc: growth channel, i.e. intermediate channels
|
|
||||||
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
|
|
||||||
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
|
|
||||||
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
|
|
||||||
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
|
|
||||||
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
|
|
||||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
|
||||||
|
|
||||||
# initialization
|
|
||||||
mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x1 = self.lrelu(self.conv1(x))
|
|
||||||
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
|
||||||
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
|
||||||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
|
||||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
|
||||||
return x5 * 0.2 + x
|
|
||||||
|
|
||||||
|
|
||||||
class RRDB(nn.Module):
|
|
||||||
'''Residual in Residual Dense Block'''
|
|
||||||
|
|
||||||
def __init__(self, nf, gc=32):
|
|
||||||
super(RRDB, self).__init__()
|
|
||||||
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
|
|
||||||
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
|
|
||||||
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
out = self.RDB1(x)
|
|
||||||
out = self.RDB2(out)
|
|
||||||
out = self.RDB3(out)
|
|
||||||
return out * 0.2 + x
|
|
||||||
|
|
||||||
|
|
||||||
class RRDBNet(nn.Module):
|
|
||||||
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, block_outputs=[], fea_up0=True,
|
|
||||||
fea_up1=False):
|
|
||||||
super(RRDBNet, self).__init__()
|
|
||||||
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
|
||||||
self.scale = scale
|
|
||||||
self.block_outputs = block_outputs
|
|
||||||
self.fea_up0 = fea_up0
|
|
||||||
self.fea_up1 = fea_up1
|
|
||||||
|
|
||||||
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
|
||||||
self.RRDB_trunk = mutil.make_layer(RRDB_block_f, nb)
|
|
||||||
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
|
||||||
#### upsampling
|
|
||||||
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
|
||||||
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
|
||||||
if self.scale >= 8:
|
|
||||||
self.upconv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
|
||||||
if self.scale >= 16:
|
|
||||||
self.upconv4 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
|
||||||
if self.scale >= 32:
|
|
||||||
self.upconv5 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
|
||||||
|
|
||||||
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
|
||||||
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
|
||||||
|
|
||||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
|
||||||
|
|
||||||
def forward(self, x, get_steps=False):
|
|
||||||
fea = self.conv_first(x)
|
|
||||||
|
|
||||||
block_idxs = self.block_outputs or []
|
|
||||||
block_results = {}
|
|
||||||
|
|
||||||
for idx, m in enumerate(self.RRDB_trunk.children()):
|
|
||||||
fea = checkpoint(m, fea)
|
|
||||||
for b in block_idxs:
|
|
||||||
if b == idx:
|
|
||||||
block_results["block_{}".format(idx)] = fea
|
|
||||||
|
|
||||||
trunk = self.trunk_conv(fea)
|
|
||||||
|
|
||||||
last_lr_fea = fea + trunk
|
|
||||||
|
|
||||||
fea_up2 = self.upconv1(F.interpolate(last_lr_fea, scale_factor=2, mode='nearest'))
|
|
||||||
fea = self.lrelu(fea_up2)
|
|
||||||
|
|
||||||
fea_up4 = self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))
|
|
||||||
fea = self.lrelu(fea_up4)
|
|
||||||
|
|
||||||
fea_up8 = None
|
|
||||||
fea_up16 = None
|
|
||||||
fea_up32 = None
|
|
||||||
|
|
||||||
if self.scale >= 8:
|
|
||||||
fea_up8 = self.upconv3(F.interpolate(fea, scale_factor=2, mode='nearest'))
|
|
||||||
fea = self.lrelu(fea_up8)
|
|
||||||
if self.scale >= 16:
|
|
||||||
fea_up16 = self.upconv4(F.interpolate(fea, scale_factor=2, mode='nearest'))
|
|
||||||
fea = self.lrelu(fea_up16)
|
|
||||||
if self.scale >= 32:
|
|
||||||
fea_up32 = self.upconv5(F.interpolate(fea, scale_factor=2, mode='nearest'))
|
|
||||||
fea = self.lrelu(fea_up32)
|
|
||||||
|
|
||||||
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
|
||||||
|
|
||||||
results = {'last_lr_fea': last_lr_fea,
|
|
||||||
'fea_up1': last_lr_fea,
|
|
||||||
'fea_up2': fea_up2,
|
|
||||||
'fea_up4': fea_up4,
|
|
||||||
'fea_up8': fea_up8,
|
|
||||||
'fea_up16': fea_up16,
|
|
||||||
'fea_up32': fea_up32,
|
|
||||||
'out': out}
|
|
||||||
|
|
||||||
if self.fea_up0:
|
|
||||||
results['fea_up0'] = F.interpolate(last_lr_fea, scale_factor=1/2, mode='bilinear', align_corners=False, recompute_scale_factor=True)
|
|
||||||
if self.fea_up1:
|
|
||||||
results['fea_up-1'] = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True)
|
|
||||||
|
|
||||||
if get_steps:
|
|
||||||
for k, v in block_results.items():
|
|
||||||
results[k] = v
|
|
||||||
return results
|
|
||||||
else:
|
|
||||||
return out
|
|
|
@ -1,68 +0,0 @@
|
||||||
import torch
|
|
||||||
from torch import nn as nn
|
|
||||||
|
|
||||||
from models.archs.srflow import thops
|
|
||||||
from models.archs.srflow.FlowStep import FlowStep
|
|
||||||
from models.archs.srflow.flow import Conv2dZeros, GaussianDiag
|
|
||||||
|
|
||||||
|
|
||||||
class Split2d(nn.Module):
|
|
||||||
def __init__(self, num_channels, logs_eps=0, cond_channels=0, position=None, consume_ratio=0.5):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.num_channels_consume = int(round(num_channels * consume_ratio))
|
|
||||||
self.num_channels_pass = num_channels - self.num_channels_consume
|
|
||||||
|
|
||||||
self.conv = Conv2dZeros(in_channels=self.num_channels_pass + cond_channels,
|
|
||||||
out_channels=self.num_channels_consume * 2)
|
|
||||||
self.logs_eps = logs_eps
|
|
||||||
self.position = position
|
|
||||||
|
|
||||||
def split2d_prior(self, z, ft):
|
|
||||||
if ft is not None:
|
|
||||||
z = torch.cat([z, ft], dim=1)
|
|
||||||
h = self.conv(z)
|
|
||||||
return thops.split_feature(h, "cross")
|
|
||||||
|
|
||||||
def exp_eps(self, logs):
|
|
||||||
return torch.exp(logs) + self.logs_eps
|
|
||||||
|
|
||||||
def forward(self, input, logdet=0., reverse=False, eps_std=None, eps=None, ft=None, y_onehot=None):
|
|
||||||
if not reverse:
|
|
||||||
# self.input = input
|
|
||||||
z1, z2 = self.split_ratio(input)
|
|
||||||
mean, logs = self.split2d_prior(z1, ft)
|
|
||||||
|
|
||||||
eps = (z2 - mean) / self.exp_eps(logs)
|
|
||||||
|
|
||||||
logdet = logdet + self.get_logdet(logs, mean, z2)
|
|
||||||
|
|
||||||
# print(logs.shape, mean.shape, z2.shape)
|
|
||||||
# self.eps = eps
|
|
||||||
# print('split, enc eps:', eps)
|
|
||||||
return z1, logdet, eps
|
|
||||||
else:
|
|
||||||
z1 = input
|
|
||||||
mean, logs = self.split2d_prior(z1, ft)
|
|
||||||
|
|
||||||
if eps is None:
|
|
||||||
#print("WARNING: eps is None, generating eps untested functionality!")
|
|
||||||
eps = GaussianDiag.sample_eps(mean.shape, eps_std)
|
|
||||||
|
|
||||||
eps = eps.to(mean.device)
|
|
||||||
z2 = mean + self.exp_eps(logs) * eps
|
|
||||||
|
|
||||||
z = thops.cat_feature(z1, z2)
|
|
||||||
logdet = logdet - self.get_logdet(logs, mean, z2)
|
|
||||||
|
|
||||||
return z, logdet
|
|
||||||
# return z, logdet, eps
|
|
||||||
|
|
||||||
def get_logdet(self, logs, mean, z2):
|
|
||||||
logdet_diff = GaussianDiag.logp(mean, logs, z2)
|
|
||||||
# print("Split2D: logdet diff", logdet_diff.item())
|
|
||||||
return logdet_diff
|
|
||||||
|
|
||||||
def split_ratio(self, input):
|
|
||||||
z1, z2 = input[:, :self.num_channels_pass, ...], input[:, self.num_channels_pass:, ...]
|
|
||||||
return z1, z2
|
|
|
@ -1,150 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from models.archs.srflow.FlowActNorms import ActNorm2d
|
|
||||||
from . import thops
|
|
||||||
|
|
||||||
|
|
||||||
class Conv2d(nn.Conv2d):
|
|
||||||
pad_dict = {
|
|
||||||
"same": lambda kernel, stride: [((k - 1) * s + 1) // 2 for k, s in zip(kernel, stride)],
|
|
||||||
"valid": lambda kernel, stride: [0 for _ in kernel]
|
|
||||||
}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_padding(padding, kernel_size, stride):
|
|
||||||
# make paddding
|
|
||||||
if isinstance(padding, str):
|
|
||||||
if isinstance(kernel_size, int):
|
|
||||||
kernel_size = [kernel_size, kernel_size]
|
|
||||||
if isinstance(stride, int):
|
|
||||||
stride = [stride, stride]
|
|
||||||
padding = padding.lower()
|
|
||||||
try:
|
|
||||||
padding = Conv2d.pad_dict[padding](kernel_size, stride)
|
|
||||||
except KeyError:
|
|
||||||
raise ValueError("{} is not supported".format(padding))
|
|
||||||
return padding
|
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels,
|
|
||||||
kernel_size=[3, 3], stride=[1, 1],
|
|
||||||
padding="same", do_actnorm=True, weight_std=0.05):
|
|
||||||
padding = Conv2d.get_padding(padding, kernel_size, stride)
|
|
||||||
super().__init__(in_channels, out_channels, kernel_size, stride,
|
|
||||||
padding, bias=(not do_actnorm))
|
|
||||||
# init weight with std
|
|
||||||
self.weight.data.normal_(mean=0.0, std=weight_std)
|
|
||||||
if not do_actnorm:
|
|
||||||
self.bias.data.zero_()
|
|
||||||
else:
|
|
||||||
self.actnorm = ActNorm2d(out_channels)
|
|
||||||
self.do_actnorm = do_actnorm
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
x = super().forward(input)
|
|
||||||
if self.do_actnorm:
|
|
||||||
x, _ = self.actnorm(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Conv2dZeros(nn.Conv2d):
|
|
||||||
def __init__(self, in_channels, out_channels,
|
|
||||||
kernel_size=[3, 3], stride=[1, 1],
|
|
||||||
padding="same", logscale_factor=3):
|
|
||||||
padding = Conv2d.get_padding(padding, kernel_size, stride)
|
|
||||||
super().__init__(in_channels, out_channels, kernel_size, stride, padding)
|
|
||||||
# logscale_factor
|
|
||||||
self.logscale_factor = logscale_factor
|
|
||||||
self.register_parameter("logs", nn.Parameter(torch.zeros(out_channels, 1, 1)))
|
|
||||||
# init
|
|
||||||
self.weight.data.zero_()
|
|
||||||
self.bias.data.zero_()
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
output = super().forward(input)
|
|
||||||
return output * torch.exp(self.logs * self.logscale_factor)
|
|
||||||
|
|
||||||
|
|
||||||
class GaussianDiag:
|
|
||||||
Log2PI = float(np.log(2 * np.pi))
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def likelihood(mean, logs, x):
|
|
||||||
"""
|
|
||||||
lnL = -1/2 * { ln|Var| + ((X - Mu)^T)(Var^-1)(X - Mu) + kln(2*PI) }
|
|
||||||
k = 1 (Independent)
|
|
||||||
Var = logs ** 2
|
|
||||||
"""
|
|
||||||
if mean is None and logs is None:
|
|
||||||
return -0.5 * (x ** 2 + GaussianDiag.Log2PI)
|
|
||||||
else:
|
|
||||||
return -0.5 * (logs * 2. + ((x - mean) ** 2) / torch.exp(logs * 2.) + GaussianDiag.Log2PI)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def logp(mean, logs, x):
|
|
||||||
likelihood = GaussianDiag.likelihood(mean, logs, x)
|
|
||||||
return thops.sum(likelihood, dim=[1, 2, 3])
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def sample(mean, logs, eps_std=None):
|
|
||||||
eps_std = eps_std or 1
|
|
||||||
eps = torch.normal(mean=torch.zeros_like(mean),
|
|
||||||
std=torch.ones_like(logs) * eps_std)
|
|
||||||
return mean + torch.exp(logs) * eps
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def sample_eps(shape, eps_std, seed=None):
|
|
||||||
if seed is not None:
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
eps = torch.normal(mean=torch.zeros(shape),
|
|
||||||
std=torch.ones(shape) * eps_std)
|
|
||||||
return eps
|
|
||||||
|
|
||||||
|
|
||||||
def squeeze2d(input, factor=2):
|
|
||||||
assert factor >= 1 and isinstance(factor, int)
|
|
||||||
if factor == 1:
|
|
||||||
return input
|
|
||||||
size = input.size()
|
|
||||||
B = size[0]
|
|
||||||
C = size[1]
|
|
||||||
H = size[2]
|
|
||||||
W = size[3]
|
|
||||||
assert H % factor == 0 and W % factor == 0, "{}".format((H, W, factor))
|
|
||||||
x = input.view(B, C, H // factor, factor, W // factor, factor)
|
|
||||||
x = x.permute(0, 1, 3, 5, 2, 4).contiguous()
|
|
||||||
x = x.view(B, C * factor * factor, H // factor, W // factor)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def unsqueeze2d(input, factor=2):
|
|
||||||
assert factor >= 1 and isinstance(factor, int)
|
|
||||||
factor2 = factor ** 2
|
|
||||||
if factor == 1:
|
|
||||||
return input
|
|
||||||
size = input.size()
|
|
||||||
B = size[0]
|
|
||||||
C = size[1]
|
|
||||||
H = size[2]
|
|
||||||
W = size[3]
|
|
||||||
assert C % (factor2) == 0, "{}".format(C)
|
|
||||||
x = input.view(B, C // factor2, factor, factor, H, W)
|
|
||||||
x = x.permute(0, 1, 4, 2, 5, 3).contiguous()
|
|
||||||
x = x.view(B, C // (factor2), H * factor, W * factor)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class SqueezeLayer(nn.Module):
|
|
||||||
def __init__(self, factor):
|
|
||||||
super().__init__()
|
|
||||||
self.factor = factor
|
|
||||||
|
|
||||||
def forward(self, input, logdet=None, reverse=False):
|
|
||||||
if not reverse:
|
|
||||||
output = squeeze2d(input, self.factor) # Squeeze in forward
|
|
||||||
return output, logdet
|
|
||||||
else:
|
|
||||||
output = unsqueeze2d(input, self.factor)
|
|
||||||
return output, logdet
|
|
|
@ -1,12 +0,0 @@
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
|
|
||||||
def f_conv2d_bias(in_channels, out_channels):
|
|
||||||
def padding_same(kernel, stride):
|
|
||||||
return [((k - 1) * s + 1) // 2 for k, s in zip(kernel, stride)]
|
|
||||||
|
|
||||||
padding = padding_same([3, 3], [1, 1])
|
|
||||||
assert padding == [1, 1], padding
|
|
||||||
return nn.Sequential(
|
|
||||||
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=[3, 3], stride=1, padding=1,
|
|
||||||
bias=True))
|
|
|
@ -1,79 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.init as init
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
def initialize_weights(net_l, scale=1):
|
|
||||||
if not isinstance(net_l, list):
|
|
||||||
net_l = [net_l]
|
|
||||||
for net in net_l:
|
|
||||||
for m in net.modules():
|
|
||||||
if isinstance(m, nn.Conv2d):
|
|
||||||
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
|
||||||
m.weight.data *= scale # for residual block
|
|
||||||
if m.bias is not None:
|
|
||||||
m.bias.data.zero_()
|
|
||||||
elif isinstance(m, nn.Linear):
|
|
||||||
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
|
||||||
m.weight.data *= scale
|
|
||||||
if m.bias is not None:
|
|
||||||
m.bias.data.zero_()
|
|
||||||
elif isinstance(m, nn.BatchNorm2d):
|
|
||||||
init.constant_(m.weight, 1)
|
|
||||||
init.constant_(m.bias.data, 0.0)
|
|
||||||
|
|
||||||
|
|
||||||
def make_layer(block, n_layers):
|
|
||||||
layers = []
|
|
||||||
for _ in range(n_layers):
|
|
||||||
layers.append(block())
|
|
||||||
return nn.Sequential(*layers)
|
|
||||||
|
|
||||||
|
|
||||||
class ResidualBlock_noBN(nn.Module):
|
|
||||||
'''Residual block w/o BN
|
|
||||||
---Conv-ReLU-Conv-+-
|
|
||||||
|________________|
|
|
||||||
'''
|
|
||||||
|
|
||||||
def __init__(self, nf=64):
|
|
||||||
super(ResidualBlock_noBN, self).__init__()
|
|
||||||
self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
|
||||||
self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
|
||||||
|
|
||||||
# initialization
|
|
||||||
initialize_weights([self.conv1, self.conv2], 0.1)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
identity = x
|
|
||||||
out = F.relu(self.conv1(x), inplace=True)
|
|
||||||
out = self.conv2(out)
|
|
||||||
return identity + out
|
|
||||||
|
|
||||||
|
|
||||||
def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'):
|
|
||||||
"""Warp an image or feature map with optical flow
|
|
||||||
Args:
|
|
||||||
x (Tensor): size (N, C, H, W)
|
|
||||||
flow (Tensor): size (N, H, W, 2), normal value
|
|
||||||
interp_mode (str): 'nearest' or 'bilinear'
|
|
||||||
padding_mode (str): 'zeros' or 'border' or 'reflection'
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: warped image or feature map
|
|
||||||
"""
|
|
||||||
assert x.size()[-2:] == flow.size()[1:3]
|
|
||||||
B, C, H, W = x.size()
|
|
||||||
# mesh grid
|
|
||||||
grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
|
|
||||||
grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
|
|
||||||
grid.requires_grad = False
|
|
||||||
grid = grid.type_as(x)
|
|
||||||
vgrid = grid + flow
|
|
||||||
# scale grid to [-1,1]
|
|
||||||
vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0
|
|
||||||
vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0
|
|
||||||
vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
|
|
||||||
output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode)
|
|
||||||
return output
|
|
|
@ -1,135 +0,0 @@
|
||||||
import math
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from models.archs.srflow.FlowUpsamplerNet import FlowUpsamplerNet
|
|
||||||
import models.archs.srflow.thops as thops
|
|
||||||
import models.archs.srflow.flow as flow
|
|
||||||
from models.archs.srflow.RRDBNet_arch import RRDBNet
|
|
||||||
|
|
||||||
|
|
||||||
class SRFlowNet(nn.Module):
|
|
||||||
def __init__(self, in_nc, out_nc, nf, nb, quant, flow_block_maps, noise_quant,
|
|
||||||
hidden_channels=64, gc=32, scale=4, K=16, L=3, train_rrdb_at_step=0,
|
|
||||||
hr_img_shape=(128,128,3), coupling='CondAffineSeparatedAndCond'):
|
|
||||||
super(SRFlowNet, self).__init__()
|
|
||||||
|
|
||||||
self.scale = scale
|
|
||||||
self.noise_quant = noise_quant
|
|
||||||
self.quant = quant
|
|
||||||
self.flow_block_maps = flow_block_maps
|
|
||||||
self.RRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, flow_block_maps)
|
|
||||||
self.train_rrdb_step = train_rrdb_at_step
|
|
||||||
self.RRDB_training = True
|
|
||||||
|
|
||||||
self.flowUpsamplerNet = FlowUpsamplerNet(image_shape=hr_img_shape,
|
|
||||||
hidden_channels=hidden_channels,
|
|
||||||
scale=scale, rrdb_blocks=flow_block_maps,
|
|
||||||
K=K, L=L, flow_coupling=coupling)
|
|
||||||
self.i = 0
|
|
||||||
|
|
||||||
def forward(self, gt=None, lr=None, reverse=False, z=None, eps_std=None, epses=None, reverse_with_grad=False,
|
|
||||||
lr_enc=None,
|
|
||||||
add_gt_noise=False, step=None, y_label=None):
|
|
||||||
if not reverse:
|
|
||||||
return self.normal_flow(gt, lr, epses=epses, lr_enc=lr_enc, add_gt_noise=add_gt_noise, step=step,
|
|
||||||
y_onehot=y_label)
|
|
||||||
else:
|
|
||||||
assert lr.shape[1] == 3
|
|
||||||
if reverse_with_grad:
|
|
||||||
return self.reverse_flow(lr, z, y_onehot=y_label, eps_std=eps_std, epses=epses, lr_enc=lr_enc,
|
|
||||||
add_gt_noise=add_gt_noise)
|
|
||||||
else:
|
|
||||||
with torch.no_grad():
|
|
||||||
return self.reverse_flow(lr, z, y_onehot=y_label, eps_std=eps_std, epses=epses, lr_enc=lr_enc,
|
|
||||||
add_gt_noise=add_gt_noise)
|
|
||||||
|
|
||||||
def normal_flow(self, gt, lr, y_onehot=None, epses=None, lr_enc=None, add_gt_noise=True, step=None):
|
|
||||||
if lr_enc is None:
|
|
||||||
lr_enc = self.rrdbPreprocessing(lr)
|
|
||||||
|
|
||||||
logdet = torch.zeros_like(gt[:, 0, 0, 0])
|
|
||||||
pixels = thops.pixels(gt)
|
|
||||||
|
|
||||||
z = gt
|
|
||||||
|
|
||||||
if add_gt_noise:
|
|
||||||
# Setup
|
|
||||||
if self.noise_quant:
|
|
||||||
z = z + ((torch.rand(z.shape, device=z.device) - 0.5) / self.quant)
|
|
||||||
logdet = logdet + float(-np.log(self.quant) * pixels)
|
|
||||||
|
|
||||||
# Encode
|
|
||||||
epses, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, gt=z, logdet=logdet, reverse=False, epses=epses,
|
|
||||||
y_onehot=y_onehot)
|
|
||||||
|
|
||||||
objective = logdet.clone()
|
|
||||||
|
|
||||||
if isinstance(epses, (list, tuple)):
|
|
||||||
z = epses[-1]
|
|
||||||
else:
|
|
||||||
z = epses
|
|
||||||
|
|
||||||
objective = objective + flow.GaussianDiag.logp(None, None, z)
|
|
||||||
|
|
||||||
nll = (-objective) / float(np.log(2.) * pixels)
|
|
||||||
|
|
||||||
if isinstance(epses, list):
|
|
||||||
return epses, nll, logdet
|
|
||||||
return z, nll, logdet
|
|
||||||
|
|
||||||
def rrdbPreprocessing(self, lr):
|
|
||||||
rrdbResults = self.RRDB(lr, get_steps=True)
|
|
||||||
block_idxs = self.flow_block_maps
|
|
||||||
if len(block_idxs) > 0:
|
|
||||||
concat = torch.cat([rrdbResults["block_{}".format(idx)] for idx in block_idxs], dim=1)
|
|
||||||
|
|
||||||
keys = ['last_lr_fea', 'fea_up1', 'fea_up2', 'fea_up4']
|
|
||||||
if 'fea_up0' in rrdbResults.keys():
|
|
||||||
keys.append('fea_up0')
|
|
||||||
if 'fea_up-1' in rrdbResults.keys():
|
|
||||||
keys.append('fea_up-1')
|
|
||||||
if self.scale >= 8:
|
|
||||||
keys.append('fea_up8')
|
|
||||||
if self.scale == 16:
|
|
||||||
keys.append('fea_up16')
|
|
||||||
for k in keys:
|
|
||||||
h = rrdbResults[k].shape[2]
|
|
||||||
w = rrdbResults[k].shape[3]
|
|
||||||
rrdbResults[k] = torch.cat([rrdbResults[k], F.interpolate(concat, (h, w))], dim=1)
|
|
||||||
return rrdbResults
|
|
||||||
|
|
||||||
def get_score(self, disc_loss_sigma, z):
|
|
||||||
score_real = 0.5 * (1 - 1 / (disc_loss_sigma ** 2)) * thops.sum(z ** 2, dim=[1, 2, 3]) - \
|
|
||||||
z.shape[1] * z.shape[2] * z.shape[3] * math.log(disc_loss_sigma)
|
|
||||||
return -score_real
|
|
||||||
|
|
||||||
def reverse_flow(self, lr, z, y_onehot, eps_std, epses=None, lr_enc=None, add_gt_noise=True):
|
|
||||||
logdet = torch.zeros_like(lr[:, 0, 0, 0])
|
|
||||||
pixels = thops.pixels(lr) * self.scale ** 2
|
|
||||||
|
|
||||||
if add_gt_noise:
|
|
||||||
logdet = logdet - float(-np.log(self.quant) * pixels)
|
|
||||||
|
|
||||||
if lr_enc is None:
|
|
||||||
lr_enc = self.rrdbPreprocessing(lr)
|
|
||||||
|
|
||||||
x, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, z=z, eps_std=eps_std, reverse=True, epses=epses,
|
|
||||||
logdet=logdet)
|
|
||||||
|
|
||||||
return x, logdet
|
|
||||||
|
|
||||||
def set_rrdb_training(self, trainable):
|
|
||||||
if self.RRDB_training != trainable:
|
|
||||||
for p in self.RRDB.parameters():
|
|
||||||
if not trainable:
|
|
||||||
p.DO_NOT_TRAIN = True
|
|
||||||
elif hasattr(p, "DO_NOT_TRAIN"):
|
|
||||||
del p.DO_NOT_TRAIN
|
|
||||||
self.RRDB_training = trainable
|
|
||||||
|
|
||||||
def update_for_step(self, step, experiments_path='.'):
|
|
||||||
self.set_rrdb_training(step > self.train_rrdb_step)
|
|
|
@ -1,52 +0,0 @@
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def sum(tensor, dim=None, keepdim=False):
|
|
||||||
if dim is None:
|
|
||||||
# sum up all dim
|
|
||||||
return torch.sum(tensor)
|
|
||||||
else:
|
|
||||||
if isinstance(dim, int):
|
|
||||||
dim = [dim]
|
|
||||||
dim = sorted(dim)
|
|
||||||
for d in dim:
|
|
||||||
tensor = tensor.sum(dim=d, keepdim=True)
|
|
||||||
if not keepdim:
|
|
||||||
for i, d in enumerate(dim):
|
|
||||||
tensor.squeeze_(d-i)
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
|
|
||||||
def mean(tensor, dim=None, keepdim=False):
|
|
||||||
if dim is None:
|
|
||||||
# mean all dim
|
|
||||||
return torch.mean(tensor)
|
|
||||||
else:
|
|
||||||
if isinstance(dim, int):
|
|
||||||
dim = [dim]
|
|
||||||
dim = sorted(dim)
|
|
||||||
for d in dim:
|
|
||||||
tensor = tensor.mean(dim=d, keepdim=True)
|
|
||||||
if not keepdim:
|
|
||||||
for i, d in enumerate(dim):
|
|
||||||
tensor.squeeze_(d-i)
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
|
|
||||||
def split_feature(tensor, type="split"):
|
|
||||||
"""
|
|
||||||
type = ["split", "cross"]
|
|
||||||
"""
|
|
||||||
C = tensor.size(1)
|
|
||||||
if type == "split":
|
|
||||||
return tensor[:, :C // 2, ...], tensor[:, C // 2:, ...]
|
|
||||||
elif type == "cross":
|
|
||||||
return tensor[:, 0::2, ...], tensor[:, 1::2, ...]
|
|
||||||
|
|
||||||
|
|
||||||
def cat_feature(tensor_a, tensor_b):
|
|
||||||
return torch.cat((tensor_a, tensor_b), dim=1)
|
|
||||||
|
|
||||||
|
|
||||||
def pixels(tensor):
|
|
||||||
return int(tensor.size(2) * tensor.size(3))
|
|
Loading…
Reference in New Issue
Block a user