DL-Art-School/codes/models/archs/srflow_orig/Split.py
2020-11-19 23:47:24 -07:00

70 lines
2.4 KiB
Python

import torch
from torch import nn as nn
from models.archs.srflow_orig import thops
from models.archs.srflow_orig.FlowStep import FlowStep
from models.archs.srflow_orig.flow import Conv2dZeros, GaussianDiag
from utils.util import opt_get
class Split2d(nn.Module):
def __init__(self, num_channels, logs_eps=0, cond_channels=0, position=None, consume_ratio=0.5, opt=None):
super().__init__()
self.num_channels_consume = int(round(num_channels * consume_ratio))
self.num_channels_pass = num_channels - self.num_channels_consume
self.conv = Conv2dZeros(in_channels=self.num_channels_pass + cond_channels,
out_channels=self.num_channels_consume * 2)
self.logs_eps = logs_eps
self.position = position
self.opt = opt
def split2d_prior(self, z, ft):
if ft is not None:
z = torch.cat([z, ft], dim=1)
h = self.conv(z)
return thops.split_feature(h, "cross")
def exp_eps(self, logs):
return torch.exp(logs) + self.logs_eps
def forward(self, input, logdet=0., reverse=False, eps_std=None, eps=None, ft=None, y_onehot=None):
if not reverse:
# self.input = input
z1, z2 = self.split_ratio(input)
mean, logs = self.split2d_prior(z1, ft)
eps = (z2 - mean) / self.exp_eps(logs)
logdet = logdet + self.get_logdet(logs, mean, z2)
# print(logs.shape, mean.shape, z2.shape)
# self.eps = eps
# print('split, enc eps:', eps)
return z1, logdet, eps
else:
z1 = input
mean, logs = self.split2d_prior(z1, ft)
if eps is None:
#print("WARNING: eps is None, generating eps untested functionality!")
eps = GaussianDiag.sample_eps(mean.shape, eps_std)
eps = eps.to(mean.device)
z2 = mean + self.exp_eps(logs) * eps
z = thops.cat_feature(z1, z2)
logdet = logdet - self.get_logdet(logs, mean, z2)
return z, logdet
# return z, logdet, eps
def get_logdet(self, logs, mean, z2):
logdet_diff = GaussianDiag.logp(mean, logs, z2)
# print("Split2D: logdet diff", logdet_diff.item())
return logdet_diff
def split_ratio(self, input):
z1, z2 = input[:, :self.num_channels_pass, ...], input[:, self.num_channels_pass:, ...]
return z1, z2