James Betker 8ab595e427 Add FlatProcessorNet
After doing some thinking and reading on the subject, it occurred to me that
I was treating the generator like a discriminator by focusing the network
complexity at the feature levels. It makes far more sense to process each conv
level equally for the generator, hence the FlatProcessorNet in this commit. This
network borrows some of the residual pass-through logic from RRDB which makes
the gradient path exceptionally short for pretty much all model parameters and
can be trained in O1 optimization mode without overflows again.
2020-04-28 11:49:21 -06:00

102 lines
3.2 KiB

import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
def initialize_weights(net_l, scale=1):
if not isinstance(net_l, list):
net_l = [net_l]
for net in net_l:
for m in net.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
m.weight.data *= scale # for residual block
if m.bias is not None:
elif isinstance(m, nn.Linear):
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
m.weight.data *= scale
if m.bias is not None:
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias.data, 0.0)
def make_layer(block, n_layers):
layers = []
for _ in range(n_layers):
return nn.Sequential(*layers)
class ResidualBlock(nn.Module):
'''Residual block with BN
def __init__(self, nf=64):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.BN1 = nn.BatchNorm2d(nf)
self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.BN2 = nn.BatchNorm2d(nf)
# initialization
initialize_weights([self.conv1, self.conv2], 0.1)
def forward(self, x):
identity = x
out = F.relu(self.BN1(self.conv1(x)), inplace=True)
out = self.BN2(self.conv2(out))
return identity + out
class ResidualBlock_noBN(nn.Module):
'''Residual block w/o BN
def __init__(self, nf=64):
super(ResidualBlock_noBN, self).__init__()
self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
# initialization
initialize_weights([self.conv1, self.conv2], 0.1)
def forward(self, x):
identity = x
out = F.relu(self.conv1(x), inplace=True)
out = self.conv2(out)
return identity + out
def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'):
"""Warp an image or feature map with optical flow
x (Tensor): size (N, C, H, W)
flow (Tensor): size (N, H, W, 2), normal value
interp_mode (str): 'nearest' or 'bilinear'
padding_mode (str): 'zeros' or 'border' or 'reflection'
Tensor: warped image or feature map
assert x.size()[-2:] == flow.size()[1:3]
B, C, H, W = x.size()
# mesh grid
grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
grid.requires_grad = False
grid = grid.type_as(x)
vgrid = grid + flow
# scale grid to [-1,1]
vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0
vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0
vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode)
return output