DL-Art-School/codes/models/archs/srflow/FlowUpsamplerNet.py
2020-11-19 21:42:24 -07:00

268 lines
10 KiB
Python

import numpy as np
import torch
from torch import nn as nn
import models.archs.srflow.Split
from models.archs.srflow import flow, thops, Split
from models.archs.srflow.Split import Split2d
from models.archs.srflow.glow_arch import f_conv2d_bias
from models.archs.srflow.FlowStep import FlowStep
from utils.util import opt_get
import torchvision
class FlowUpsamplerNet(nn.Module):
def __init__(self, image_shape, hidden_channels, scale,
rrdb_blocks,
actnorm_scale=1.0,
flow_permutation='invconv',
flow_coupling="affine",
LU_decomposed=False, K=16, L=3,
norm_opt=None,
n_bypass_channels=None):
super().__init__()
self.layers = nn.ModuleList()
self.output_shapes = []
self.L = L
self.K = K
self.scale=scale
if isinstance(self.K, int):
self.K = [K for K in [K, ] * (self.L + 1)]
H, W, self.C = image_shape
self.image_shape = image_shape
self.check_image_shape()
if scale == 16:
self.levelToName = {
0: 'fea_up16',
1: 'fea_up8',
2: 'fea_up4',
3: 'fea_up2',
4: 'fea_up1',
}
if scale == 8:
self.levelToName = {
0: 'fea_up8',
1: 'fea_up4',
2: 'fea_up2',
3: 'fea_up1',
4: 'fea_up0'
}
elif scale == 4:
self.levelToName = {
0: 'fea_up4',
1: 'fea_up2',
2: 'fea_up1',
3: 'fea_up0',
4: 'fea_up-1'
}
affineInCh = self.get_affineInCh(rrdb_blocks)
conditional_channels = {}
n_rrdb = self.get_n_rrdb_channels(rrdb_blocks)
conditional_channels[0] = n_rrdb
for level in range(1, self.L + 1):
# Level 1 gets conditionals from 2, 3, 4 => L - level
# Level 2 gets conditionals from 3, 4
# Level 3 gets conditionals from 4
# Level 4 gets conditionals from None
n_bypass = 0 if n_bypass_channels is None else (self.L - level) * n_bypass_channels
conditional_channels[level] = n_rrdb + n_bypass
# Upsampler
for level in range(1, self.L + 1):
# 1. Squeeze
H, W = self.arch_squeeze(H, W)
# 2. K FlowStep
self.arch_additionalFlowAffine(H, LU_decomposed, W, actnorm_scale, hidden_channels)
self.arch_FlowStep(H, self.K[level], LU_decomposed, W, actnorm_scale, affineInCh, flow_coupling,
flow_permutation,
hidden_channels, norm_opt,
n_conditional_channels=conditional_channels[level])
# Split
self.arch_split(H, W, level, self.L)
self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64 // 2 // 2)
self.H = H
self.W = W
def get_n_rrdb_channels(self, blocks):
n_rrdb = 64 if blocks is None else (len(blocks) + 1) * 64
return n_rrdb
def arch_FlowStep(self, H, K, LU_decomposed, W, actnorm_scale, affineInCh, flow_coupling, flow_permutation,
hidden_channels, normOpt, n_conditional_channels=None, condAff=None):
if condAff is not None:
condAff['in_channels_rrdb'] = n_conditional_channels
for k in range(K):
position_name = self.get_position_name(H, self.scale)
if normOpt: normOpt['position'] = position_name
self.layers.append(
FlowStep(in_channels=self.C,
hidden_channels=hidden_channels,
actnorm_scale=actnorm_scale,
flow_permutation=flow_permutation,
flow_coupling=flow_coupling,
acOpt=condAff,
position=position_name,
LU_decomposed=LU_decomposed, idx=k, normOpt=normOpt))
self.output_shapes.append(
[-1, self.C, H, W])
def arch_split(self, H, W, L, levels, split_flow=True, correct_splits=False, logs_eps=0, consume_ratio=.5, split_conditional=False, cond_channels=None, split_type='Split2d'):
correction = 0 if correct_splits else 1
if split_flow and L < levels - correction:
logs_eps = logs_eps
consume_ratio = consume_ratio
position_name = self.get_position_name(H, self.scale)
position = position_name if split_conditional else None
cond_channels = 0 if cond_channels is None else cond_channels
if split_type == 'Split2d':
split = Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position,
cond_channels=cond_channels, consume_ratio=consume_ratio)
self.layers.append(split)
self.output_shapes.append([-1, split.num_channels_pass, H, W])
self.C = split.num_channels_pass
def arch_additionalFlowAffine(self, H, LU_decomposed, W, actnorm_scale, hidden_channels, additionalFlowNoAffine=2):
for _ in range(additionalFlowNoAffine):
self.layers.append(
FlowStep(in_channels=self.C,
hidden_channels=hidden_channels,
actnorm_scale=actnorm_scale,
flow_permutation='invconv',
flow_coupling='noCoupling',
LU_decomposed=LU_decomposed))
self.output_shapes.append(
[-1, self.C, H, W])
def arch_squeeze(self, H, W):
self.C, H, W = self.C * 4, H // 2, W // 2
self.layers.append(flow.SqueezeLayer(factor=2))
self.output_shapes.append([-1, self.C, H, W])
return H, W
def get_affineInCh(self, rrdb_blocks):
affineInCh = (len(rrdb_blocks) + 1) * 64
return affineInCh
def check_image_shape(self):
assert self.C == 1 or self.C == 3, ("image_shape should be HWC, like (64, 64, 3)"
"self.C == 1 or self.C == 3")
def forward(self, gt=None, rrdbResults=None, z=None, epses=None, logdet=0., reverse=False, eps_std=None,
y_onehot=None):
if reverse:
epses_copy = [eps for eps in epses] if isinstance(epses, list) else epses
sr, logdet = self.decode(rrdbResults, z, eps_std, epses=epses_copy, logdet=logdet, y_onehot=y_onehot)
return sr, logdet
else:
assert gt is not None
assert rrdbResults is not None
z, logdet = self.encode(gt, rrdbResults, logdet=logdet, epses=epses, y_onehot=y_onehot)
return z, logdet
def encode(self, gt, rrdbResults, logdet=0.0, epses=None, y_onehot=None):
fl_fea = gt
reverse = False
level_conditionals = {}
bypasses = {}
L = self.L
for level in range(1, L + 1):
bypasses[level] = torch.nn.functional.interpolate(gt, scale_factor=2 ** -level, mode='bilinear', align_corners=False)
for layer, shape in zip(self.layers, self.output_shapes):
size = shape[2]
level = int(np.log(self.image_shape[0] / size) / np.log(2))
if level > 0 and level not in level_conditionals.keys():
level_conditionals[level] = rrdbResults[self.levelToName[level]]
level_conditionals[level] = rrdbResults[self.levelToName[level]]
if isinstance(layer, FlowStep):
fl_fea, logdet = layer(fl_fea, logdet, reverse=reverse, rrdbResults=level_conditionals[level])
elif isinstance(layer, Split2d):
fl_fea, logdet = self.forward_split2d(epses, fl_fea, layer, logdet, reverse, level_conditionals[level],
y_onehot=y_onehot)
else:
fl_fea, logdet = layer(fl_fea, logdet, reverse=reverse)
z = fl_fea
if not isinstance(epses, list):
return z, logdet
epses.append(z)
return epses, logdet
def forward_preFlow(self, fl_fea, logdet, reverse):
if hasattr(self, 'preFlow'):
for l in self.preFlow:
fl_fea, logdet = l(fl_fea, logdet, reverse=reverse)
return fl_fea, logdet
def forward_split2d(self, epses, fl_fea, layer, logdet, reverse, rrdbResults, y_onehot=None):
ft = None if layer.position is None else rrdbResults[layer.position]
fl_fea, logdet, eps = layer(fl_fea, logdet, reverse=reverse, eps=epses, ft=ft, y_onehot=y_onehot)
if isinstance(epses, list):
epses.append(eps)
return fl_fea, logdet
def decode(self, rrdbResults, z, eps_std=None, epses=None, logdet=0.0, y_onehot=None):
z = epses.pop() if isinstance(epses, list) else z
fl_fea = z
# debug.imwrite("fl_fea", fl_fea)
bypasses = {}
level_conditionals = {}
for level in range(self.L + 1):
level_conditionals[level] = rrdbResults[self.levelToName[level]]
for layer, shape in zip(reversed(self.layers), reversed(self.output_shapes)):
size = shape[2]
level = int(np.log(self.H / size) / np.log(2))
if isinstance(layer, Split2d):
fl_fea, logdet = self.forward_split2d_reverse(eps_std, epses, fl_fea, layer,
rrdbResults[self.levelToName[level]], logdet=logdet,
y_onehot=y_onehot)
elif isinstance(layer, FlowStep):
fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True, rrdbResults=level_conditionals[level])
else:
fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True)
sr = fl_fea
assert sr.shape[1] == 3
return sr, logdet
def forward_split2d_reverse(self, eps_std, epses, fl_fea, layer, rrdbResults, logdet, y_onehot=None):
ft = None if layer.position is None else rrdbResults[layer.position]
fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True,
eps=epses.pop() if isinstance(epses, list) else None,
eps_std=eps_std, ft=ft, y_onehot=y_onehot)
return fl_fea, logdet
def get_position_name(self, H, scale):
downscale_factor = self.image_shape[0] // H
position_name = 'fea_up{}'.format(scale / downscale_factor)
return position_name