DL-Art-School/codes/models/archs/srflow/FlowUpsamplerNet.py

268 lines
10 KiB
Python
Raw Normal View History

2020-11-07 03:38:04 +00:00
import numpy as np
import torch
from torch import nn as nn
import models.archs.srflow.Split
2020-11-20 04:42:24 +00:00
from models.archs.srflow import flow, thops, Split
2020-11-07 03:38:04 +00:00
from models.archs.srflow.Split import Split2d
from models.archs.srflow.glow_arch import f_conv2d_bias
from models.archs.srflow.FlowStep import FlowStep
from utils.util import opt_get
2020-11-20 04:42:24 +00:00
import torchvision
2020-11-07 03:38:04 +00:00
class FlowUpsamplerNet(nn.Module):
2020-11-20 04:42:24 +00:00
def __init__(self, image_shape, hidden_channels, scale,
rrdb_blocks,
2020-11-07 03:38:04 +00:00
actnorm_scale=1.0,
2020-11-20 04:42:24 +00:00
flow_permutation='invconv',
2020-11-07 03:38:04 +00:00
flow_coupling="affine",
2020-11-20 04:42:24 +00:00
LU_decomposed=False, K=16, L=3,
norm_opt=None,
n_bypass_channels=None):
2020-11-07 03:38:04 +00:00
super().__init__()
self.layers = nn.ModuleList()
self.output_shapes = []
2020-11-20 04:42:24 +00:00
self.L = L
self.K = K
self.scale=scale
2020-11-07 03:38:04 +00:00
if isinstance(self.K, int):
self.K = [K for K in [K, ] * (self.L + 1)]
H, W, self.C = image_shape
2020-11-20 04:42:24 +00:00
self.image_shape = image_shape
2020-11-07 03:38:04 +00:00
self.check_image_shape()
2020-11-20 04:42:24 +00:00
if scale == 16:
2020-11-07 03:38:04 +00:00
self.levelToName = {
0: 'fea_up16',
1: 'fea_up8',
2: 'fea_up4',
3: 'fea_up2',
4: 'fea_up1',
}
2020-11-20 04:42:24 +00:00
if scale == 8:
2020-11-07 03:38:04 +00:00
self.levelToName = {
0: 'fea_up8',
1: 'fea_up4',
2: 'fea_up2',
3: 'fea_up1',
4: 'fea_up0'
}
2020-11-20 04:42:24 +00:00
elif scale == 4:
2020-11-07 03:38:04 +00:00
self.levelToName = {
0: 'fea_up4',
1: 'fea_up2',
2: 'fea_up1',
3: 'fea_up0',
4: 'fea_up-1'
}
2020-11-20 04:42:24 +00:00
affineInCh = self.get_affineInCh(rrdb_blocks)
2020-11-07 03:38:04 +00:00
conditional_channels = {}
2020-11-20 04:42:24 +00:00
n_rrdb = self.get_n_rrdb_channels(rrdb_blocks)
2020-11-07 03:38:04 +00:00
conditional_channels[0] = n_rrdb
for level in range(1, self.L + 1):
# Level 1 gets conditionals from 2, 3, 4 => L - level
# Level 2 gets conditionals from 3, 4
# Level 3 gets conditionals from 4
# Level 4 gets conditionals from None
n_bypass = 0 if n_bypass_channels is None else (self.L - level) * n_bypass_channels
conditional_channels[level] = n_rrdb + n_bypass
# Upsampler
for level in range(1, self.L + 1):
# 1. Squeeze
H, W = self.arch_squeeze(H, W)
# 2. K FlowStep
2020-11-20 04:42:24 +00:00
self.arch_additionalFlowAffine(H, LU_decomposed, W, actnorm_scale, hidden_channels)
2020-11-07 03:38:04 +00:00
self.arch_FlowStep(H, self.K[level], LU_decomposed, W, actnorm_scale, affineInCh, flow_coupling,
flow_permutation,
2020-11-20 04:42:24 +00:00
hidden_channels, norm_opt,
n_conditional_channels=conditional_channels[level])
2020-11-07 03:38:04 +00:00
# Split
2020-11-20 04:42:24 +00:00
self.arch_split(H, W, level, self.L)
2020-11-07 03:38:04 +00:00
2020-11-20 04:42:24 +00:00
self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64 // 2 // 2)
2020-11-07 03:38:04 +00:00
self.H = H
self.W = W
2020-11-20 04:42:24 +00:00
def get_n_rrdb_channels(self, blocks):
2020-11-07 03:38:04 +00:00
n_rrdb = 64 if blocks is None else (len(blocks) + 1) * 64
return n_rrdb
def arch_FlowStep(self, H, K, LU_decomposed, W, actnorm_scale, affineInCh, flow_coupling, flow_permutation,
2020-11-20 04:42:24 +00:00
hidden_channels, normOpt, n_conditional_channels=None, condAff=None):
2020-11-07 03:38:04 +00:00
if condAff is not None:
2020-11-20 04:42:24 +00:00
condAff['in_channels_rrdb'] = n_conditional_channels
2020-11-07 03:38:04 +00:00
for k in range(K):
2020-11-20 04:42:24 +00:00
position_name = self.get_position_name(H, self.scale)
2020-11-07 03:38:04 +00:00
if normOpt: normOpt['position'] = position_name
self.layers.append(
FlowStep(in_channels=self.C,
hidden_channels=hidden_channels,
actnorm_scale=actnorm_scale,
flow_permutation=flow_permutation,
flow_coupling=flow_coupling,
acOpt=condAff,
position=position_name,
2020-11-20 04:42:24 +00:00
LU_decomposed=LU_decomposed, idx=k, normOpt=normOpt))
2020-11-07 03:38:04 +00:00
self.output_shapes.append(
[-1, self.C, H, W])
2020-11-20 04:42:24 +00:00
def arch_split(self, H, W, L, levels, split_flow=True, correct_splits=False, logs_eps=0, consume_ratio=.5, split_conditional=False, cond_channels=None, split_type='Split2d'):
2020-11-07 03:38:04 +00:00
correction = 0 if correct_splits else 1
2020-11-20 04:42:24 +00:00
if split_flow and L < levels - correction:
logs_eps = logs_eps
consume_ratio = consume_ratio
position_name = self.get_position_name(H, self.scale)
position = position_name if split_conditional else None
2020-11-07 03:38:04 +00:00
cond_channels = 0 if cond_channels is None else cond_channels
2020-11-20 04:42:24 +00:00
if split_type == 'Split2d':
split = Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position,
cond_channels=cond_channels, consume_ratio=consume_ratio)
2020-11-07 03:38:04 +00:00
self.layers.append(split)
self.output_shapes.append([-1, split.num_channels_pass, H, W])
self.C = split.num_channels_pass
2020-11-20 04:42:24 +00:00
def arch_additionalFlowAffine(self, H, LU_decomposed, W, actnorm_scale, hidden_channels, additionalFlowNoAffine=2):
for _ in range(additionalFlowNoAffine):
self.layers.append(
FlowStep(in_channels=self.C,
hidden_channels=hidden_channels,
actnorm_scale=actnorm_scale,
flow_permutation='invconv',
flow_coupling='noCoupling',
LU_decomposed=LU_decomposed))
self.output_shapes.append(
[-1, self.C, H, W])
2020-11-07 03:38:04 +00:00
def arch_squeeze(self, H, W):
self.C, H, W = self.C * 4, H // 2, W // 2
self.layers.append(flow.SqueezeLayer(factor=2))
self.output_shapes.append([-1, self.C, H, W])
return H, W
2020-11-20 04:42:24 +00:00
def get_affineInCh(self, rrdb_blocks):
affineInCh = (len(rrdb_blocks) + 1) * 64
2020-11-07 03:38:04 +00:00
return affineInCh
def check_image_shape(self):
assert self.C == 1 or self.C == 3, ("image_shape should be HWC, like (64, 64, 3)"
"self.C == 1 or self.C == 3")
def forward(self, gt=None, rrdbResults=None, z=None, epses=None, logdet=0., reverse=False, eps_std=None,
y_onehot=None):
if reverse:
epses_copy = [eps for eps in epses] if isinstance(epses, list) else epses
sr, logdet = self.decode(rrdbResults, z, eps_std, epses=epses_copy, logdet=logdet, y_onehot=y_onehot)
return sr, logdet
else:
assert gt is not None
assert rrdbResults is not None
z, logdet = self.encode(gt, rrdbResults, logdet=logdet, epses=epses, y_onehot=y_onehot)
return z, logdet
def encode(self, gt, rrdbResults, logdet=0.0, epses=None, y_onehot=None):
fl_fea = gt
reverse = False
level_conditionals = {}
bypasses = {}
2020-11-20 04:42:24 +00:00
L = self.L
2020-11-07 03:38:04 +00:00
for level in range(1, L + 1):
bypasses[level] = torch.nn.functional.interpolate(gt, scale_factor=2 ** -level, mode='bilinear', align_corners=False)
for layer, shape in zip(self.layers, self.output_shapes):
size = shape[2]
2020-11-20 04:42:24 +00:00
level = int(np.log(self.image_shape[0] / size) / np.log(2))
2020-11-07 03:38:04 +00:00
if level > 0 and level not in level_conditionals.keys():
level_conditionals[level] = rrdbResults[self.levelToName[level]]
level_conditionals[level] = rrdbResults[self.levelToName[level]]
if isinstance(layer, FlowStep):
fl_fea, logdet = layer(fl_fea, logdet, reverse=reverse, rrdbResults=level_conditionals[level])
elif isinstance(layer, Split2d):
fl_fea, logdet = self.forward_split2d(epses, fl_fea, layer, logdet, reverse, level_conditionals[level],
y_onehot=y_onehot)
else:
fl_fea, logdet = layer(fl_fea, logdet, reverse=reverse)
z = fl_fea
if not isinstance(epses, list):
return z, logdet
epses.append(z)
return epses, logdet
def forward_preFlow(self, fl_fea, logdet, reverse):
if hasattr(self, 'preFlow'):
for l in self.preFlow:
fl_fea, logdet = l(fl_fea, logdet, reverse=reverse)
return fl_fea, logdet
def forward_split2d(self, epses, fl_fea, layer, logdet, reverse, rrdbResults, y_onehot=None):
ft = None if layer.position is None else rrdbResults[layer.position]
fl_fea, logdet, eps = layer(fl_fea, logdet, reverse=reverse, eps=epses, ft=ft, y_onehot=y_onehot)
if isinstance(epses, list):
epses.append(eps)
return fl_fea, logdet
def decode(self, rrdbResults, z, eps_std=None, epses=None, logdet=0.0, y_onehot=None):
z = epses.pop() if isinstance(epses, list) else z
fl_fea = z
# debug.imwrite("fl_fea", fl_fea)
bypasses = {}
level_conditionals = {}
2020-11-20 04:42:24 +00:00
for level in range(self.L + 1):
level_conditionals[level] = rrdbResults[self.levelToName[level]]
2020-11-07 03:38:04 +00:00
for layer, shape in zip(reversed(self.layers), reversed(self.output_shapes)):
size = shape[2]
2020-11-20 04:42:24 +00:00
level = int(np.log(self.H / size) / np.log(2))
2020-11-07 03:38:04 +00:00
if isinstance(layer, Split2d):
fl_fea, logdet = self.forward_split2d_reverse(eps_std, epses, fl_fea, layer,
rrdbResults[self.levelToName[level]], logdet=logdet,
y_onehot=y_onehot)
elif isinstance(layer, FlowStep):
fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True, rrdbResults=level_conditionals[level])
else:
fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True)
sr = fl_fea
assert sr.shape[1] == 3
return sr, logdet
def forward_split2d_reverse(self, eps_std, epses, fl_fea, layer, rrdbResults, logdet, y_onehot=None):
ft = None if layer.position is None else rrdbResults[layer.position]
fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True,
eps=epses.pop() if isinstance(epses, list) else None,
eps_std=eps_std, ft=ft, y_onehot=y_onehot)
return fl_fea, logdet
2020-11-20 04:42:24 +00:00
def get_position_name(self, H, scale):
downscale_factor = self.image_shape[0] // H
position_name = 'fea_up{}'.format(scale / downscale_factor)
return position_name