DL-Art-School/dlas/models/image_generation/srflow/FlowActNorms.py

127 lines
4.1 KiB
Python

import torch
from torch import nn as nn
from dlas.models.image_generation.srflow import thops
class _ActNorm(nn.Module):
"""
Activation Normalization
Initialize the bias and scale with a given minibatch,
so that the output per-channel have zero mean and unit variance for that.
After initialization, `bias` and `logs` will be trained as parameters.
"""
def __init__(self, num_features, scale=1.):
super().__init__()
# register mean and scale
size = [1, num_features, 1, 1]
self.register_parameter("bias", nn.Parameter(torch.zeros(*size)))
self.register_parameter("logs", nn.Parameter(torch.zeros(*size)))
self.num_features = num_features
self.scale = float(scale)
self.inited = False
def _check_input_dim(self, input):
return NotImplemented
def initialize_parameters(self, input):
self._check_input_dim(input)
if not self.training:
return
if (self.bias != 0).any():
self.inited = True
return
assert input.device == self.bias.device, (
input.device, self.bias.device)
with torch.no_grad():
bias = thops.mean(input.clone(), dim=[
0, 2, 3], keepdim=True) * -1.0
vars = thops.mean((input.clone() + bias) ** 2,
dim=[0, 2, 3], keepdim=True)
logs = torch.log(self.scale / (torch.sqrt(vars) + 1e-6))
self.bias.data.copy_(bias.data)
self.logs.data.copy_(logs.data)
self.inited = True
def _center(self, input, reverse=False, offset=None):
bias = self.bias
if offset is not None:
bias = bias + offset
if not reverse:
return input + bias
else:
return input - bias
def _scale(self, input, logdet=None, reverse=False, offset=None):
logs = self.logs
if offset is not None:
logs = logs + offset
if not reverse:
# should have shape batchsize, n_channels, 1, 1
input = input * torch.exp(logs)
# input = input * torch.exp(logs+logs_offset)
else:
input = input * torch.exp(-logs)
if logdet is not None:
"""
logs is log_std of `mean of channels`
so we need to multiply pixels
"""
dlogdet = thops.sum(logs) * thops.pixels(input)
if reverse:
dlogdet *= -1
logdet = logdet + dlogdet
return input, logdet
def forward(self, input, logdet=None, reverse=False, offset_mask=None, logs_offset=None, bias_offset=None):
self._check_input_dim(input)
if offset_mask is not None:
logs_offset *= offset_mask
bias_offset *= offset_mask
# no need to permute dims as old version
if not reverse:
# center and scale
# self.input = input
input = self._center(input, reverse, bias_offset)
input, logdet = self._scale(input, logdet, reverse, logs_offset)
else:
# scale and center
input, logdet = self._scale(input, logdet, reverse, logs_offset)
input = self._center(input, reverse, bias_offset)
return input, logdet
class ActNorm2d(_ActNorm):
def __init__(self, num_features, scale=1.):
super().__init__(num_features, scale)
def _check_input_dim(self, input):
assert len(input.size()) == 4
assert input.size(1) == self.num_features, (
"[ActNorm]: input should be in shape as `BCHW`,"
" channels should be {} rather than {}".format(
self.num_features, input.size()))
class MaskedActNorm2d(ActNorm2d):
def __init__(self, num_features, scale=1.):
super().__init__(num_features, scale)
def forward(self, input, mask, logdet=None, reverse=False):
assert mask.dtype == torch.bool
output, logdet_out = super().forward(input, logdet, reverse)
input[mask] = output[mask]
logdet[mask] = logdet_out[mask]
return input, logdet