DL-Art-School/codes/models/srflow_orig/flow.py

151 lines
4.8 KiB
Python
Raw Normal View History

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
2020-11-20 06:47:24 +00:00
from models.archs.srflow_orig.FlowActNorms import ActNorm2d
from . import thops
class Conv2d(nn.Conv2d):
pad_dict = {
"same": lambda kernel, stride: [((k - 1) * s + 1) // 2 for k, s in zip(kernel, stride)],
"valid": lambda kernel, stride: [0 for _ in kernel]
}
@staticmethod
def get_padding(padding, kernel_size, stride):
# make paddding
if isinstance(padding, str):
if isinstance(kernel_size, int):
kernel_size = [kernel_size, kernel_size]
if isinstance(stride, int):
stride = [stride, stride]
padding = padding.lower()
try:
padding = Conv2d.pad_dict[padding](kernel_size, stride)
except KeyError:
raise ValueError("{} is not supported".format(padding))
return padding
def __init__(self, in_channels, out_channels,
kernel_size=[3, 3], stride=[1, 1],
padding="same", do_actnorm=True, weight_std=0.05):
padding = Conv2d.get_padding(padding, kernel_size, stride)
super().__init__(in_channels, out_channels, kernel_size, stride,
padding, bias=(not do_actnorm))
# init weight with std
self.weight.data.normal_(mean=0.0, std=weight_std)
if not do_actnorm:
self.bias.data.zero_()
else:
self.actnorm = ActNorm2d(out_channels)
self.do_actnorm = do_actnorm
def forward(self, input):
x = super().forward(input)
if self.do_actnorm:
x, _ = self.actnorm(x)
return x
class Conv2dZeros(nn.Conv2d):
def __init__(self, in_channels, out_channels,
kernel_size=[3, 3], stride=[1, 1],
padding="same", logscale_factor=3):
padding = Conv2d.get_padding(padding, kernel_size, stride)
super().__init__(in_channels, out_channels, kernel_size, stride, padding)
# logscale_factor
self.logscale_factor = logscale_factor
self.register_parameter("logs", nn.Parameter(torch.zeros(out_channels, 1, 1)))
# init
self.weight.data.zero_()
self.bias.data.zero_()
def forward(self, input):
output = super().forward(input)
return output * torch.exp(self.logs * self.logscale_factor)
class GaussianDiag:
Log2PI = float(np.log(2 * np.pi))
@staticmethod
def likelihood(mean, logs, x):
"""
lnL = -1/2 * { ln|Var| + ((X - Mu)^T)(Var^-1)(X - Mu) + kln(2*PI) }
k = 1 (Independent)
Var = logs ** 2
"""
if mean is None and logs is None:
return -0.5 * (x ** 2 + GaussianDiag.Log2PI)
else:
return -0.5 * (logs * 2. + ((x - mean) ** 2) / torch.exp(logs * 2.) + GaussianDiag.Log2PI)
@staticmethod
def logp(mean, logs, x):
likelihood = GaussianDiag.likelihood(mean, logs, x)
return thops.sum(likelihood, dim=[1, 2, 3])
@staticmethod
def sample(mean, logs, eps_std=None):
eps_std = eps_std or 1
eps = torch.normal(mean=torch.zeros_like(mean),
std=torch.ones_like(logs) * eps_std)
return mean + torch.exp(logs) * eps
@staticmethod
def sample_eps(shape, eps_std, seed=None):
if seed is not None:
torch.manual_seed(seed)
eps = torch.normal(mean=torch.zeros(shape),
std=torch.ones(shape) * eps_std)
return eps
def squeeze2d(input, factor=2):
assert factor >= 1 and isinstance(factor, int)
if factor == 1:
return input
size = input.size()
B = size[0]
C = size[1]
H = size[2]
W = size[3]
assert H % factor == 0 and W % factor == 0, "{}".format((H, W, factor))
x = input.view(B, C, H // factor, factor, W // factor, factor)
x = x.permute(0, 1, 3, 5, 2, 4).contiguous()
x = x.view(B, C * factor * factor, H // factor, W // factor)
return x
def unsqueeze2d(input, factor=2):
assert factor >= 1 and isinstance(factor, int)
factor2 = factor ** 2
if factor == 1:
return input
size = input.size()
B = size[0]
C = size[1]
H = size[2]
W = size[3]
assert C % (factor2) == 0, "{}".format(C)
x = input.view(B, C // factor2, factor, factor, H, W)
x = x.permute(0, 1, 4, 2, 5, 3).contiguous()
x = x.view(B, C // (factor2), H * factor, W * factor)
return x
class SqueezeLayer(nn.Module):
def __init__(self, factor):
super().__init__()
self.factor = factor
def forward(self, input, logdet=None, reverse=False):
if not reverse:
output = squeeze2d(input, self.factor) # Squeeze in forward
return output, logdet
else:
output = unsqueeze2d(input, self.factor)
return output, logdet