DL-Art-School/dlas/models/image_generation/srflow/Permutations.py
2023-03-21 15:38:42 +00:00

43 lines
1.4 KiB
Python

import numpy as np
import torch
from torch import nn as nn
from torch.nn import functional as F
from models.image_generation.srflow import thops
class InvertibleConv1x1(nn.Module):
def __init__(self, num_channels, LU_decomposed=False):
super().__init__()
w_shape = [num_channels, num_channels]
w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype(np.float32)
self.register_parameter("weight", nn.Parameter(torch.Tensor(w_init)))
self.w_shape = w_shape
self.LU = LU_decomposed
def get_weight(self, input, reverse):
w_shape = self.w_shape
pixels = thops.pixels(input)
dlogdet = torch.slogdet(self.weight)[1] * pixels
if not reverse:
weight = self.weight.view(w_shape[0], w_shape[1], 1, 1)
else:
weight = torch.inverse(self.weight.double()).float() \
.view(w_shape[0], w_shape[1], 1, 1)
return weight, dlogdet
def forward(self, input, logdet=None, reverse=False):
"""
log-det = log|abs(|W|)| * pixels
"""
weight, dlogdet = self.get_weight(input, reverse)
if not reverse:
z = F.conv2d(input, weight)
if logdet is not None:
logdet = logdet + dlogdet
return z, logdet
else:
z = F.conv2d(input, weight)
if logdet is not None:
logdet = logdet - dlogdet
return z, logdet