import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from models.srflow.FlowActNorms import ActNorm2d from . import thops class Conv2d(nn.Conv2d): pad_dict = { "same": lambda kernel, stride: [((k - 1) * s + 1) // 2 for k, s in zip(kernel, stride)], "valid": lambda kernel, stride: [0 for _ in kernel] } @staticmethod def get_padding(padding, kernel_size, stride): # make paddding if isinstance(padding, str): if isinstance(kernel_size, int): kernel_size = [kernel_size, kernel_size] if isinstance(stride, int): stride = [stride, stride] padding = padding.lower() try: padding = Conv2d.pad_dict[padding](kernel_size, stride) except KeyError: raise ValueError("{} is not supported".format(padding)) return padding def __init__(self, in_channels, out_channels, kernel_size=[3, 3], stride=[1, 1], padding="same", do_actnorm=True, weight_std=0.05): padding = Conv2d.get_padding(padding, kernel_size, stride) super().__init__(in_channels, out_channels, kernel_size, stride, padding, bias=(not do_actnorm)) # init weight with std self.weight.data.normal_(mean=0.0, std=weight_std) if not do_actnorm: self.bias.data.zero_() else: self.actnorm = ActNorm2d(out_channels) self.do_actnorm = do_actnorm def forward(self, input): x = super().forward(input) if self.do_actnorm: x, _ = self.actnorm(x) return x class Conv2dZeros(nn.Conv2d): def __init__(self, in_channels, out_channels, kernel_size=[3, 3], stride=[1, 1], padding="same", logscale_factor=3): padding = Conv2d.get_padding(padding, kernel_size, stride) super().__init__(in_channels, out_channels, kernel_size, stride, padding) # logscale_factor self.logscale_factor = logscale_factor self.register_parameter("logs", nn.Parameter(torch.zeros(out_channels, 1, 1))) # init self.weight.data.zero_() self.bias.data.zero_() def forward(self, input): output = super().forward(input) return output * torch.exp(self.logs * self.logscale_factor) class GaussianDiag: Log2PI = float(np.log(2 * np.pi)) @staticmethod def likelihood(mean, logs, x): """ lnL = -1/2 * { ln|Var| + ((X - Mu)^T)(Var^-1)(X - Mu) + kln(2*PI) } k = 1 (Independent) Var = logs ** 2 """ if mean is None and logs is None: return -0.5 * (x ** 2 + GaussianDiag.Log2PI) else: return -0.5 * (logs * 2. + ((x - mean) ** 2) / torch.exp(logs * 2.) + GaussianDiag.Log2PI) @staticmethod def logp(mean, logs, x): likelihood = GaussianDiag.likelihood(mean, logs, x) return thops.sum(likelihood, dim=[1, 2, 3]) @staticmethod def sample(mean, logs, eps_std=None): eps_std = eps_std or 1 eps = torch.normal(mean=torch.zeros_like(mean), std=torch.ones_like(logs) * eps_std) return mean + torch.exp(logs) * eps @staticmethod def sample_eps(shape, eps_std, seed=None): if seed is not None: torch.manual_seed(seed) eps = torch.normal(mean=torch.zeros(shape), std=torch.ones(shape) * eps_std) return eps def squeeze2d(input, factor=2): assert factor >= 1 and isinstance(factor, int) if factor == 1: return input size = input.size() B = size[0] C = size[1] H = size[2] W = size[3] assert H % factor == 0 and W % factor == 0, "{}".format((H, W, factor)) x = input.view(B, C, H // factor, factor, W // factor, factor) x = x.permute(0, 1, 3, 5, 2, 4).contiguous() x = x.view(B, C * factor * factor, H // factor, W // factor) return x def unsqueeze2d(input, factor=2): assert factor >= 1 and isinstance(factor, int) factor2 = factor ** 2 if factor == 1: return input size = input.size() B = size[0] C = size[1] H = size[2] W = size[3] assert C % (factor2) == 0, "{}".format(C) x = input.view(B, C // factor2, factor, factor, H, W) x = x.permute(0, 1, 4, 2, 5, 3).contiguous() x = x.view(B, C // (factor2), H * factor, W * factor) return x class SqueezeLayer(nn.Module): def __init__(self, factor): super().__init__() self.factor = factor def forward(self, input, logdet=None, reverse=False): if not reverse: output = squeeze2d(input, self.factor) # Squeeze in forward return output, logdet else: output = unsqueeze2d(input, self.factor) return output, logdet