import torch from torch import nn as nn from models.modules import thops class _ActNorm(nn.Module): """ Activation Normalization Initialize the bias and scale with a given minibatch, so that the output per-channel have zero mean and unit variance for that. After initialization, `bias` and `logs` will be trained as parameters. """ def __init__(self, num_features, scale=1.): super().__init__() # register mean and scale size = [1, num_features, 1, 1] self.register_parameter("bias", nn.Parameter(torch.zeros(*size))) self.register_parameter("logs", nn.Parameter(torch.zeros(*size))) self.num_features = num_features self.scale = float(scale) self.inited = False def _check_input_dim(self, input): return NotImplemented def initialize_parameters(self, input): self._check_input_dim(input) if not self.training: return if (self.bias != 0).any(): self.inited = True return assert input.device == self.bias.device, (input.device, self.bias.device) with torch.no_grad(): bias = thops.mean(input.clone(), dim=[0, 2, 3], keepdim=True) * -1.0 vars = thops.mean((input.clone() + bias) ** 2, dim=[0, 2, 3], keepdim=True) logs = torch.log(self.scale / (torch.sqrt(vars) + 1e-6)) self.bias.data.copy_(bias.data) self.logs.data.copy_(logs.data) self.inited = True def _center(self, input, reverse=False, offset=None): bias = self.bias if offset is not None: bias = bias + offset if not reverse: return input + bias else: return input - bias def _scale(self, input, logdet=None, reverse=False, offset=None): logs = self.logs if offset is not None: logs = logs + offset if not reverse: input = input * torch.exp(logs) # should have shape batchsize, n_channels, 1, 1 # input = input * torch.exp(logs+logs_offset) else: input = input * torch.exp(-logs) if logdet is not None: """ logs is log_std of `mean of channels` so we need to multiply pixels """ dlogdet = thops.sum(logs) * thops.pixels(input) if reverse: dlogdet *= -1 logdet = logdet + dlogdet return input, logdet def forward(self, input, logdet=None, reverse=False, offset_mask=None, logs_offset=None, bias_offset=None): if not self.inited: self.initialize_parameters(input) self._check_input_dim(input) if offset_mask is not None: logs_offset *= offset_mask bias_offset *= offset_mask # no need to permute dims as old version if not reverse: # center and scale # self.input = input input = self._center(input, reverse, bias_offset) input, logdet = self._scale(input, logdet, reverse, logs_offset) else: # scale and center input, logdet = self._scale(input, logdet, reverse, logs_offset) input = self._center(input, reverse, bias_offset) return input, logdet class ActNorm2d(_ActNorm): def __init__(self, num_features, scale=1.): super().__init__(num_features, scale) def _check_input_dim(self, input): assert len(input.size()) == 4 assert input.size(1) == self.num_features, ( "[ActNorm]: input should be in shape as `BCHW`," " channels should be {} rather than {}".format( self.num_features, input.size())) class MaskedActNorm2d(ActNorm2d): def __init__(self, num_features, scale=1.): super().__init__(num_features, scale) def forward(self, input, mask, logdet=None, reverse=False): assert mask.dtype == torch.bool output, logdet_out = super().forward(input, logdet, reverse) input[mask] = output[mask] logdet[mask] = logdet_out[mask] return input, logdet