43 lines
1.4 KiB
Python
43 lines
1.4 KiB
Python
import numpy as np
|
|
import torch
|
|
from torch import nn as nn
|
|
from torch.nn import functional as F
|
|
|
|
from models.modules import thops
|
|
|
|
|
|
class InvertibleConv1x1(nn.Module):
|
|
def __init__(self, num_channels, LU_decomposed=False):
|
|
super().__init__()
|
|
w_shape = [num_channels, num_channels]
|
|
w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype(np.float32)
|
|
self.register_parameter("weight", nn.Parameter(torch.Tensor(w_init)))
|
|
self.w_shape = w_shape
|
|
self.LU = LU_decomposed
|
|
|
|
def get_weight(self, input, reverse):
|
|
w_shape = self.w_shape
|
|
pixels = thops.pixels(input)
|
|
dlogdet = torch.slogdet(self.weight)[1] * pixels
|
|
if not reverse:
|
|
weight = self.weight.view(w_shape[0], w_shape[1], 1, 1)
|
|
else:
|
|
weight = torch.inverse(self.weight.double()).float() \
|
|
.view(w_shape[0], w_shape[1], 1, 1)
|
|
return weight, dlogdet
|
|
def forward(self, input, logdet=None, reverse=False):
|
|
"""
|
|
log-det = log|abs(|W|)| * pixels
|
|
"""
|
|
weight, dlogdet = self.get_weight(input, reverse)
|
|
if not reverse:
|
|
z = F.conv2d(input, weight)
|
|
if logdet is not None:
|
|
logdet = logdet + dlogdet
|
|
return z, logdet
|
|
else:
|
|
z = F.conv2d(input, weight)
|
|
if logdet is not None:
|
|
logdet = logdet - dlogdet
|
|
return z, logdet
|