forked from mrq/DL-Art-School
biggan arch, initial work (not implemented)
This commit is contained in:
parent
61ed51d9e4
commit
79593803f2
284
codes/models/archs/biggan_gen_arch.py
Normal file
284
codes/models/archs/biggan_gen_arch.py
Normal file
|
@ -0,0 +1,284 @@
|
|||
# Source: https://github.com/ajbrock/BigGAN-PyTorch/blob/master/BigGANdeep.py
|
||||
import numpy as np
|
||||
import math
|
||||
import functools
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import init
|
||||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Parameter as P
|
||||
|
||||
import models.archs.biggan_layers as layers
|
||||
from models.archs.biggan_sync_batchnorm import SynchronizedBatchNorm2d as SyncBatchNorm2d
|
||||
|
||||
# BigGAN-deep: uses a different resblock and pattern
|
||||
|
||||
# Architectures for G
|
||||
# Attention is passed in in the format '32_64' to mean applying an attention
|
||||
# block at both resolution 32x32 and 64x64. Just '64' will apply at 64x64.
|
||||
|
||||
# Channel ratio is the ratio of
|
||||
class GBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels,
|
||||
which_conv=nn.Conv2d, which_bn=layers.bn, activation=None,
|
||||
upsample=None, channel_ratio=4):
|
||||
super(GBlock, self).__init__()
|
||||
|
||||
self.in_channels, self.out_channels = in_channels, out_channels
|
||||
self.hidden_channels = self.in_channels // channel_ratio
|
||||
self.which_conv, self.which_bn = which_conv, which_bn
|
||||
self.activation = activation
|
||||
# Conv layers
|
||||
self.conv1 = self.which_conv(self.in_channels, self.hidden_channels,
|
||||
kernel_size=1, padding=0)
|
||||
self.conv2 = self.which_conv(self.hidden_channels, self.hidden_channels)
|
||||
self.conv3 = self.which_conv(self.hidden_channels, self.hidden_channels)
|
||||
self.conv4 = self.which_conv(self.hidden_channels, self.out_channels,
|
||||
kernel_size=1, padding=0)
|
||||
# Batchnorm layers
|
||||
self.bn1 = self.which_bn(self.in_channels)
|
||||
self.bn2 = self.which_bn(self.hidden_channels)
|
||||
self.bn3 = self.which_bn(self.hidden_channels)
|
||||
self.bn4 = self.which_bn(self.hidden_channels)
|
||||
# upsample layers
|
||||
self.upsample = upsample
|
||||
|
||||
def forward(self, x, y):
|
||||
# Project down to channel ratio
|
||||
h = self.conv1(self.activation(self.bn1(x, y)))
|
||||
# Apply next BN-ReLU
|
||||
h = self.activation(self.bn2(h, y))
|
||||
# Drop channels in x if necessary
|
||||
if self.in_channels != self.out_channels:
|
||||
x = x[:, :self.out_channels]
|
||||
# Upsample both h and x at this point
|
||||
if self.upsample:
|
||||
h = self.upsample(h)
|
||||
x = self.upsample(x)
|
||||
# 3x3 convs
|
||||
h = self.conv2(h)
|
||||
h = self.conv3(self.activation(self.bn3(h, y)))
|
||||
# Final 1x1 conv
|
||||
h = self.conv4(self.activation(self.bn4(h, y)))
|
||||
return h + x
|
||||
|
||||
|
||||
def G_arch(ch=64, attention='64', ksize='333333', dilation='111111'):
|
||||
arch = {}
|
||||
arch[256] = {'in_channels': [ch * item for item in [16, 16, 8, 8, 4, 2]],
|
||||
'out_channels': [ch * item for item in [16, 8, 8, 4, 2, 1]],
|
||||
'upsample': [True] * 6,
|
||||
'resolution': [8, 16, 32, 64, 128, 256],
|
||||
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
|
||||
for i in range(3, 9)}}
|
||||
arch[128] = {'in_channels': [ch * item for item in [16, 16, 8, 4, 2]],
|
||||
'out_channels': [ch * item for item in [16, 8, 4, 2, 1]],
|
||||
'upsample': [True] * 5,
|
||||
'resolution': [8, 16, 32, 64, 128],
|
||||
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
|
||||
for i in range(3, 8)}}
|
||||
arch[64] = {'in_channels': [ch * item for item in [16, 16, 8, 4]],
|
||||
'out_channels': [ch * item for item in [16, 8, 4, 2]],
|
||||
'upsample': [True] * 4,
|
||||
'resolution': [8, 16, 32, 64],
|
||||
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
|
||||
for i in range(3, 7)}}
|
||||
arch[32] = {'in_channels': [ch * item for item in [4, 4, 4]],
|
||||
'out_channels': [ch * item for item in [4, 4, 4]],
|
||||
'upsample': [True] * 3,
|
||||
'resolution': [8, 16, 32],
|
||||
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
|
||||
for i in range(3, 6)}}
|
||||
|
||||
return arch
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, G_ch=64, G_depth=2, dim_z=128, bottom_width=4, resolution=128,
|
||||
G_kernel_size=3, G_attn='64', n_classes=1000,
|
||||
num_G_SVs=1, num_G_SV_itrs=1,
|
||||
G_shared=True, shared_dim=0, hier=False,
|
||||
cross_replica=False, mybn=False,
|
||||
G_activation=nn.ReLU(inplace=False),
|
||||
G_lr=5e-5, G_B1=0.0, G_B2=0.999, adam_eps=1e-8,
|
||||
BN_eps=1e-5, SN_eps=1e-12, G_mixed_precision=False, G_fp16=False,
|
||||
G_init='ortho', skip_init=False, no_optim=False,
|
||||
G_param='SN', norm_style='bn',
|
||||
**kwargs):
|
||||
super(Generator, self).__init__()
|
||||
# Channel width mulitplier
|
||||
self.ch = G_ch
|
||||
# Number of resblocks per stage
|
||||
self.G_depth = G_depth
|
||||
# Dimensionality of the latent space
|
||||
self.dim_z = dim_z
|
||||
# The initial spatial dimensions
|
||||
self.bottom_width = bottom_width
|
||||
# Resolution of the output
|
||||
self.resolution = resolution
|
||||
# Kernel size?
|
||||
self.kernel_size = G_kernel_size
|
||||
# Attention?
|
||||
self.attention = G_attn
|
||||
# number of classes, for use in categorical conditional generation
|
||||
self.n_classes = n_classes
|
||||
# Use shared embeddings?
|
||||
self.G_shared = G_shared
|
||||
# Dimensionality of the shared embedding? Unused if not using G_shared
|
||||
self.shared_dim = shared_dim if shared_dim > 0 else dim_z
|
||||
# Hierarchical latent space?
|
||||
self.hier = hier
|
||||
# Cross replica batchnorm?
|
||||
self.cross_replica = cross_replica
|
||||
# Use my batchnorm?
|
||||
self.mybn = mybn
|
||||
# nonlinearity for residual blocks
|
||||
self.activation = G_activation
|
||||
# Initialization style
|
||||
self.init = G_init
|
||||
# Parameterization style
|
||||
self.G_param = G_param
|
||||
# Normalization style
|
||||
self.norm_style = norm_style
|
||||
# Epsilon for BatchNorm?
|
||||
self.BN_eps = BN_eps
|
||||
# Epsilon for Spectral Norm?
|
||||
self.SN_eps = SN_eps
|
||||
# fp16?
|
||||
self.fp16 = G_fp16
|
||||
# Architecture dict
|
||||
self.arch = G_arch(self.ch, self.attention)[resolution]
|
||||
|
||||
# Which convs, batchnorms, and linear layers to use
|
||||
if self.G_param == 'SN':
|
||||
self.which_conv = functools.partial(layers.SNConv2d,
|
||||
kernel_size=3, padding=1,
|
||||
num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
|
||||
eps=self.SN_eps)
|
||||
self.which_linear = functools.partial(layers.SNLinear,
|
||||
num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
|
||||
eps=self.SN_eps)
|
||||
else:
|
||||
self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
|
||||
self.which_linear = nn.Linear
|
||||
|
||||
# We use a non-spectral-normed embedding here regardless;
|
||||
# For some reason applying SN to G's embedding seems to randomly cripple G
|
||||
self.which_embedding = nn.Embedding
|
||||
bn_linear = (functools.partial(self.which_linear, bias=False) if self.G_shared
|
||||
else self.which_embedding)
|
||||
self.which_bn = functools.partial(layers.ccbn,
|
||||
which_linear=bn_linear,
|
||||
cross_replica=self.cross_replica,
|
||||
mybn=self.mybn,
|
||||
input_size=(self.shared_dim + self.dim_z if self.G_shared
|
||||
else self.n_classes),
|
||||
norm_style=self.norm_style,
|
||||
eps=self.BN_eps)
|
||||
|
||||
# Prepare model
|
||||
# If not using shared embeddings, self.shared is just a passthrough
|
||||
self.shared = (self.which_embedding(n_classes, self.shared_dim) if G_shared
|
||||
else layers.identity())
|
||||
# First linear layer
|
||||
self.linear = self.which_linear(self.dim_z + self.shared_dim,
|
||||
self.arch['in_channels'][0] * (self.bottom_width ** 2))
|
||||
|
||||
# self.blocks is a doubly-nested list of modules, the outer loop intended
|
||||
# to be over blocks at a given resolution (resblocks and/or self-attention)
|
||||
# while the inner loop is over a given block
|
||||
self.blocks = []
|
||||
for index in range(len(self.arch['out_channels'])):
|
||||
self.blocks += [[GBlock(in_channels=self.arch['in_channels'][index],
|
||||
out_channels=self.arch['in_channels'][index] if g_index == 0 else
|
||||
self.arch['out_channels'][index],
|
||||
which_conv=self.which_conv,
|
||||
which_bn=self.which_bn,
|
||||
activation=self.activation,
|
||||
upsample=(functools.partial(F.interpolate, scale_factor=2)
|
||||
if self.arch['upsample'][index] and g_index == (
|
||||
self.G_depth - 1) else None))]
|
||||
for g_index in range(self.G_depth)]
|
||||
|
||||
# If attention on this block, attach it to the end
|
||||
if self.arch['attention'][self.arch['resolution'][index]]:
|
||||
print('Adding attention layer in G at resolution %d' % self.arch['resolution'][index])
|
||||
self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], self.which_conv)]
|
||||
|
||||
# Turn self.blocks into a ModuleList so that it's all properly registered.
|
||||
self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
|
||||
|
||||
# output layer: batchnorm-relu-conv.
|
||||
# Consider using a non-spectral conv here
|
||||
self.output_layer = nn.Sequential(layers.bn(self.arch['out_channels'][-1],
|
||||
cross_replica=self.cross_replica,
|
||||
mybn=self.mybn),
|
||||
self.activation,
|
||||
self.which_conv(self.arch['out_channels'][-1], 3))
|
||||
|
||||
# Initialize weights. Optionally skip init for testing.
|
||||
if not skip_init:
|
||||
self.init_weights()
|
||||
|
||||
# Set up optimizer
|
||||
# If this is an EMA copy, no need for an optim, so just return now
|
||||
if no_optim:
|
||||
return
|
||||
self.lr, self.B1, self.B2, self.adam_eps = G_lr, G_B1, G_B2, adam_eps
|
||||
if G_mixed_precision:
|
||||
print('Using fp16 adam in G...')
|
||||
import utils
|
||||
self.optim = utils.Adam16(params=self.parameters(), lr=self.lr,
|
||||
betas=(self.B1, self.B2), weight_decay=0,
|
||||
eps=self.adam_eps)
|
||||
else:
|
||||
self.optim = optim.Adam(params=self.parameters(), lr=self.lr,
|
||||
betas=(self.B1, self.B2), weight_decay=0,
|
||||
eps=self.adam_eps)
|
||||
|
||||
# LR scheduling, left here for forward compatibility
|
||||
# self.lr_sched = {'itr' : 0}# if self.progressive else {}
|
||||
# self.j = 0
|
||||
|
||||
# Initialize
|
||||
def init_weights(self):
|
||||
self.param_count = 0
|
||||
for module in self.modules():
|
||||
if (isinstance(module, nn.Conv2d)
|
||||
or isinstance(module, nn.Linear)
|
||||
or isinstance(module, nn.Embedding)):
|
||||
if self.init == 'ortho':
|
||||
init.orthogonal_(module.weight)
|
||||
elif self.init == 'N02':
|
||||
init.normal_(module.weight, 0, 0.02)
|
||||
elif self.init in ['glorot', 'xavier']:
|
||||
init.xavier_uniform_(module.weight)
|
||||
else:
|
||||
print('Init style not recognized...')
|
||||
self.param_count += sum([p.data.nelement() for p in module.parameters()])
|
||||
print('Param count for G''s initialized parameters: %d' % self.param_count)
|
||||
|
||||
# Note on this forward function: we pass in a y vector which has
|
||||
# already been passed through G.shared to enable easy class-wise
|
||||
# interpolation later. If we passed in the one-hot and then ran it through
|
||||
# G.shared in this forward function, it would be harder to handle.
|
||||
# NOTE: The z vs y dichotomy here is for compatibility with not-y
|
||||
def forward(self, z, y):
|
||||
# If hierarchical, concatenate zs and ys
|
||||
if self.hier:
|
||||
z = torch.cat([y, z], 1)
|
||||
y = z
|
||||
# First linear layer
|
||||
h = self.linear(z)
|
||||
# Reshape
|
||||
h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width)
|
||||
# Loop over blocks
|
||||
for index, blocklist in enumerate(self.blocks):
|
||||
# Second inner loop in case block has multiple layers
|
||||
for block in blocklist:
|
||||
h = block(h, y)
|
||||
|
||||
# Apply batchnorm-relu-conv-tanh at output
|
||||
return torch.tanh(self.output_layer(h))
|
464
codes/models/archs/biggan_layers.py
Normal file
464
codes/models/archs/biggan_layers.py
Normal file
|
@ -0,0 +1,464 @@
|
|||
''' Layers
|
||||
This file contains various layers for the BigGAN models.
|
||||
'''
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import init
|
||||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Parameter as P
|
||||
|
||||
from sync_batchnorm import SynchronizedBatchNorm2d as SyncBN2d
|
||||
|
||||
|
||||
# Projection of x onto y
|
||||
def proj(x, y):
|
||||
return torch.mm(y, x.t()) * y / torch.mm(y, y.t())
|
||||
|
||||
|
||||
# Orthogonalize x wrt list of vectors ys
|
||||
def gram_schmidt(x, ys):
|
||||
for y in ys:
|
||||
x = x - proj(x, y)
|
||||
return x
|
||||
|
||||
|
||||
# Apply num_itrs steps of the power method to estimate top N singular values.
|
||||
def power_iteration(W, u_, update=True, eps=1e-12):
|
||||
# Lists holding singular vectors and values
|
||||
us, vs, svs = [], [], []
|
||||
for i, u in enumerate(u_):
|
||||
# Run one step of the power iteration
|
||||
with torch.no_grad():
|
||||
v = torch.matmul(u, W)
|
||||
# Run Gram-Schmidt to subtract components of all other singular vectors
|
||||
v = F.normalize(gram_schmidt(v, vs), eps=eps)
|
||||
# Add to the list
|
||||
vs += [v]
|
||||
# Update the other singular vector
|
||||
u = torch.matmul(v, W.t())
|
||||
# Run Gram-Schmidt to subtract components of all other singular vectors
|
||||
u = F.normalize(gram_schmidt(u, us), eps=eps)
|
||||
# Add to the list
|
||||
us += [u]
|
||||
if update:
|
||||
u_[i][:] = u
|
||||
# Compute this singular value and add it to the list
|
||||
svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))]
|
||||
# svs += [torch.sum(F.linear(u, W.transpose(0, 1)) * v)]
|
||||
return svs, us, vs
|
||||
|
||||
|
||||
# Convenience passthrough function
|
||||
class identity(nn.Module):
|
||||
def forward(self, input):
|
||||
return input
|
||||
|
||||
|
||||
# Spectral normalization base class
|
||||
class SN(object):
|
||||
def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12):
|
||||
# Number of power iterations per step
|
||||
self.num_itrs = num_itrs
|
||||
# Number of singular values
|
||||
self.num_svs = num_svs
|
||||
# Transposed?
|
||||
self.transpose = transpose
|
||||
# Epsilon value for avoiding divide-by-0
|
||||
self.eps = eps
|
||||
# Register a singular vector for each sv
|
||||
for i in range(self.num_svs):
|
||||
self.register_buffer('u%d' % i, torch.randn(1, num_outputs))
|
||||
self.register_buffer('sv%d' % i, torch.ones(1))
|
||||
|
||||
# Singular vectors (u side)
|
||||
@property
|
||||
def u(self):
|
||||
return [getattr(self, 'u%d' % i) for i in range(self.num_svs)]
|
||||
|
||||
# Singular values;
|
||||
# note that these buffers are just for logging and are not used in training.
|
||||
@property
|
||||
def sv(self):
|
||||
return [getattr(self, 'sv%d' % i) for i in range(self.num_svs)]
|
||||
|
||||
# Compute the spectrally-normalized weight
|
||||
def W_(self):
|
||||
W_mat = self.weight.view(self.weight.size(0), -1)
|
||||
if self.transpose:
|
||||
W_mat = W_mat.t()
|
||||
# Apply num_itrs power iterations
|
||||
for _ in range(self.num_itrs):
|
||||
svs, us, vs = power_iteration(W_mat, self.u, update=self.training, eps=self.eps)
|
||||
# Update the svs
|
||||
if self.training:
|
||||
with torch.no_grad(): # Make sure to do this in a no_grad() context or you'll get memory leaks!
|
||||
for i, sv in enumerate(svs):
|
||||
self.sv[i][:] = sv
|
||||
return self.weight / svs[0]
|
||||
|
||||
|
||||
# 2D Conv layer with spectral norm
|
||||
class SNConv2d(nn.Conv2d, SN):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
||||
padding=0, dilation=1, groups=1, bias=True,
|
||||
num_svs=1, num_itrs=1, eps=1e-12):
|
||||
nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride,
|
||||
padding, dilation, groups, bias)
|
||||
SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps)
|
||||
|
||||
def forward(self, x):
|
||||
return F.conv2d(x, self.W_(), self.bias, self.stride,
|
||||
self.padding, self.dilation, self.groups)
|
||||
|
||||
|
||||
# Linear layer with spectral norm
|
||||
class SNLinear(nn.Linear, SN):
|
||||
def __init__(self, in_features, out_features, bias=True,
|
||||
num_svs=1, num_itrs=1, eps=1e-12):
|
||||
nn.Linear.__init__(self, in_features, out_features, bias)
|
||||
SN.__init__(self, num_svs, num_itrs, out_features, eps=eps)
|
||||
|
||||
def forward(self, x):
|
||||
return F.linear(x, self.W_(), self.bias)
|
||||
|
||||
|
||||
# Embedding layer with spectral norm
|
||||
# We use num_embeddings as the dim instead of embedding_dim here
|
||||
# for convenience sake
|
||||
class SNEmbedding(nn.Embedding, SN):
|
||||
def __init__(self, num_embeddings, embedding_dim, padding_idx=None,
|
||||
max_norm=None, norm_type=2, scale_grad_by_freq=False,
|
||||
sparse=False, _weight=None,
|
||||
num_svs=1, num_itrs=1, eps=1e-12):
|
||||
nn.Embedding.__init__(self, num_embeddings, embedding_dim, padding_idx,
|
||||
max_norm, norm_type, scale_grad_by_freq,
|
||||
sparse, _weight)
|
||||
SN.__init__(self, num_svs, num_itrs, num_embeddings, eps=eps)
|
||||
|
||||
def forward(self, x):
|
||||
return F.embedding(x, self.W_())
|
||||
|
||||
|
||||
# A non-local block as used in SA-GAN
|
||||
# Note that the implementation as described in the paper is largely incorrect;
|
||||
# refer to the released code for the actual implementation.
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, ch, which_conv=SNConv2d, name='attention'):
|
||||
super(Attention, self).__init__()
|
||||
# Channel multiplier
|
||||
self.ch = ch
|
||||
self.which_conv = which_conv
|
||||
self.theta = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False)
|
||||
self.phi = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False)
|
||||
self.g = self.which_conv(self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False)
|
||||
self.o = self.which_conv(self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False)
|
||||
# Learnable gain parameter
|
||||
self.gamma = P(torch.tensor(0.), requires_grad=True)
|
||||
|
||||
def forward(self, x, y=None):
|
||||
# Apply convs
|
||||
theta = self.theta(x)
|
||||
phi = F.max_pool2d(self.phi(x), [2, 2])
|
||||
g = F.max_pool2d(self.g(x), [2, 2])
|
||||
# Perform reshapes
|
||||
theta = theta.view(-1, self.ch // 8, x.shape[2] * x.shape[3])
|
||||
phi = phi.view(-1, self.ch // 8, x.shape[2] * x.shape[3] // 4)
|
||||
g = g.view(-1, self.ch // 2, x.shape[2] * x.shape[3] // 4)
|
||||
# Matmul and softmax to get attention maps
|
||||
beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1)
|
||||
# Attention map times g path
|
||||
o = self.o(torch.bmm(g, beta.transpose(1, 2)).view(-1, self.ch // 2, x.shape[2], x.shape[3]))
|
||||
return self.gamma * o + x
|
||||
|
||||
|
||||
# Fused batchnorm op
|
||||
def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5):
|
||||
# Apply scale and shift--if gain and bias are provided, fuse them here
|
||||
# Prepare scale
|
||||
scale = torch.rsqrt(var + eps)
|
||||
# If a gain is provided, use it
|
||||
if gain is not None:
|
||||
scale = scale * gain
|
||||
# Prepare shift
|
||||
shift = mean * scale
|
||||
# If bias is provided, use it
|
||||
if bias is not None:
|
||||
shift = shift - bias
|
||||
return x * scale - shift
|
||||
# return ((x - mean) / ((var + eps) ** 0.5)) * gain + bias # The unfused way.
|
||||
|
||||
|
||||
# Manual BN
|
||||
# Calculate means and variances using mean-of-squares minus mean-squared
|
||||
def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5):
|
||||
# Cast x to float32 if necessary
|
||||
float_x = x.float()
|
||||
# Calculate expected value of x (m) and expected value of x**2 (m2)
|
||||
# Mean of x
|
||||
m = torch.mean(float_x, [0, 2, 3], keepdim=True)
|
||||
# Mean of x squared
|
||||
m2 = torch.mean(float_x ** 2, [0, 2, 3], keepdim=True)
|
||||
# Calculate variance as mean of squared minus mean squared.
|
||||
var = (m2 - m ** 2)
|
||||
# Cast back to float 16 if necessary
|
||||
var = var.type(x.type())
|
||||
m = m.type(x.type())
|
||||
# Return mean and variance for updating stored mean/var if requested
|
||||
if return_mean_var:
|
||||
return fused_bn(x, m, var, gain, bias, eps), m.squeeze(), var.squeeze()
|
||||
else:
|
||||
return fused_bn(x, m, var, gain, bias, eps)
|
||||
|
||||
|
||||
# My batchnorm, supports standing stats
|
||||
class myBN(nn.Module):
|
||||
def __init__(self, num_channels, eps=1e-5, momentum=0.1):
|
||||
super(myBN, self).__init__()
|
||||
# momentum for updating running stats
|
||||
self.momentum = momentum
|
||||
# epsilon to avoid dividing by 0
|
||||
self.eps = eps
|
||||
# Momentum
|
||||
self.momentum = momentum
|
||||
# Register buffers
|
||||
self.register_buffer('stored_mean', torch.zeros(num_channels))
|
||||
self.register_buffer('stored_var', torch.ones(num_channels))
|
||||
self.register_buffer('accumulation_counter', torch.zeros(1))
|
||||
# Accumulate running means and vars
|
||||
self.accumulate_standing = False
|
||||
|
||||
# reset standing stats
|
||||
def reset_stats(self):
|
||||
self.stored_mean[:] = 0
|
||||
self.stored_var[:] = 0
|
||||
self.accumulation_counter[:] = 0
|
||||
|
||||
def forward(self, x, gain, bias):
|
||||
if self.training:
|
||||
out, mean, var = manual_bn(x, gain, bias, return_mean_var=True, eps=self.eps)
|
||||
# If accumulating standing stats, increment them
|
||||
if self.accumulate_standing:
|
||||
self.stored_mean[:] = self.stored_mean + mean.data
|
||||
self.stored_var[:] = self.stored_var + var.data
|
||||
self.accumulation_counter += 1.0
|
||||
# If not accumulating standing stats, take running averages
|
||||
else:
|
||||
self.stored_mean[:] = self.stored_mean * (1 - self.momentum) + mean * self.momentum
|
||||
self.stored_var[:] = self.stored_var * (1 - self.momentum) + var * self.momentum
|
||||
return out
|
||||
# If not in training mode, use the stored statistics
|
||||
else:
|
||||
mean = self.stored_mean.view(1, -1, 1, 1)
|
||||
var = self.stored_var.view(1, -1, 1, 1)
|
||||
# If using standing stats, divide them by the accumulation counter
|
||||
if self.accumulate_standing:
|
||||
mean = mean / self.accumulation_counter
|
||||
var = var / self.accumulation_counter
|
||||
return fused_bn(x, mean, var, gain, bias, self.eps)
|
||||
|
||||
|
||||
# Simple function to handle groupnorm norm stylization
|
||||
def groupnorm(x, norm_style):
|
||||
# If number of channels specified in norm_style:
|
||||
if 'ch' in norm_style:
|
||||
ch = int(norm_style.split('_')[-1])
|
||||
groups = max(int(x.shape[1]) // ch, 1)
|
||||
# If number of groups specified in norm style
|
||||
elif 'grp' in norm_style:
|
||||
groups = int(norm_style.split('_')[-1])
|
||||
# If neither, default to groups = 16
|
||||
else:
|
||||
groups = 16
|
||||
return F.group_norm(x, groups)
|
||||
|
||||
|
||||
# Class-conditional bn
|
||||
# output size is the number of channels, input size is for the linear layers
|
||||
# Andy's Note: this class feels messy but I'm not really sure how to clean it up
|
||||
# Suggestions welcome! (By which I mean, refactor this and make a pull request
|
||||
# if you want to make this more readable/usable).
|
||||
class ccbn(nn.Module):
|
||||
def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1,
|
||||
cross_replica=False, mybn=False, norm_style='bn', ):
|
||||
super(ccbn, self).__init__()
|
||||
self.output_size, self.input_size = output_size, input_size
|
||||
# Prepare gain and bias layers
|
||||
self.gain = which_linear(input_size, output_size)
|
||||
self.bias = which_linear(input_size, output_size)
|
||||
# epsilon to avoid dividing by 0
|
||||
self.eps = eps
|
||||
# Momentum
|
||||
self.momentum = momentum
|
||||
# Use cross-replica batchnorm?
|
||||
self.cross_replica = cross_replica
|
||||
# Use my batchnorm?
|
||||
self.mybn = mybn
|
||||
# Norm style?
|
||||
self.norm_style = norm_style
|
||||
|
||||
if self.cross_replica:
|
||||
self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False)
|
||||
elif self.mybn:
|
||||
self.bn = myBN(output_size, self.eps, self.momentum)
|
||||
elif self.norm_style in ['bn', 'in']:
|
||||
self.register_buffer('stored_mean', torch.zeros(output_size))
|
||||
self.register_buffer('stored_var', torch.ones(output_size))
|
||||
|
||||
def forward(self, x, y):
|
||||
# Calculate class-conditional gains and biases
|
||||
gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
|
||||
bias = self.bias(y).view(y.size(0), -1, 1, 1)
|
||||
# If using my batchnorm
|
||||
if self.mybn or self.cross_replica:
|
||||
return self.bn(x, gain=gain, bias=bias)
|
||||
# else:
|
||||
else:
|
||||
if self.norm_style == 'bn':
|
||||
out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None,
|
||||
self.training, 0.1, self.eps)
|
||||
elif self.norm_style == 'in':
|
||||
out = F.instance_norm(x, self.stored_mean, self.stored_var, None, None,
|
||||
self.training, 0.1, self.eps)
|
||||
elif self.norm_style == 'gn':
|
||||
out = groupnorm(x, self.normstyle)
|
||||
elif self.norm_style == 'nonorm':
|
||||
out = x
|
||||
return out * gain + bias
|
||||
|
||||
def extra_repr(self):
|
||||
s = 'out: {output_size}, in: {input_size},'
|
||||
s += ' cross_replica={cross_replica}'
|
||||
return s.format(**self.__dict__)
|
||||
|
||||
|
||||
# Normal, non-class-conditional BN
|
||||
class bn(nn.Module):
|
||||
def __init__(self, output_size, eps=1e-5, momentum=0.1,
|
||||
cross_replica=False, mybn=False):
|
||||
super(bn, self).__init__()
|
||||
self.output_size = output_size
|
||||
# Prepare gain and bias layers
|
||||
self.gain = P(torch.ones(output_size), requires_grad=True)
|
||||
self.bias = P(torch.zeros(output_size), requires_grad=True)
|
||||
# epsilon to avoid dividing by 0
|
||||
self.eps = eps
|
||||
# Momentum
|
||||
self.momentum = momentum
|
||||
# Use cross-replica batchnorm?
|
||||
self.cross_replica = cross_replica
|
||||
# Use my batchnorm?
|
||||
self.mybn = mybn
|
||||
|
||||
if self.cross_replica:
|
||||
self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False)
|
||||
elif mybn:
|
||||
self.bn = myBN(output_size, self.eps, self.momentum)
|
||||
# Register buffers if neither of the above
|
||||
else:
|
||||
self.register_buffer('stored_mean', torch.zeros(output_size))
|
||||
self.register_buffer('stored_var', torch.ones(output_size))
|
||||
|
||||
def forward(self, x, y=None):
|
||||
if self.cross_replica or self.mybn:
|
||||
gain = self.gain.view(1, -1, 1, 1)
|
||||
bias = self.bias.view(1, -1, 1, 1)
|
||||
return self.bn(x, gain=gain, bias=bias)
|
||||
else:
|
||||
return F.batch_norm(x, self.stored_mean, self.stored_var, self.gain,
|
||||
self.bias, self.training, self.momentum, self.eps)
|
||||
|
||||
|
||||
# Generator blocks
|
||||
# Note that this class assumes the kernel size and padding (and any other
|
||||
# settings) have been selected in the main generator module and passed in
|
||||
# through the which_conv arg. Similar rules apply with which_bn (the input
|
||||
# size [which is actually the number of channels of the conditional info] must
|
||||
# be preselected)
|
||||
class GBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels,
|
||||
which_conv=nn.Conv2d, which_bn=bn, activation=None,
|
||||
upsample=None):
|
||||
super(GBlock, self).__init__()
|
||||
|
||||
self.in_channels, self.out_channels = in_channels, out_channels
|
||||
self.which_conv, self.which_bn = which_conv, which_bn
|
||||
self.activation = activation
|
||||
self.upsample = upsample
|
||||
# Conv layers
|
||||
self.conv1 = self.which_conv(self.in_channels, self.out_channels)
|
||||
self.conv2 = self.which_conv(self.out_channels, self.out_channels)
|
||||
self.learnable_sc = in_channels != out_channels or upsample
|
||||
if self.learnable_sc:
|
||||
self.conv_sc = self.which_conv(in_channels, out_channels,
|
||||
kernel_size=1, padding=0)
|
||||
# Batchnorm layers
|
||||
self.bn1 = self.which_bn(in_channels)
|
||||
self.bn2 = self.which_bn(out_channels)
|
||||
# upsample layers
|
||||
self.upsample = upsample
|
||||
|
||||
def forward(self, x, y):
|
||||
h = self.activation(self.bn1(x, y))
|
||||
if self.upsample:
|
||||
h = self.upsample(h)
|
||||
x = self.upsample(x)
|
||||
h = self.conv1(h)
|
||||
h = self.activation(self.bn2(h, y))
|
||||
h = self.conv2(h)
|
||||
if self.learnable_sc:
|
||||
x = self.conv_sc(x)
|
||||
return h + x
|
||||
|
||||
|
||||
# Residual block for the discriminator
|
||||
class DBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, which_conv=SNConv2d, wide=True,
|
||||
preactivation=False, activation=None, downsample=None, ):
|
||||
super(DBlock, self).__init__()
|
||||
self.in_channels, self.out_channels = in_channels, out_channels
|
||||
# If using wide D (as in SA-GAN and BigGAN), change the channel pattern
|
||||
self.hidden_channels = self.out_channels if wide else self.in_channels
|
||||
self.which_conv = which_conv
|
||||
self.preactivation = preactivation
|
||||
self.activation = activation
|
||||
self.downsample = downsample
|
||||
|
||||
# Conv layers
|
||||
self.conv1 = self.which_conv(self.in_channels, self.hidden_channels)
|
||||
self.conv2 = self.which_conv(self.hidden_channels, self.out_channels)
|
||||
self.learnable_sc = True if (in_channels != out_channels) or downsample else False
|
||||
if self.learnable_sc:
|
||||
self.conv_sc = self.which_conv(in_channels, out_channels,
|
||||
kernel_size=1, padding=0)
|
||||
|
||||
def shortcut(self, x):
|
||||
if self.preactivation:
|
||||
if self.learnable_sc:
|
||||
x = self.conv_sc(x)
|
||||
if self.downsample:
|
||||
x = self.downsample(x)
|
||||
else:
|
||||
if self.downsample:
|
||||
x = self.downsample(x)
|
||||
if self.learnable_sc:
|
||||
x = self.conv_sc(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
if self.preactivation:
|
||||
# h = self.activation(x) # NOT TODAY SATAN
|
||||
# Andy's note: This line *must* be an out-of-place ReLU or it
|
||||
# will negatively affect the shortcut connection.
|
||||
h = F.relu(x)
|
||||
else:
|
||||
h = x
|
||||
h = self.conv1(h)
|
||||
h = self.conv2(self.activation(h))
|
||||
if self.downsample:
|
||||
h = self.downsample(h)
|
||||
|
||||
return h + self.shortcut(x)
|
||||
|
||||
# dogball
|
351
codes/models/archs/biggan_sync_batchnorm.py
Normal file
351
codes/models/archs/biggan_sync_batchnorm.py
Normal file
|
@ -0,0 +1,351 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# File : batchnorm.py
|
||||
# Author : Jiayuan Mao
|
||||
# Email : maojiayuan@gmail.com
|
||||
# Date : 27/01/2018
|
||||
#
|
||||
# This file is part of Synchronized-BatchNorm-PyTorch.
|
||||
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
||||
# Distributed under MIT License.
|
||||
|
||||
import collections
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
|
||||
|
||||
from .comm import SyncMaster
|
||||
|
||||
__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
|
||||
|
||||
|
||||
def _sum_ft(tensor):
|
||||
"""sum over the first and last dimention"""
|
||||
return tensor.sum(dim=0).sum(dim=-1)
|
||||
|
||||
|
||||
def _unsqueeze_ft(tensor):
|
||||
"""add new dementions at the front and the tail"""
|
||||
return tensor.unsqueeze(0).unsqueeze(-1)
|
||||
|
||||
|
||||
_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
|
||||
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
|
||||
|
||||
|
||||
# _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'ssum', 'sum_size'])
|
||||
|
||||
class _SynchronizedBatchNorm(_BatchNorm):
|
||||
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
|
||||
super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
|
||||
|
||||
self._sync_master = SyncMaster(self._data_parallel_master)
|
||||
|
||||
self._is_parallel = False
|
||||
self._parallel_id = None
|
||||
self._slave_pipe = None
|
||||
|
||||
def forward(self, input, gain=None, bias=None):
|
||||
# If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
|
||||
if not (self._is_parallel and self.training):
|
||||
out = F.batch_norm(
|
||||
input, self.running_mean, self.running_var, self.weight, self.bias,
|
||||
self.training, self.momentum, self.eps)
|
||||
if gain is not None:
|
||||
out = out + gain
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
return out
|
||||
|
||||
# Resize the input to (B, C, -1).
|
||||
input_shape = input.size()
|
||||
# print(input_shape)
|
||||
input = input.view(input.size(0), input.size(1), -1)
|
||||
|
||||
# Compute the sum and square-sum.
|
||||
sum_size = input.size(0) * input.size(2)
|
||||
input_sum = _sum_ft(input)
|
||||
input_ssum = _sum_ft(input ** 2)
|
||||
# Reduce-and-broadcast the statistics.
|
||||
# print('it begins')
|
||||
if self._parallel_id == 0:
|
||||
mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
|
||||
else:
|
||||
mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
|
||||
# if self._parallel_id == 0:
|
||||
# # print('here')
|
||||
# sum, ssum, num = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
|
||||
# else:
|
||||
# # print('there')
|
||||
# sum, ssum, num = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
|
||||
|
||||
# print('how2')
|
||||
# num = sum_size
|
||||
# print('Sum: %f, ssum: %f, sumsize: %f, insum: %f' %(float(sum.sum().cpu()), float(ssum.sum().cpu()), float(sum_size), float(input_sum.sum().cpu())))
|
||||
# Fix the graph
|
||||
# sum = (sum.detach() - input_sum.detach()) + input_sum
|
||||
# ssum = (ssum.detach() - input_ssum.detach()) + input_ssum
|
||||
|
||||
# mean = sum / num
|
||||
# var = ssum / num - mean ** 2
|
||||
# # var = (ssum - mean * sum) / num
|
||||
# inv_std = torch.rsqrt(var + self.eps)
|
||||
|
||||
# Compute the output.
|
||||
if gain is not None:
|
||||
# print('gaining')
|
||||
# scale = _unsqueeze_ft(inv_std) * gain.squeeze(-1)
|
||||
# shift = _unsqueeze_ft(mean) * scale - bias.squeeze(-1)
|
||||
# output = input * scale - shift
|
||||
output = (input - _unsqueeze_ft(mean)) * (_unsqueeze_ft(inv_std) * gain.squeeze(-1)) + bias.squeeze(-1)
|
||||
elif self.affine:
|
||||
# MJY:: Fuse the multiplication for speed.
|
||||
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
|
||||
else:
|
||||
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
|
||||
|
||||
# Reshape it.
|
||||
return output.view(input_shape)
|
||||
|
||||
def __data_parallel_replicate__(self, ctx, copy_id):
|
||||
self._is_parallel = True
|
||||
self._parallel_id = copy_id
|
||||
|
||||
# parallel_id == 0 means master device.
|
||||
if self._parallel_id == 0:
|
||||
ctx.sync_master = self._sync_master
|
||||
else:
|
||||
self._slave_pipe = ctx.sync_master.register_slave(copy_id)
|
||||
|
||||
def _data_parallel_master(self, intermediates):
|
||||
"""Reduce the sum and square-sum, compute the statistics, and broadcast it."""
|
||||
|
||||
# Always using same "device order" makes the ReduceAdd operation faster.
|
||||
# Thanks to:: Tete Xiao (http://tetexiao.com/)
|
||||
intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
|
||||
|
||||
to_reduce = [i[1][:2] for i in intermediates]
|
||||
to_reduce = [j for i in to_reduce for j in i] # flatten
|
||||
target_gpus = [i[1].sum.get_device() for i in intermediates]
|
||||
|
||||
sum_size = sum([i[1].sum_size for i in intermediates])
|
||||
sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
|
||||
mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
|
||||
|
||||
broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
|
||||
# print('a')
|
||||
# print(type(sum_), type(ssum), type(sum_size), sum_.shape, ssum.shape, sum_size)
|
||||
# broadcasted = Broadcast.apply(target_gpus, sum_, ssum, torch.tensor(sum_size).float().to(sum_.device))
|
||||
# print('b')
|
||||
outputs = []
|
||||
for i, rec in enumerate(intermediates):
|
||||
outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2])))
|
||||
# outputs.append((rec[0], _MasterMessage(*broadcasted[i*3:i*3+3])))
|
||||
|
||||
return outputs
|
||||
|
||||
def _compute_mean_std(self, sum_, ssum, size):
|
||||
"""Compute the mean and standard-deviation with sum and square-sum. This method
|
||||
also maintains the moving average on the master device."""
|
||||
assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
|
||||
mean = sum_ / size
|
||||
sumvar = ssum - sum_ * mean
|
||||
unbias_var = sumvar / (size - 1)
|
||||
bias_var = sumvar / size
|
||||
|
||||
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
|
||||
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
|
||||
return mean, torch.rsqrt(bias_var + self.eps)
|
||||
# return mean, bias_var.clamp(self.eps) ** -0.5
|
||||
|
||||
|
||||
class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
|
||||
r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
|
||||
mini-batch.
|
||||
|
||||
.. math::
|
||||
|
||||
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
||||
|
||||
This module differs from the built-in PyTorch BatchNorm1d as the mean and
|
||||
standard-deviation are reduced across all devices during training.
|
||||
|
||||
For example, when one uses `nn.DataParallel` to wrap the network during
|
||||
training, PyTorch's implementation normalize the tensor on each device using
|
||||
the statistics only on that device, which accelerated the computation and
|
||||
is also easy to implement, but the statistics might be inaccurate.
|
||||
Instead, in this synchronized version, the statistics will be computed
|
||||
over all training samples distributed on multiple devices.
|
||||
|
||||
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
||||
as the built-in PyTorch implementation.
|
||||
|
||||
The mean and standard-deviation are calculated per-dimension over
|
||||
the mini-batches and gamma and beta are learnable parameter vectors
|
||||
of size C (where C is the input size).
|
||||
|
||||
During training, this layer keeps a running estimate of its computed mean
|
||||
and variance. The running sum is kept with a default momentum of 0.1.
|
||||
|
||||
During evaluation, this running mean/variance is used for normalization.
|
||||
|
||||
Because the BatchNorm is done over the `C` dimension, computing statistics
|
||||
on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
|
||||
|
||||
Args:
|
||||
num_features: num_features from an expected input of size
|
||||
`batch_size x num_features [x width]`
|
||||
eps: a value added to the denominator for numerical stability.
|
||||
Default: 1e-5
|
||||
momentum: the value used for the running_mean and running_var
|
||||
computation. Default: 0.1
|
||||
affine: a boolean value that when set to ``True``, gives the layer learnable
|
||||
affine parameters. Default: ``True``
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, C)` or :math:`(N, C, L)`
|
||||
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
|
||||
|
||||
Examples:
|
||||
>>> # With Learnable Parameters
|
||||
>>> m = SynchronizedBatchNorm1d(100)
|
||||
>>> # Without Learnable Parameters
|
||||
>>> m = SynchronizedBatchNorm1d(100, affine=False)
|
||||
>>> input = torch.autograd.Variable(torch.randn(20, 100))
|
||||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
def _check_input_dim(self, input):
|
||||
if input.dim() != 2 and input.dim() != 3:
|
||||
raise ValueError('expected 2D or 3D input (got {}D input)'
|
||||
.format(input.dim()))
|
||||
super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
|
||||
|
||||
|
||||
class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
|
||||
r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
|
||||
of 3d inputs
|
||||
|
||||
.. math::
|
||||
|
||||
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
||||
|
||||
This module differs from the built-in PyTorch BatchNorm2d as the mean and
|
||||
standard-deviation are reduced across all devices during training.
|
||||
|
||||
For example, when one uses `nn.DataParallel` to wrap the network during
|
||||
training, PyTorch's implementation normalize the tensor on each device using
|
||||
the statistics only on that device, which accelerated the computation and
|
||||
is also easy to implement, but the statistics might be inaccurate.
|
||||
Instead, in this synchronized version, the statistics will be computed
|
||||
over all training samples distributed on multiple devices.
|
||||
|
||||
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
||||
as the built-in PyTorch implementation.
|
||||
|
||||
The mean and standard-deviation are calculated per-dimension over
|
||||
the mini-batches and gamma and beta are learnable parameter vectors
|
||||
of size C (where C is the input size).
|
||||
|
||||
During training, this layer keeps a running estimate of its computed mean
|
||||
and variance. The running sum is kept with a default momentum of 0.1.
|
||||
|
||||
During evaluation, this running mean/variance is used for normalization.
|
||||
|
||||
Because the BatchNorm is done over the `C` dimension, computing statistics
|
||||
on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
|
||||
|
||||
Args:
|
||||
num_features: num_features from an expected input of
|
||||
size batch_size x num_features x height x width
|
||||
eps: a value added to the denominator for numerical stability.
|
||||
Default: 1e-5
|
||||
momentum: the value used for the running_mean and running_var
|
||||
computation. Default: 0.1
|
||||
affine: a boolean value that when set to ``True``, gives the layer learnable
|
||||
affine parameters. Default: ``True``
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, C, H, W)`
|
||||
- Output: :math:`(N, C, H, W)` (same shape as input)
|
||||
|
||||
Examples:
|
||||
>>> # With Learnable Parameters
|
||||
>>> m = SynchronizedBatchNorm2d(100)
|
||||
>>> # Without Learnable Parameters
|
||||
>>> m = SynchronizedBatchNorm2d(100, affine=False)
|
||||
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
|
||||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
def _check_input_dim(self, input):
|
||||
if input.dim() != 4:
|
||||
raise ValueError('expected 4D input (got {}D input)'
|
||||
.format(input.dim()))
|
||||
super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
|
||||
|
||||
|
||||
class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
|
||||
r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
|
||||
of 4d inputs
|
||||
|
||||
.. math::
|
||||
|
||||
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
||||
|
||||
This module differs from the built-in PyTorch BatchNorm3d as the mean and
|
||||
standard-deviation are reduced across all devices during training.
|
||||
|
||||
For example, when one uses `nn.DataParallel` to wrap the network during
|
||||
training, PyTorch's implementation normalize the tensor on each device using
|
||||
the statistics only on that device, which accelerated the computation and
|
||||
is also easy to implement, but the statistics might be inaccurate.
|
||||
Instead, in this synchronized version, the statistics will be computed
|
||||
over all training samples distributed on multiple devices.
|
||||
|
||||
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
||||
as the built-in PyTorch implementation.
|
||||
|
||||
The mean and standard-deviation are calculated per-dimension over
|
||||
the mini-batches and gamma and beta are learnable parameter vectors
|
||||
of size C (where C is the input size).
|
||||
|
||||
During training, this layer keeps a running estimate of its computed mean
|
||||
and variance. The running sum is kept with a default momentum of 0.1.
|
||||
|
||||
During evaluation, this running mean/variance is used for normalization.
|
||||
|
||||
Because the BatchNorm is done over the `C` dimension, computing statistics
|
||||
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
|
||||
or Spatio-temporal BatchNorm
|
||||
|
||||
Args:
|
||||
num_features: num_features from an expected input of
|
||||
size batch_size x num_features x depth x height x width
|
||||
eps: a value added to the denominator for numerical stability.
|
||||
Default: 1e-5
|
||||
momentum: the value used for the running_mean and running_var
|
||||
computation. Default: 0.1
|
||||
affine: a boolean value that when set to ``True``, gives the layer learnable
|
||||
affine parameters. Default: ``True``
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, C, D, H, W)`
|
||||
- Output: :math:`(N, C, D, H, W)` (same shape as input)
|
||||
|
||||
Examples:
|
||||
>>> # With Learnable Parameters
|
||||
>>> m = SynchronizedBatchNorm3d(100)
|
||||
>>> # Without Learnable Parameters
|
||||
>>> m = SynchronizedBatchNorm3d(100, affine=False)
|
||||
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
|
||||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
def _check_input_dim(self, input):
|
||||
if input.dim() != 5:
|
||||
raise ValueError('expected 5D input (got {}D input)'
|
||||
.format(input.dim()))
|
||||
super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
|
Loading…
Reference in New Issue
Block a user