Get rid of biggan
Not really sure it's a great fit for what is being done here.
This commit is contained in:
parent
0a714e8451
commit
ddfd7f67a0
|
@ -1,209 +0,0 @@
|
|||
# Source: https://github.com/ajbrock/BigGAN-PyTorch/blob/master/BigGANdeep.py
|
||||
import functools
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import init
|
||||
import torch.nn.functional as F
|
||||
|
||||
import models.archs.biggan_layers as layers
|
||||
|
||||
# BigGAN-deep: uses a different resblock and pattern
|
||||
|
||||
# Architectures for G
|
||||
# Attention is passed in in the format '32_64' to mean applying an attention
|
||||
# block at both resolution 32x32 and 64x64. Just '64' will apply at 64x64.
|
||||
|
||||
# Channel ratio is the ratio of
|
||||
class GBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels,
|
||||
which_conv=nn.Conv2d, which_bn=layers.bn, activation=None,
|
||||
upsample=None, channel_ratio=4):
|
||||
super(GBlock, self).__init__()
|
||||
|
||||
self.in_channels, self.out_channels = in_channels, out_channels
|
||||
self.hidden_channels = self.in_channels // channel_ratio
|
||||
self.which_conv, self.which_bn = which_conv, which_bn
|
||||
self.activation = activation
|
||||
# Conv layers
|
||||
self.conv1 = self.which_conv(self.in_channels, self.hidden_channels,
|
||||
kernel_size=1, padding=0)
|
||||
self.conv2 = self.which_conv(self.hidden_channels, self.hidden_channels)
|
||||
self.conv3 = self.which_conv(self.hidden_channels, self.hidden_channels)
|
||||
self.conv4 = self.which_conv(self.hidden_channels, self.out_channels,
|
||||
kernel_size=1, padding=0)
|
||||
# Batchnorm layers
|
||||
self.bn1 = self.which_bn(self.in_channels)
|
||||
self.bn2 = self.which_bn(self.hidden_channels)
|
||||
self.bn3 = self.which_bn(self.hidden_channels)
|
||||
self.bn4 = self.which_bn(self.hidden_channels)
|
||||
# upsample layers
|
||||
self.upsample = upsample
|
||||
|
||||
def forward(self, x):
|
||||
# Project down to channel ratio
|
||||
h = self.conv1(self.activation(self.bn1(x)))
|
||||
# Apply next BN-ReLU
|
||||
h = self.activation(self.bn2(h))
|
||||
# Drop channels in x if necessary
|
||||
if self.in_channels != self.out_channels:
|
||||
x = x[:, :self.out_channels]
|
||||
# Upsample both h and x at this point
|
||||
if self.upsample:
|
||||
h = self.upsample(h)
|
||||
x = self.upsample(x)
|
||||
# 3x3 convs
|
||||
h = self.conv2(h)
|
||||
h = self.conv3(self.activation(self.bn3(h)))
|
||||
# Final 1x1 conv
|
||||
h = self.conv4(self.activation(self.bn4(h)))
|
||||
return h + x
|
||||
|
||||
|
||||
def G_arch(ch=64, attention='64', base_width=64):
|
||||
arch = {}
|
||||
arch[128] = {'in_channels': [ch * item for item in [2, 2, 1, 1]],
|
||||
'out_channels': [ch * item for item in [2, 1, 1, 1]],
|
||||
'upsample': [False, True, False, False],
|
||||
'resolution': [base_width, base_width, base_width*2, base_width*2],
|
||||
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
|
||||
for i in range(3, 8)}}
|
||||
|
||||
return arch
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, G_ch=64, G_depth=2, bottom_width=4, resolution=128,
|
||||
G_kernel_size=3, G_attn='64',
|
||||
num_G_SVs=1, num_G_SV_itrs=1, hier=False,
|
||||
cross_replica=False, mybn=False,
|
||||
G_activation=nn.ReLU(inplace=False),
|
||||
BN_eps=1e-5, SN_eps=1e-12,
|
||||
G_init='ortho', skip_init=False,
|
||||
G_param='SN', norm_style='bn'):
|
||||
super(Generator, self).__init__()
|
||||
# Channel width multiplier
|
||||
self.ch = G_ch
|
||||
# Number of resblocks per stage
|
||||
self.G_depth = G_depth
|
||||
# The initial spatial dimensions
|
||||
self.bottom_width = bottom_width
|
||||
# Resolution of the output
|
||||
self.resolution = resolution
|
||||
# Kernel size?
|
||||
self.kernel_size = G_kernel_size
|
||||
# Attention?
|
||||
self.attention = G_attn
|
||||
# Hierarchical latent space?
|
||||
self.hier = hier
|
||||
# Cross replica batchnorm?
|
||||
self.cross_replica = cross_replica
|
||||
# Use my batchnorm?
|
||||
self.mybn = mybn
|
||||
# nonlinearity for residual blocks
|
||||
self.activation = G_activation
|
||||
# Initialization style
|
||||
self.init = G_init
|
||||
# Parameterization style
|
||||
self.G_param = G_param
|
||||
# Normalization style
|
||||
self.norm_style = norm_style
|
||||
# Epsilon for BatchNorm?
|
||||
self.BN_eps = BN_eps
|
||||
# Epsilon for Spectral Norm?
|
||||
self.SN_eps = SN_eps
|
||||
# Architecture dict
|
||||
self.arch = G_arch(self.ch, self.attention)[resolution]
|
||||
|
||||
# Which convs, batchnorms, and linear layers to use
|
||||
if self.G_param == 'SN':
|
||||
self.which_conv = functools.partial(layers.SNConv2d,
|
||||
kernel_size=3, padding=1,
|
||||
num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
|
||||
eps=self.SN_eps)
|
||||
else:
|
||||
self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
|
||||
|
||||
self.which_bn = functools.partial(layers.bn,
|
||||
cross_replica=self.cross_replica,
|
||||
mybn=self.mybn,
|
||||
norm_style=self.norm_style,
|
||||
eps=self.BN_eps)
|
||||
|
||||
# Prepare model
|
||||
# First conv layer to project into feature-space
|
||||
self.initial_conv = nn.Sequential(self.which_conv(3, self.arch['in_channels'][0]),
|
||||
layers.bn(self.arch['in_channels'][0],
|
||||
cross_replica=self.cross_replica,
|
||||
mybn=self.mybn),
|
||||
self.activation
|
||||
)
|
||||
|
||||
# self.blocks is a doubly-nested list of modules, the outer loop intended
|
||||
# to be over blocks at a given resolution (resblocks and/or self-attention)
|
||||
# while the inner loop is over a given block
|
||||
self.blocks = []
|
||||
for index in range(len(self.arch['out_channels'])):
|
||||
self.blocks += [[GBlock(in_channels=self.arch['in_channels'][index],
|
||||
out_channels=self.arch['in_channels'][index] if g_index == 0 else
|
||||
self.arch['out_channels'][index],
|
||||
which_conv=self.which_conv,
|
||||
which_bn=self.which_bn,
|
||||
activation=self.activation,
|
||||
upsample=(functools.partial(F.interpolate, scale_factor=2)
|
||||
if self.arch['upsample'][index] and g_index == (
|
||||
self.G_depth - 1) else None))]
|
||||
for g_index in range(self.G_depth)]
|
||||
|
||||
# If attention on this block, attach it to the end
|
||||
if self.arch['attention'][self.arch['resolution'][index]]:
|
||||
print('Adding attention layer in G at resolution %d' % self.arch['resolution'][index])
|
||||
self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], self.which_conv)]
|
||||
|
||||
# Turn self.blocks into a ModuleList so that it's all properly registered.
|
||||
self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
|
||||
|
||||
# output layer: batchnorm-relu-conv.
|
||||
# Consider using a non-spectral conv here
|
||||
self.output_layer = nn.Sequential(layers.bn(self.arch['out_channels'][-1],
|
||||
cross_replica=self.cross_replica,
|
||||
mybn=self.mybn),
|
||||
self.activation,
|
||||
self.which_conv(self.arch['out_channels'][-1], 3))
|
||||
|
||||
# Initialize weights. Optionally skip init for testing.
|
||||
if not skip_init:
|
||||
self.init_weights()
|
||||
|
||||
# Initialize
|
||||
def init_weights(self):
|
||||
self.param_count = 0
|
||||
for module in self.modules():
|
||||
if (isinstance(module, nn.Conv2d)
|
||||
or isinstance(module, nn.Linear)
|
||||
or isinstance(module, nn.Embedding)):
|
||||
if self.init == 'ortho':
|
||||
init.orthogonal_(module.weight)
|
||||
elif self.init == 'N02':
|
||||
init.normal_(module.weight, 0, 0.02)
|
||||
elif self.init in ['glorot', 'xavier']:
|
||||
init.xavier_uniform_(module.weight)
|
||||
else:
|
||||
print('Init style not recognized...')
|
||||
self.param_count += sum([p.data.nelement() for p in module.parameters()])
|
||||
print('Param count for G''s initialized parameters: %d' % self.param_count)
|
||||
|
||||
def forward(self, z):
|
||||
# First conv layer to convert into correct filter-space.
|
||||
h = self.initial_conv(z)
|
||||
# Loop over blocks
|
||||
for index, blocklist in enumerate(self.blocks):
|
||||
# Second inner loop in case block has multiple layers
|
||||
for block in blocklist:
|
||||
h = block(h)
|
||||
|
||||
# Apply batchnorm-relu-conv-tanh at output
|
||||
return (torch.tanh(self.output_layer(h)), )
|
||||
|
||||
def biggan_medium(num_filters):
|
||||
return Generator(num_filters)
|
|
@ -1,459 +0,0 @@
|
|||
''' Layers
|
||||
This file contains various layers for the BigGAN models.
|
||||
'''
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Parameter as P
|
||||
|
||||
|
||||
# Projection of x onto y
|
||||
def proj(x, y):
|
||||
return torch.mm(y, x.t()) * y / torch.mm(y, y.t())
|
||||
|
||||
|
||||
# Orthogonalize x wrt list of vectors ys
|
||||
def gram_schmidt(x, ys):
|
||||
for y in ys:
|
||||
x = x - proj(x, y)
|
||||
return x
|
||||
|
||||
|
||||
# Apply num_itrs steps of the power method to estimate top N singular values.
|
||||
def power_iteration(W, u_, update=True, eps=1e-12):
|
||||
# Lists holding singular vectors and values
|
||||
us, vs, svs = [], [], []
|
||||
for i, u in enumerate(u_):
|
||||
# Run one step of the power iteration
|
||||
with torch.no_grad():
|
||||
v = torch.matmul(u, W)
|
||||
# Run Gram-Schmidt to subtract components of all other singular vectors
|
||||
v = F.normalize(gram_schmidt(v, vs), eps=eps)
|
||||
# Add to the list
|
||||
vs += [v]
|
||||
# Update the other singular vector
|
||||
u = torch.matmul(v, W.t())
|
||||
# Run Gram-Schmidt to subtract components of all other singular vectors
|
||||
u = F.normalize(gram_schmidt(u, us), eps=eps)
|
||||
# Add to the list
|
||||
us += [u]
|
||||
if update:
|
||||
u_[i][:] = u
|
||||
# Compute this singular value and add it to the list
|
||||
svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))]
|
||||
# svs += [torch.sum(F.linear(u, W.transpose(0, 1)) * v)]
|
||||
return svs, us, vs
|
||||
|
||||
|
||||
# Convenience passthrough function
|
||||
class identity(nn.Module):
|
||||
def forward(self, input):
|
||||
return input
|
||||
|
||||
|
||||
# Spectral normalization base class
|
||||
class SN(object):
|
||||
def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12):
|
||||
# Number of power iterations per step
|
||||
self.num_itrs = num_itrs
|
||||
# Number of singular values
|
||||
self.num_svs = num_svs
|
||||
# Transposed?
|
||||
self.transpose = transpose
|
||||
# Epsilon value for avoiding divide-by-0
|
||||
self.eps = eps
|
||||
# Register a singular vector for each sv
|
||||
for i in range(self.num_svs):
|
||||
self.register_buffer('u%d' % i, torch.randn(1, num_outputs))
|
||||
self.register_buffer('sv%d' % i, torch.ones(1))
|
||||
|
||||
# Singular vectors (u side)
|
||||
@property
|
||||
def u(self):
|
||||
return [getattr(self, 'u%d' % i) for i in range(self.num_svs)]
|
||||
|
||||
# Singular values;
|
||||
# note that these buffers are just for logging and are not used in training.
|
||||
@property
|
||||
def sv(self):
|
||||
return [getattr(self, 'sv%d' % i) for i in range(self.num_svs)]
|
||||
|
||||
# Compute the spectrally-normalized weight
|
||||
def W_(self):
|
||||
W_mat = self.weight.view(self.weight.size(0), -1)
|
||||
if self.transpose:
|
||||
W_mat = W_mat.t()
|
||||
# Apply num_itrs power iterations
|
||||
for _ in range(self.num_itrs):
|
||||
svs, us, vs = power_iteration(W_mat, self.u, update=self.training, eps=self.eps)
|
||||
# Update the svs
|
||||
if self.training:
|
||||
with torch.no_grad(): # Make sure to do this in a no_grad() context or you'll get memory leaks!
|
||||
for i, sv in enumerate(svs):
|
||||
self.sv[i][:] = sv
|
||||
return self.weight / svs[0]
|
||||
|
||||
|
||||
# 2D Conv layer with spectral norm
|
||||
class SNConv2d(nn.Conv2d, SN):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
||||
padding=0, dilation=1, groups=1, bias=True,
|
||||
num_svs=1, num_itrs=1, eps=1e-12):
|
||||
nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride,
|
||||
padding, dilation, groups, bias)
|
||||
SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps)
|
||||
|
||||
def forward(self, x):
|
||||
return F.conv2d(x, self.W_(), self.bias, self.stride,
|
||||
self.padding, self.dilation, self.groups)
|
||||
|
||||
|
||||
# Linear layer with spectral norm
|
||||
class SNLinear(nn.Linear, SN):
|
||||
def __init__(self, in_features, out_features, bias=True,
|
||||
num_svs=1, num_itrs=1, eps=1e-12):
|
||||
nn.Linear.__init__(self, in_features, out_features, bias)
|
||||
SN.__init__(self, num_svs, num_itrs, out_features, eps=eps)
|
||||
|
||||
def forward(self, x):
|
||||
return F.linear(x, self.W_(), self.bias)
|
||||
|
||||
|
||||
# Embedding layer with spectral norm
|
||||
# We use num_embeddings as the dim instead of embedding_dim here
|
||||
# for convenience sake
|
||||
class SNEmbedding(nn.Embedding, SN):
|
||||
def __init__(self, num_embeddings, embedding_dim, padding_idx=None,
|
||||
max_norm=None, norm_type=2, scale_grad_by_freq=False,
|
||||
sparse=False, _weight=None,
|
||||
num_svs=1, num_itrs=1, eps=1e-12):
|
||||
nn.Embedding.__init__(self, num_embeddings, embedding_dim, padding_idx,
|
||||
max_norm, norm_type, scale_grad_by_freq,
|
||||
sparse, _weight)
|
||||
SN.__init__(self, num_svs, num_itrs, num_embeddings, eps=eps)
|
||||
|
||||
def forward(self, x):
|
||||
return F.embedding(x, self.W_())
|
||||
|
||||
|
||||
# A non-local block as used in SA-GAN
|
||||
# Note that the implementation as described in the paper is largely incorrect;
|
||||
# refer to the released code for the actual implementation.
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, ch, which_conv=SNConv2d, name='attention'):
|
||||
super(Attention, self).__init__()
|
||||
# Channel multiplier
|
||||
self.ch = ch
|
||||
self.which_conv = which_conv
|
||||
self.theta = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False)
|
||||
self.phi = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False)
|
||||
self.g = self.which_conv(self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False)
|
||||
self.o = self.which_conv(self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False)
|
||||
# Learnable gain parameter
|
||||
self.gamma = P(torch.tensor(0.), requires_grad=True)
|
||||
|
||||
def forward(self, x, y=None):
|
||||
# Apply convs
|
||||
theta = self.theta(x)
|
||||
phi = F.max_pool2d(self.phi(x), [2, 2])
|
||||
g = F.max_pool2d(self.g(x), [2, 2])
|
||||
# Perform reshapes
|
||||
theta = theta.view(-1, self.ch // 8, x.shape[2] * x.shape[3])
|
||||
phi = phi.view(-1, self.ch // 8, x.shape[2] * x.shape[3] // 4)
|
||||
g = g.view(-1, self.ch // 2, x.shape[2] * x.shape[3] // 4)
|
||||
# Matmul and softmax to get attention maps
|
||||
beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1)
|
||||
# Attention map times g path
|
||||
o = self.o(torch.bmm(g, beta.transpose(1, 2)).view(-1, self.ch // 2, x.shape[2], x.shape[3]))
|
||||
return self.gamma * o + x
|
||||
|
||||
|
||||
# Fused batchnorm op
|
||||
def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5):
|
||||
# Apply scale and shift--if gain and bias are provided, fuse them here
|
||||
# Prepare scale
|
||||
scale = torch.rsqrt(var + eps)
|
||||
# If a gain is provided, use it
|
||||
if gain is not None:
|
||||
scale = scale * gain
|
||||
# Prepare shift
|
||||
shift = mean * scale
|
||||
# If bias is provided, use it
|
||||
if bias is not None:
|
||||
shift = shift - bias
|
||||
return x * scale - shift
|
||||
# return ((x - mean) / ((var + eps) ** 0.5)) * gain + bias # The unfused way.
|
||||
|
||||
|
||||
# Manual BN
|
||||
# Calculate means and variances using mean-of-squares minus mean-squared
|
||||
def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5):
|
||||
# Cast x to float32 if necessary
|
||||
float_x = x.float()
|
||||
# Calculate expected value of x (m) and expected value of x**2 (m2)
|
||||
# Mean of x
|
||||
m = torch.mean(float_x, [0, 2, 3], keepdim=True)
|
||||
# Mean of x squared
|
||||
m2 = torch.mean(float_x ** 2, [0, 2, 3], keepdim=True)
|
||||
# Calculate variance as mean of squared minus mean squared.
|
||||
var = (m2 - m ** 2)
|
||||
# Cast back to float 16 if necessary
|
||||
var = var.type(x.type())
|
||||
m = m.type(x.type())
|
||||
# Return mean and variance for updating stored mean/var if requested
|
||||
if return_mean_var:
|
||||
return fused_bn(x, m, var, gain, bias, eps), m.squeeze(), var.squeeze()
|
||||
else:
|
||||
return fused_bn(x, m, var, gain, bias, eps)
|
||||
|
||||
|
||||
# My batchnorm, supports standing stats
|
||||
class myBN(nn.Module):
|
||||
def __init__(self, num_channels, eps=1e-5, momentum=0.1):
|
||||
super(myBN, self).__init__()
|
||||
# momentum for updating running stats
|
||||
self.momentum = momentum
|
||||
# epsilon to avoid dividing by 0
|
||||
self.eps = eps
|
||||
# Momentum
|
||||
self.momentum = momentum
|
||||
# Register buffers
|
||||
self.register_buffer('stored_mean', torch.zeros(num_channels))
|
||||
self.register_buffer('stored_var', torch.ones(num_channels))
|
||||
self.register_buffer('accumulation_counter', torch.zeros(1))
|
||||
# Accumulate running means and vars
|
||||
self.accumulate_standing = False
|
||||
|
||||
# reset standing stats
|
||||
def reset_stats(self):
|
||||
self.stored_mean[:] = 0
|
||||
self.stored_var[:] = 0
|
||||
self.accumulation_counter[:] = 0
|
||||
|
||||
def forward(self, x, gain, bias):
|
||||
if self.training:
|
||||
out, mean, var = manual_bn(x, gain, bias, return_mean_var=True, eps=self.eps)
|
||||
# If accumulating standing stats, increment them
|
||||
if self.accumulate_standing:
|
||||
self.stored_mean[:] = self.stored_mean + mean.data
|
||||
self.stored_var[:] = self.stored_var + var.data
|
||||
self.accumulation_counter += 1.0
|
||||
# If not accumulating standing stats, take running averages
|
||||
else:
|
||||
self.stored_mean[:] = self.stored_mean * (1 - self.momentum) + mean * self.momentum
|
||||
self.stored_var[:] = self.stored_var * (1 - self.momentum) + var * self.momentum
|
||||
return out
|
||||
# If not in training mode, use the stored statistics
|
||||
else:
|
||||
mean = self.stored_mean.view(1, -1, 1, 1)
|
||||
var = self.stored_var.view(1, -1, 1, 1)
|
||||
# If using standing stats, divide them by the accumulation counter
|
||||
if self.accumulate_standing:
|
||||
mean = mean / self.accumulation_counter
|
||||
var = var / self.accumulation_counter
|
||||
return fused_bn(x, mean, var, gain, bias, self.eps)
|
||||
|
||||
|
||||
# Simple function to handle groupnorm norm stylization
|
||||
def groupnorm(x, norm_style):
|
||||
# If number of channels specified in norm_style:
|
||||
if 'ch' in norm_style:
|
||||
ch = int(norm_style.split('_')[-1])
|
||||
groups = max(int(x.shape[1]) // ch, 1)
|
||||
# If number of groups specified in norm style
|
||||
elif 'grp' in norm_style:
|
||||
groups = int(norm_style.split('_')[-1])
|
||||
# If neither, default to groups = 16
|
||||
else:
|
||||
groups = 16
|
||||
return F.group_norm(x, groups)
|
||||
|
||||
|
||||
# Class-conditional bn
|
||||
# output size is the number of channels, input size is for the linear layers
|
||||
# Andy's Note: this class feels messy but I'm not really sure how to clean it up
|
||||
# Suggestions welcome! (By which I mean, refactor this and make a pull request
|
||||
# if you want to make this more readable/usable).
|
||||
class ccbn(nn.Module):
|
||||
def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1,
|
||||
cross_replica=False, mybn=False, norm_style='bn', ):
|
||||
super(ccbn, self).__init__()
|
||||
self.output_size, self.input_size = output_size, input_size
|
||||
# Prepare gain and bias layers
|
||||
self.gain = which_linear(input_size, output_size)
|
||||
self.bias = which_linear(input_size, output_size)
|
||||
# epsilon to avoid dividing by 0
|
||||
self.eps = eps
|
||||
# Momentum
|
||||
self.momentum = momentum
|
||||
# Use cross-replica batchnorm?
|
||||
self.cross_replica = cross_replica
|
||||
# Use my batchnorm?
|
||||
self.mybn = mybn
|
||||
# Norm style?
|
||||
self.norm_style = norm_style
|
||||
|
||||
if self.cross_replica:
|
||||
self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False)
|
||||
elif self.mybn:
|
||||
self.bn = myBN(output_size, self.eps, self.momentum)
|
||||
elif self.norm_style in ['bn', 'in']:
|
||||
self.register_buffer('stored_mean', torch.zeros(output_size))
|
||||
self.register_buffer('stored_var', torch.ones(output_size))
|
||||
|
||||
def forward(self, x, y):
|
||||
# Calculate class-conditional gains and biases
|
||||
gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
|
||||
bias = self.bias(y).view(y.size(0), -1, 1, 1)
|
||||
# If using my batchnorm
|
||||
if self.mybn or self.cross_replica:
|
||||
return self.bn(x, gain=gain, bias=bias)
|
||||
# else:
|
||||
else:
|
||||
if self.norm_style == 'bn':
|
||||
out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None,
|
||||
self.training, 0.1, self.eps)
|
||||
elif self.norm_style == 'in':
|
||||
out = F.instance_norm(x, self.stored_mean, self.stored_var, None, None,
|
||||
self.training, 0.1, self.eps)
|
||||
elif self.norm_style == 'gn':
|
||||
out = groupnorm(x, self.normstyle)
|
||||
elif self.norm_style == 'nonorm':
|
||||
out = x
|
||||
return out * gain + bias
|
||||
|
||||
def extra_repr(self):
|
||||
s = 'out: {output_size}, in: {input_size},'
|
||||
s += ' cross_replica={cross_replica}'
|
||||
return s.format(**self.__dict__)
|
||||
|
||||
|
||||
# Normal, non-class-conditional BN
|
||||
class bn(nn.Module):
|
||||
def __init__(self, output_size, eps=1e-5, momentum=0.1,
|
||||
cross_replica=False, mybn=False, norm_style=None):
|
||||
super(bn, self).__init__()
|
||||
self.output_size = output_size
|
||||
# Prepare gain and bias layers
|
||||
self.gain = P(torch.ones(output_size), requires_grad=True)
|
||||
self.bias = P(torch.zeros(output_size), requires_grad=True)
|
||||
# epsilon to avoid dividing by 0
|
||||
self.eps = eps
|
||||
# Momentum
|
||||
self.momentum = momentum
|
||||
# Use cross-replica batchnorm?
|
||||
self.cross_replica = cross_replica
|
||||
# Use my batchnorm?
|
||||
self.mybn = mybn
|
||||
|
||||
if self.cross_replica:
|
||||
self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False)
|
||||
elif mybn:
|
||||
self.bn = myBN(output_size, self.eps, self.momentum)
|
||||
# Register buffers if neither of the above
|
||||
else:
|
||||
self.register_buffer('stored_mean', torch.zeros(output_size))
|
||||
self.register_buffer('stored_var', torch.ones(output_size))
|
||||
|
||||
def forward(self, x, y=None):
|
||||
if self.cross_replica or self.mybn:
|
||||
gain = self.gain.view(1, -1, 1, 1)
|
||||
bias = self.bias.view(1, -1, 1, 1)
|
||||
return self.bn(x, gain=gain, bias=bias)
|
||||
else:
|
||||
return F.batch_norm(x, self.stored_mean, self.stored_var, self.gain,
|
||||
self.bias, self.training, self.momentum, self.eps)
|
||||
|
||||
|
||||
# Generator blocks
|
||||
# Note that this class assumes the kernel size and padding (and any other
|
||||
# settings) have been selected in the main generator module and passed in
|
||||
# through the which_conv arg. Similar rules apply with which_bn (the input
|
||||
# size [which is actually the number of channels of the conditional info] must
|
||||
# be preselected)
|
||||
class GBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels,
|
||||
which_conv=nn.Conv2d, which_bn=bn, activation=None,
|
||||
upsample=None):
|
||||
super(GBlock, self).__init__()
|
||||
|
||||
self.in_channels, self.out_channels = in_channels, out_channels
|
||||
self.which_conv, self.which_bn = which_conv, which_bn
|
||||
self.activation = activation
|
||||
self.upsample = upsample
|
||||
# Conv layers
|
||||
self.conv1 = self.which_conv(self.in_channels, self.out_channels)
|
||||
self.conv2 = self.which_conv(self.out_channels, self.out_channels)
|
||||
self.learnable_sc = in_channels != out_channels or upsample
|
||||
if self.learnable_sc:
|
||||
self.conv_sc = self.which_conv(in_channels, out_channels,
|
||||
kernel_size=1, padding=0)
|
||||
# Batchnorm layers
|
||||
self.bn1 = self.which_bn(in_channels)
|
||||
self.bn2 = self.which_bn(out_channels)
|
||||
# upsample layers
|
||||
self.upsample = upsample
|
||||
|
||||
def forward(self, x, y):
|
||||
h = self.activation(self.bn1(x, y))
|
||||
if self.upsample:
|
||||
h = self.upsample(h)
|
||||
x = self.upsample(x)
|
||||
h = self.conv1(h)
|
||||
h = self.activation(self.bn2(h, y))
|
||||
h = self.conv2(h)
|
||||
if self.learnable_sc:
|
||||
x = self.conv_sc(x)
|
||||
return h + x
|
||||
|
||||
|
||||
# Residual block for the discriminator
|
||||
class DBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, which_conv=SNConv2d, wide=True,
|
||||
preactivation=False, activation=None, downsample=None, ):
|
||||
super(DBlock, self).__init__()
|
||||
self.in_channels, self.out_channels = in_channels, out_channels
|
||||
# If using wide D (as in SA-GAN and BigGAN), change the channel pattern
|
||||
self.hidden_channels = self.out_channels if wide else self.in_channels
|
||||
self.which_conv = which_conv
|
||||
self.preactivation = preactivation
|
||||
self.activation = activation
|
||||
self.downsample = downsample
|
||||
|
||||
# Conv layers
|
||||
self.conv1 = self.which_conv(self.in_channels, self.hidden_channels)
|
||||
self.conv2 = self.which_conv(self.hidden_channels, self.out_channels)
|
||||
self.learnable_sc = True if (in_channels != out_channels) or downsample else False
|
||||
if self.learnable_sc:
|
||||
self.conv_sc = self.which_conv(in_channels, out_channels,
|
||||
kernel_size=1, padding=0)
|
||||
|
||||
def shortcut(self, x):
|
||||
if self.preactivation:
|
||||
if self.learnable_sc:
|
||||
x = self.conv_sc(x)
|
||||
if self.downsample:
|
||||
x = self.downsample(x)
|
||||
else:
|
||||
if self.downsample:
|
||||
x = self.downsample(x)
|
||||
if self.learnable_sc:
|
||||
x = self.conv_sc(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
if self.preactivation:
|
||||
# h = self.activation(x) # NOT TODAY SATAN
|
||||
# Andy's note: This line *must* be an out-of-place ReLU or it
|
||||
# will negatively affect the shortcut connection.
|
||||
h = F.relu(x)
|
||||
else:
|
||||
h = x
|
||||
h = self.conv1(h)
|
||||
h = self.conv2(self.activation(h))
|
||||
if self.downsample:
|
||||
h = self.downsample(h)
|
||||
|
||||
return h + self.shortcut(x)
|
||||
|
||||
# dogball
|
|
@ -1,489 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# File : batchnorm.py
|
||||
# Author : Jiayuan Mao
|
||||
# Email : maojiayuan@gmail.com
|
||||
# Date : 27/01/2018
|
||||
#
|
||||
# This file is part of Synchronized-BatchNorm-PyTorch.
|
||||
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
||||
# Distributed under MIT License.
|
||||
|
||||
import collections
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
|
||||
|
||||
__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
|
||||
|
||||
|
||||
def _sum_ft(tensor):
|
||||
"""sum over the first and last dimention"""
|
||||
return tensor.sum(dim=0).sum(dim=-1)
|
||||
|
||||
|
||||
def _unsqueeze_ft(tensor):
|
||||
"""add new dementions at the front and the tail"""
|
||||
return tensor.unsqueeze(0).unsqueeze(-1)
|
||||
|
||||
|
||||
_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
|
||||
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
|
||||
|
||||
|
||||
# _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'ssum', 'sum_size'])
|
||||
|
||||
class _SynchronizedBatchNorm(_BatchNorm):
|
||||
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
|
||||
super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
|
||||
|
||||
self._sync_master = SyncMaster(self._data_parallel_master)
|
||||
|
||||
self._is_parallel = False
|
||||
self._parallel_id = None
|
||||
self._slave_pipe = None
|
||||
|
||||
def forward(self, input, gain=None, bias=None):
|
||||
# If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
|
||||
if not (self._is_parallel and self.training):
|
||||
out = F.batch_norm(
|
||||
input, self.running_mean, self.running_var, self.weight, self.bias,
|
||||
self.training, self.momentum, self.eps)
|
||||
if gain is not None:
|
||||
out = out + gain
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
return out
|
||||
|
||||
# Resize the input to (B, C, -1).
|
||||
input_shape = input.size()
|
||||
# print(input_shape)
|
||||
input = input.view(input.size(0), input.size(1), -1)
|
||||
|
||||
# Compute the sum and square-sum.
|
||||
sum_size = input.size(0) * input.size(2)
|
||||
input_sum = _sum_ft(input)
|
||||
input_ssum = _sum_ft(input ** 2)
|
||||
# Reduce-and-broadcast the statistics.
|
||||
# print('it begins')
|
||||
if self._parallel_id == 0:
|
||||
mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
|
||||
else:
|
||||
mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
|
||||
# if self._parallel_id == 0:
|
||||
# # print('here')
|
||||
# sum, ssum, num = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
|
||||
# else:
|
||||
# # print('there')
|
||||
# sum, ssum, num = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
|
||||
|
||||
# print('how2')
|
||||
# num = sum_size
|
||||
# print('Sum: %f, ssum: %f, sumsize: %f, insum: %f' %(float(sum.sum().cpu()), float(ssum.sum().cpu()), float(sum_size), float(input_sum.sum().cpu())))
|
||||
# Fix the graph
|
||||
# sum = (sum.detach() - input_sum.detach()) + input_sum
|
||||
# ssum = (ssum.detach() - input_ssum.detach()) + input_ssum
|
||||
|
||||
# mean = sum / num
|
||||
# var = ssum / num - mean ** 2
|
||||
# # var = (ssum - mean * sum) / num
|
||||
# inv_std = torch.rsqrt(var + self.eps)
|
||||
|
||||
# Compute the output.
|
||||
if gain is not None:
|
||||
# print('gaining')
|
||||
# scale = _unsqueeze_ft(inv_std) * gain.squeeze(-1)
|
||||
# shift = _unsqueeze_ft(mean) * scale - bias.squeeze(-1)
|
||||
# output = input * scale - shift
|
||||
output = (input - _unsqueeze_ft(mean)) * (_unsqueeze_ft(inv_std) * gain.squeeze(-1)) + bias.squeeze(-1)
|
||||
elif self.affine:
|
||||
# MJY:: Fuse the multiplication for speed.
|
||||
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
|
||||
else:
|
||||
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
|
||||
|
||||
# Reshape it.
|
||||
return output.view(input_shape)
|
||||
|
||||
def __data_parallel_replicate__(self, ctx, copy_id):
|
||||
self._is_parallel = True
|
||||
self._parallel_id = copy_id
|
||||
|
||||
# parallel_id == 0 means master device.
|
||||
if self._parallel_id == 0:
|
||||
ctx.sync_master = self._sync_master
|
||||
else:
|
||||
self._slave_pipe = ctx.sync_master.register_slave(copy_id)
|
||||
|
||||
def _data_parallel_master(self, intermediates):
|
||||
"""Reduce the sum and square-sum, compute the statistics, and broadcast it."""
|
||||
|
||||
# Always using same "device order" makes the ReduceAdd operation faster.
|
||||
# Thanks to:: Tete Xiao (http://tetexiao.com/)
|
||||
intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
|
||||
|
||||
to_reduce = [i[1][:2] for i in intermediates]
|
||||
to_reduce = [j for i in to_reduce for j in i] # flatten
|
||||
target_gpus = [i[1].sum.get_device() for i in intermediates]
|
||||
|
||||
sum_size = sum([i[1].sum_size for i in intermediates])
|
||||
sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
|
||||
mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
|
||||
|
||||
broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
|
||||
# print('a')
|
||||
# print(type(sum_), type(ssum), type(sum_size), sum_.shape, ssum.shape, sum_size)
|
||||
# broadcasted = Broadcast.apply(target_gpus, sum_, ssum, torch.tensor(sum_size).float().to(sum_.device))
|
||||
# print('b')
|
||||
outputs = []
|
||||
for i, rec in enumerate(intermediates):
|
||||
outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2])))
|
||||
# outputs.append((rec[0], _MasterMessage(*broadcasted[i*3:i*3+3])))
|
||||
|
||||
return outputs
|
||||
|
||||
def _compute_mean_std(self, sum_, ssum, size):
|
||||
"""Compute the mean and standard-deviation with sum and square-sum. This method
|
||||
also maintains the moving average on the master device."""
|
||||
assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
|
||||
mean = sum_ / size
|
||||
sumvar = ssum - sum_ * mean
|
||||
unbias_var = sumvar / (size - 1)
|
||||
bias_var = sumvar / size
|
||||
|
||||
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
|
||||
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
|
||||
return mean, torch.rsqrt(bias_var + self.eps)
|
||||
# return mean, bias_var.clamp(self.eps) ** -0.5
|
||||
|
||||
|
||||
class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
|
||||
r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
|
||||
mini-batch.
|
||||
|
||||
.. math::
|
||||
|
||||
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
||||
|
||||
This module differs from the built-in PyTorch BatchNorm1d as the mean and
|
||||
standard-deviation are reduced across all devices during training.
|
||||
|
||||
For example, when one uses `nn.DataParallel` to wrap the network during
|
||||
training, PyTorch's implementation normalize the tensor on each device using
|
||||
the statistics only on that device, which accelerated the computation and
|
||||
is also easy to implement, but the statistics might be inaccurate.
|
||||
Instead, in this synchronized version, the statistics will be computed
|
||||
over all training samples distributed on multiple devices.
|
||||
|
||||
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
||||
as the built-in PyTorch implementation.
|
||||
|
||||
The mean and standard-deviation are calculated per-dimension over
|
||||
the mini-batches and gamma and beta are learnable parameter vectors
|
||||
of size C (where C is the input size).
|
||||
|
||||
During training, this layer keeps a running estimate of its computed mean
|
||||
and variance. The running sum is kept with a default momentum of 0.1.
|
||||
|
||||
During evaluation, this running mean/variance is used for normalization.
|
||||
|
||||
Because the BatchNorm is done over the `C` dimension, computing statistics
|
||||
on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
|
||||
|
||||
Args:
|
||||
num_features: num_features from an expected input of size
|
||||
`batch_size x num_features [x width]`
|
||||
eps: a value added to the denominator for numerical stability.
|
||||
Default: 1e-5
|
||||
momentum: the value used for the running_mean and running_var
|
||||
computation. Default: 0.1
|
||||
affine: a boolean value that when set to ``True``, gives the layer learnable
|
||||
affine parameters. Default: ``True``
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, C)` or :math:`(N, C, L)`
|
||||
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
|
||||
|
||||
Examples:
|
||||
>>> # With Learnable Parameters
|
||||
>>> m = SynchronizedBatchNorm1d(100)
|
||||
>>> # Without Learnable Parameters
|
||||
>>> m = SynchronizedBatchNorm1d(100, affine=False)
|
||||
>>> input = torch.autograd.Variable(torch.randn(20, 100))
|
||||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
def _check_input_dim(self, input):
|
||||
if input.dim() != 2 and input.dim() != 3:
|
||||
raise ValueError('expected 2D or 3D input (got {}D input)'
|
||||
.format(input.dim()))
|
||||
super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
|
||||
|
||||
|
||||
class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
|
||||
r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
|
||||
of 3d inputs
|
||||
|
||||
.. math::
|
||||
|
||||
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
||||
|
||||
This module differs from the built-in PyTorch BatchNorm2d as the mean and
|
||||
standard-deviation are reduced across all devices during training.
|
||||
|
||||
For example, when one uses `nn.DataParallel` to wrap the network during
|
||||
training, PyTorch's implementation normalize the tensor on each device using
|
||||
the statistics only on that device, which accelerated the computation and
|
||||
is also easy to implement, but the statistics might be inaccurate.
|
||||
Instead, in this synchronized version, the statistics will be computed
|
||||
over all training samples distributed on multiple devices.
|
||||
|
||||
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
||||
as the built-in PyTorch implementation.
|
||||
|
||||
The mean and standard-deviation are calculated per-dimension over
|
||||
the mini-batches and gamma and beta are learnable parameter vectors
|
||||
of size C (where C is the input size).
|
||||
|
||||
During training, this layer keeps a running estimate of its computed mean
|
||||
and variance. The running sum is kept with a default momentum of 0.1.
|
||||
|
||||
During evaluation, this running mean/variance is used for normalization.
|
||||
|
||||
Because the BatchNorm is done over the `C` dimension, computing statistics
|
||||
on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
|
||||
|
||||
Args:
|
||||
num_features: num_features from an expected input of
|
||||
size batch_size x num_features x height x width
|
||||
eps: a value added to the denominator for numerical stability.
|
||||
Default: 1e-5
|
||||
momentum: the value used for the running_mean and running_var
|
||||
computation. Default: 0.1
|
||||
affine: a boolean value that when set to ``True``, gives the layer learnable
|
||||
affine parameters. Default: ``True``
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, C, H, W)`
|
||||
- Output: :math:`(N, C, H, W)` (same shape as input)
|
||||
|
||||
Examples:
|
||||
>>> # With Learnable Parameters
|
||||
>>> m = SynchronizedBatchNorm2d(100)
|
||||
>>> # Without Learnable Parameters
|
||||
>>> m = SynchronizedBatchNorm2d(100, affine=False)
|
||||
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
|
||||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
def _check_input_dim(self, input):
|
||||
if input.dim() != 4:
|
||||
raise ValueError('expected 4D input (got {}D input)'
|
||||
.format(input.dim()))
|
||||
super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
|
||||
|
||||
|
||||
class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
|
||||
r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
|
||||
of 4d inputs
|
||||
|
||||
.. math::
|
||||
|
||||
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
||||
|
||||
This module differs from the built-in PyTorch BatchNorm3d as the mean and
|
||||
standard-deviation are reduced across all devices during training.
|
||||
|
||||
For example, when one uses `nn.DataParallel` to wrap the network during
|
||||
training, PyTorch's implementation normalize the tensor on each device using
|
||||
the statistics only on that device, which accelerated the computation and
|
||||
is also easy to implement, but the statistics might be inaccurate.
|
||||
Instead, in this synchronized version, the statistics will be computed
|
||||
over all training samples distributed on multiple devices.
|
||||
|
||||
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
||||
as the built-in PyTorch implementation.
|
||||
|
||||
The mean and standard-deviation are calculated per-dimension over
|
||||
the mini-batches and gamma and beta are learnable parameter vectors
|
||||
of size C (where C is the input size).
|
||||
|
||||
During training, this layer keeps a running estimate of its computed mean
|
||||
and variance. The running sum is kept with a default momentum of 0.1.
|
||||
|
||||
During evaluation, this running mean/variance is used for normalization.
|
||||
|
||||
Because the BatchNorm is done over the `C` dimension, computing statistics
|
||||
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
|
||||
or Spatio-temporal BatchNorm
|
||||
|
||||
Args:
|
||||
num_features: num_features from an expected input of
|
||||
size batch_size x num_features x depth x height x width
|
||||
eps: a value added to the denominator for numerical stability.
|
||||
Default: 1e-5
|
||||
momentum: the value used for the running_mean and running_var
|
||||
computation. Default: 0.1
|
||||
affine: a boolean value that when set to ``True``, gives the layer learnable
|
||||
affine parameters. Default: ``True``
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, C, D, H, W)`
|
||||
- Output: :math:`(N, C, D, H, W)` (same shape as input)
|
||||
|
||||
Examples:
|
||||
>>> # With Learnable Parameters
|
||||
>>> m = SynchronizedBatchNorm3d(100)
|
||||
>>> # Without Learnable Parameters
|
||||
>>> m = SynchronizedBatchNorm3d(100, affine=False)
|
||||
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
|
||||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
def _check_input_dim(self, input):
|
||||
if input.dim() != 5:
|
||||
raise ValueError('expected 5D input (got {}D input)'
|
||||
.format(input.dim()))
|
||||
super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
|
||||
|
||||
|
||||
# From ccomm.py
|
||||
# -*- coding: utf-8 -*-
|
||||
# File : comm.py
|
||||
# Author : Jiayuan Mao
|
||||
# Email : maojiayuan@gmail.com
|
||||
# Date : 27/01/2018
|
||||
#
|
||||
# This file is part of Synchronized-BatchNorm-PyTorch.
|
||||
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
||||
# Distributed under MIT License.
|
||||
|
||||
import queue
|
||||
import collections
|
||||
import threading
|
||||
|
||||
__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
|
||||
|
||||
|
||||
class FutureResult(object):
|
||||
"""A thread-safe future implementation. Used only as one-to-one pipe."""
|
||||
|
||||
def __init__(self):
|
||||
self._result = None
|
||||
self._lock = threading.Lock()
|
||||
self._cond = threading.Condition(self._lock)
|
||||
|
||||
def put(self, result):
|
||||
with self._lock:
|
||||
assert self._result is None, 'Previous result has\'t been fetched.'
|
||||
self._result = result
|
||||
self._cond.notify()
|
||||
|
||||
def get(self):
|
||||
with self._lock:
|
||||
if self._result is None:
|
||||
self._cond.wait()
|
||||
|
||||
res = self._result
|
||||
self._result = None
|
||||
return res
|
||||
|
||||
|
||||
_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
|
||||
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
|
||||
|
||||
|
||||
class SlavePipe(_SlavePipeBase):
|
||||
"""Pipe for master-slave communication."""
|
||||
|
||||
def run_slave(self, msg):
|
||||
self.queue.put((self.identifier, msg))
|
||||
ret = self.result.get()
|
||||
self.queue.put(True)
|
||||
return ret
|
||||
|
||||
|
||||
class SyncMaster(object):
|
||||
"""An abstract `SyncMaster` object.
|
||||
|
||||
- During the replication, as the data parallel will trigger an callback of each module, all slave devices should
|
||||
call `register(id)` and obtain an `SlavePipe` to communicate with the master.
|
||||
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
|
||||
and passed to a registered callback.
|
||||
- After receiving the messages, the master device should gather the information and determine to message passed
|
||||
back to each slave devices.
|
||||
"""
|
||||
|
||||
def __init__(self, master_callback):
|
||||
"""
|
||||
|
||||
Args:
|
||||
master_callback: a callback to be invoked after having collected messages from slave devices.
|
||||
"""
|
||||
self._master_callback = master_callback
|
||||
self._queue = queue.Queue()
|
||||
self._registry = collections.OrderedDict()
|
||||
self._activated = False
|
||||
|
||||
def __getstate__(self):
|
||||
return {'master_callback': self._master_callback}
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__init__(state['master_callback'])
|
||||
|
||||
def register_slave(self, identifier):
|
||||
"""
|
||||
Register an slave device.
|
||||
|
||||
Args:
|
||||
identifier: an identifier, usually is the device id.
|
||||
|
||||
Returns: a `SlavePipe` object which can be used to communicate with the master device.
|
||||
|
||||
"""
|
||||
if self._activated:
|
||||
assert self._queue.empty(), 'Queue is not clean before next initialization.'
|
||||
self._activated = False
|
||||
self._registry.clear()
|
||||
future = FutureResult()
|
||||
self._registry[identifier] = _MasterRegistry(future)
|
||||
return SlavePipe(identifier, self._queue, future)
|
||||
|
||||
def run_master(self, master_msg):
|
||||
"""
|
||||
Main entry for the master device in each forward pass.
|
||||
The messages were first collected from each devices (including the master device), and then
|
||||
an callback will be invoked to compute the message to be sent back to each devices
|
||||
(including the master device).
|
||||
|
||||
Args:
|
||||
master_msg: the message that the master want to send to itself. This will be placed as the first
|
||||
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
|
||||
|
||||
Returns: the message to be sent back to the master device.
|
||||
|
||||
"""
|
||||
self._activated = True
|
||||
|
||||
intermediates = [(0, master_msg)]
|
||||
for i in range(self.nr_slaves):
|
||||
intermediates.append(self._queue.get())
|
||||
|
||||
results = self._master_callback(intermediates)
|
||||
assert results[0][0] == 0, 'The first result should belongs to the master.'
|
||||
|
||||
for i, res in results:
|
||||
if i == 0:
|
||||
continue
|
||||
self._registry[i].result.put(res)
|
||||
|
||||
for i in range(self.nr_slaves):
|
||||
assert self._queue.get() is True
|
||||
|
||||
return results[0][1]
|
||||
|
||||
@property
|
||||
def nr_slaves(self):
|
||||
return len(self._registry)
|
Loading…
Reference in New Issue
Block a user