DL-Art-School/codes/models/srflow/Split.py
2020-12-18 09:24:31 -07:00

69 lines
2.5 KiB
Python

import torch
from torch import nn as nn
from models.srflow import thops
from models.srflow.flow import Conv2dZeros, GaussianDiag
from utils.util import opt_get
class Split2d(nn.Module):
def __init__(self, num_channels, logs_eps=0, cond_channels=0, position=None, consume_ratio=0.5, opt=None):
super().__init__()
self.num_channels_consume = int(round(num_channels * consume_ratio))
self.num_channels_pass = num_channels - self.num_channels_consume
self.conv = Conv2dZeros(in_channels=self.num_channels_pass + cond_channels,
out_channels=self.num_channels_consume * 2)
self.logs_eps = logs_eps
self.position = position
self.gaussian_nll_weight = opt_get(opt, ['networks', 'generator', 'flow', 'gaussian_loss_weight'], 1)
def split2d_prior(self, z, ft):
if ft is not None:
z = torch.cat([z, ft], dim=1)
h = self.conv(z)
return thops.split_feature(h, "cross")
def exp_eps(self, logs):
return torch.exp(logs) + self.logs_eps
def forward(self, input, logdet=0., reverse=False, eps_std=None, eps=None, ft=None, y_onehot=None):
if not reverse:
# self.input = input
z1, z2 = self.split_ratio(input)
mean, logs = self.split2d_prior(z1, ft)
eps = (z2 - mean) / self.exp_eps(logs)
logdet = logdet + self.get_logdet(logs, mean, z2)
# print(logs.shape, mean.shape, z2.shape)
# self.eps = eps
# print('split, enc eps:', eps)
return z1, logdet, eps
else:
z1 = input
mean, logs = self.split2d_prior(z1, ft)
if eps is None:
#print("WARNING: eps is None, generating eps untested functionality!")
eps = GaussianDiag.sample(mean, logs, eps_std)
#eps = GaussianDiag.sample_eps(mean.shape, eps_std)
eps = eps.to(mean.device)
z2 = mean + self.exp_eps(logs) * eps
z = thops.cat_feature(z1, z2)
logdet = logdet - self.get_logdet(logs, mean, z2)
return z, logdet
# return z, logdet, eps
def get_logdet(self, logs, mean, z2):
logdet_diff = GaussianDiag.logp(mean, logs, z2)
return logdet_diff * self.gaussian_nll_weight
def split_ratio(self, input):
z1, z2 = input[:, :self.num_channels_pass, ...], input[:, self.num_channels_pass:, ...]
return z1, z2