DL-Art-School/codes/models/srflow/FlowUpsamplerNet.py
2020-12-18 09:24:31 -07:00

293 lines
13 KiB
Python

import numpy as np
import torch
from torch import nn as nn
import models.srflow.Split
from models.srflow import flow
from models.srflow import thops
from models.srflow.Split import Split2d
from models.srflow.glow_arch import f_conv2d_bias
from models.srflow.FlowStep import FlowStep
from utils.util import opt_get, checkpoint
class FlowUpsamplerNet(nn.Module):
def __init__(self, image_shape, hidden_channels, K, L=None,
actnorm_scale=1.0,
flow_permutation=None,
flow_coupling="affine",
LU_decomposed=False, opt=None):
super().__init__()
self.layers = nn.ModuleList()
self.output_shapes = []
self.L = opt_get(opt, ['networks', 'generator','flow', 'L'])
self.K = opt_get(opt, ['networks', 'generator','flow', 'K'])
self.patch_sz = opt_get(opt, ['networks', 'generator', 'flow', 'patch_size'], 160)
if isinstance(self.K, int):
self.K = [K for K in [K, ] * (self.L + 1)]
self.opt = opt
H, W, self.C = image_shape
self.check_image_shape()
if opt_get(self.opt, ['networks', 'generator', 'flow_scale']) == 16:
self.levelToName = {
0: 'fea_up16',
1: 'fea_up8',
2: 'fea_up4',
3: 'fea_up2',
4: 'fea_up1',
}
if opt_get(self.opt, ['networks', 'generator', 'flow_scale']) == 8:
self.levelToName = {
0: 'fea_up8',
1: 'fea_up4',
2: 'fea_up2',
3: 'fea_up1',
4: 'fea_up0'
}
elif opt_get(self.opt, ['networks', 'generator', 'flow_scale']) == 4:
self.levelToName = {
0: 'fea_up4',
1: 'fea_up2',
2: 'fea_up1',
3: 'fea_up0',
4: 'fea_up-1'
}
affineInCh = self.get_affineInCh(opt_get)
flow_permutation = self.get_flow_permutation(flow_permutation, opt)
normOpt = opt_get(opt, ['networks', 'generator','flow', 'norm'])
conditional_channels = {}
n_rrdb = self.get_n_rrdb_channels(opt, opt_get)
n_bypass_channels = opt_get(opt, ['networks', 'generator','flow', 'levelConditional', 'n_channels'])
conditional_channels[0] = n_rrdb
for level in range(1, self.L + 1):
# Level 1 gets conditionals from 2, 3, 4 => L - level
# Level 2 gets conditionals from 3, 4
# Level 3 gets conditionals from 4
# Level 4 gets conditionals from None
n_bypass = 0 if n_bypass_channels is None else (self.L - level) * n_bypass_channels
conditional_channels[level] = n_rrdb + n_bypass
# Upsampler
for level in range(1, self.L + 1):
# 1. Squeeze
H, W = self.arch_squeeze(H, W)
# 2. K FlowStep
self.arch_additionalFlowAffine(H, LU_decomposed, W, actnorm_scale, hidden_channels, opt)
self.arch_FlowStep(H, self.K[level], LU_decomposed, W, actnorm_scale, affineInCh, flow_coupling,
flow_permutation,
hidden_channels, normOpt, opt, opt_get,
n_conditinal_channels=conditional_channels[level])
# Split
self.arch_split(H, W, level, self.L, opt, opt_get)
if opt_get(opt, ['networks', 'generator','flow', 'split', 'enable']):
self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64 // 2 // 2)
else:
self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64)
self.H = H
self.W = W
self.scaleH = self.patch_sz / H
self.scaleW = self.patch_sz / W
def get_n_rrdb_channels(self, opt, opt_get):
blocks = opt_get(opt, ['networks', 'generator','flow', 'stackRRDB', 'blocks'])
n_rrdb = 64 if blocks is None else (len(blocks) + 1) * 64
return n_rrdb
def arch_FlowStep(self, H, K, LU_decomposed, W, actnorm_scale, affineInCh, flow_coupling, flow_permutation,
hidden_channels, normOpt, opt, opt_get, n_conditinal_channels=None):
condAff = self.get_condAffSetting(opt, opt_get)
if condAff is not None:
condAff['in_channels_rrdb'] = n_conditinal_channels
for k in range(K):
position_name = self.get_position_name(H, opt_get(self.opt, ['networks', 'generator', 'flow_scale']))
if normOpt: normOpt['position'] = position_name
self.layers.append(
FlowStep(in_channels=self.C,
hidden_channels=hidden_channels,
actnorm_scale=actnorm_scale,
flow_permutation=flow_permutation,
flow_coupling=flow_coupling,
acOpt=condAff,
position=position_name,
LU_decomposed=LU_decomposed, opt=opt, idx=k, normOpt=normOpt))
self.output_shapes.append(
[-1, self.C, H, W])
def get_condAffSetting(self, opt, opt_get):
condAff = opt_get(opt, ['networks', 'generator','flow', 'condAff']) or None
condAff = opt_get(opt, ['networks', 'generator','flow', 'condFtAffine']) or condAff
return condAff
def arch_split(self, H, W, L, levels, opt, opt_get):
correct_splits = opt_get(opt, ['networks', 'generator','flow', 'split', 'correct_splits'], False)
correction = 0 if correct_splits else 1
if opt_get(opt, ['networks', 'generator','flow', 'split', 'enable']) and L < levels - correction:
logs_eps = opt_get(opt, ['networks', 'generator','flow', 'split', 'logs_eps']) or 0
consume_ratio = opt_get(opt, ['networks', 'generator','flow', 'split', 'consume_ratio']) or 0.5
position_name = self.get_position_name(H, opt_get(self.opt, ['networks', 'generator', 'flow_scale']))
position = position_name if opt_get(opt, ['networks', 'generator','flow', 'split', 'conditional']) else None
cond_channels = opt_get(opt, ['networks', 'generator','flow', 'split', 'cond_channels'])
cond_channels = 0 if cond_channels is None else cond_channels
t = opt_get(opt, ['networks', 'generator','flow', 'split', 'type'], 'Split2d')
if t == 'Split2d':
split = models.srflow.Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position,
cond_channels=cond_channels, consume_ratio=consume_ratio, opt=opt)
self.layers.append(split)
self.output_shapes.append([-1, split.num_channels_pass, H, W])
self.C = split.num_channels_pass
def arch_additionalFlowAffine(self, H, LU_decomposed, W, actnorm_scale, hidden_channels, opt):
if 'additionalFlowNoAffine' in opt['networks']['generator']['flow']:
n_additionalFlowNoAffine = int(opt['networks']['generator']['flow']['additionalFlowNoAffine'])
for _ in range(n_additionalFlowNoAffine):
self.layers.append(
FlowStep(in_channels=self.C,
hidden_channels=hidden_channels,
actnorm_scale=actnorm_scale,
flow_permutation='invconv',
flow_coupling='noCoupling',
LU_decomposed=LU_decomposed, opt=opt))
self.output_shapes.append(
[-1, self.C, H, W])
def arch_squeeze(self, H, W):
self.C, H, W = self.C * 4, H // 2, W // 2
self.layers.append(flow.SqueezeLayer(factor=2))
self.output_shapes.append([-1, self.C, H, W])
return H, W
def get_flow_permutation(self, flow_permutation, opt):
flow_permutation = opt['networks']['generator']['flow'].get('flow_permutation', 'invconv')
return flow_permutation
def get_affineInCh(self, opt_get):
affineInCh = opt_get(self.opt, ['networks', 'generator','flow', 'stackRRDB', 'blocks']) or []
affineInCh = (len(affineInCh) + 1) * 64
return affineInCh
def check_image_shape(self):
assert self.C == 1 or self.C == 3, ("image_shape should be HWC, like (64, 64, 3)"
"self.C == 1 or self.C == 3")
def forward(self, gt=None, rrdbResults=None, z=None, epses=None, logdet=0., reverse=False, eps_std=None,
y_onehot=None):
if reverse:
epses_copy = [eps for eps in epses] if isinstance(epses, list) else epses
sr, logdet = self.decode(rrdbResults, z, eps_std, epses=epses_copy, logdet=logdet, y_onehot=y_onehot)
return sr, logdet
else:
assert gt is not None
assert rrdbResults is not None
z, logdet = self.encode(gt, rrdbResults, logdet=logdet, epses=epses, y_onehot=y_onehot)
return z, logdet
def encode(self, gt, rrdbResults, logdet=0.0, epses=None, y_onehot=None):
fl_fea = gt
reverse = False
level_conditionals = {}
bypasses = {}
L = opt_get(self.opt, ['networks', 'generator','flow', 'L'])
for level in range(1, L + 1):
bypasses[level] = torch.nn.functional.interpolate(gt, scale_factor=2 ** -level, mode='bilinear', align_corners=False)
for layer, shape in zip(self.layers, self.output_shapes):
size = shape[2]
level = int(np.log(self.patch_sz / size) / np.log(2))
if level > 0 and level not in level_conditionals.keys():
level_conditionals[level] = rrdbResults[self.levelToName[level]]
level_conditionals[level] = rrdbResults[self.levelToName[level]]
if isinstance(layer, FlowStep):
fl_fea, logdet = checkpoint(layer, fl_fea, logdet, level_conditionals[level])
elif isinstance(layer, Split2d):
fl_fea, logdet = self.forward_split2d(epses, fl_fea, layer, logdet, reverse, level_conditionals[level],
y_onehot=y_onehot)
else:
fl_fea, logdet = layer(fl_fea, logdet, reverse=reverse)
z = fl_fea
if not isinstance(epses, list):
return z, logdet
epses.append(z)
return epses, logdet
def forward_preFlow(self, fl_fea, logdet, reverse):
if hasattr(self, 'preFlow'):
for l in self.preFlow:
fl_fea, logdet = l(fl_fea, logdet, reverse=reverse)
return fl_fea, logdet
def forward_split2d(self, epses, fl_fea, layer, logdet, reverse, rrdbResults, y_onehot=None):
ft = None if layer.position is None else rrdbResults[layer.position]
fl_fea, logdet, eps = layer(fl_fea, logdet, reverse=reverse, eps=epses, ft=ft, y_onehot=y_onehot)
epses.append(eps)
return fl_fea, logdet
def decode(self, rrdbResults, z, eps_std=None, epses=None, logdet=0.0, y_onehot=None):
z = epses.pop() if isinstance(epses, list) else z
fl_fea = z
# debug.imwrite("fl_fea", fl_fea)
bypasses = {}
level_conditionals = {}
if not opt_get(self.opt, ['networks', 'generator','flow', 'levelConditional', 'conditional']) == True:
for level in range(self.L + 1):
level_conditionals[level] = rrdbResults[self.levelToName[level]]
for layer, shape in zip(reversed(self.layers), reversed(self.output_shapes)):
size = shape[2]
level = int(np.log(self.patch_sz / size) / np.log(2))
# size = fl_fea.shape[2]
# level = int(np.log(160 / size) / np.log(2))
if isinstance(layer, Split2d):
fl_fea, logdet = self.forward_split2d_reverse(eps_std, epses, fl_fea, layer,
rrdbResults[self.levelToName[level]], logdet=logdet,
y_onehot=y_onehot)
elif isinstance(layer, FlowStep):
fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True, rrdbResults=level_conditionals[level])
else:
fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True)
sr = fl_fea
assert sr.shape[1] == 3
return sr, logdet
def forward_split2d_reverse(self, eps_std, epses, fl_fea, layer, rrdbResults, logdet, y_onehot=None):
ft = None if layer.position is None else rrdbResults[layer.position]
fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True,
eps=epses.pop() if isinstance(epses, list) else None,
eps_std=eps_std, ft=ft, y_onehot=y_onehot)
return fl_fea, logdet
def get_position_name(self, H, scale):
downscale_factor = self.patch_sz // H
position_name = 'fea_up{}'.format(scale / downscale_factor)
return position_name