From ddfd7f67a0f424f0720b6eae8864c9d1ecbe5671 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 16 Jun 2020 11:21:44 -0600 Subject: [PATCH] Get rid of biggan Not really sure it's a great fit for what is being done here. --- codes/models/archs/biggan_gen_arch.py | 209 --------- codes/models/archs/biggan_layers.py | 459 ------------------ codes/models/archs/biggan_sync_batchnorm.py | 489 -------------------- 3 files changed, 1157 deletions(-) delete mode 100644 codes/models/archs/biggan_gen_arch.py delete mode 100644 codes/models/archs/biggan_layers.py delete mode 100644 codes/models/archs/biggan_sync_batchnorm.py diff --git a/codes/models/archs/biggan_gen_arch.py b/codes/models/archs/biggan_gen_arch.py deleted file mode 100644 index 9ba4f704..00000000 --- a/codes/models/archs/biggan_gen_arch.py +++ /dev/null @@ -1,209 +0,0 @@ -# Source: https://github.com/ajbrock/BigGAN-PyTorch/blob/master/BigGANdeep.py -import functools - -import torch -import torch.nn as nn -from torch.nn import init -import torch.nn.functional as F - -import models.archs.biggan_layers as layers - -# BigGAN-deep: uses a different resblock and pattern - -# Architectures for G -# Attention is passed in in the format '32_64' to mean applying an attention -# block at both resolution 32x32 and 64x64. Just '64' will apply at 64x64. - -# Channel ratio is the ratio of -class GBlock(nn.Module): - def __init__(self, in_channels, out_channels, - which_conv=nn.Conv2d, which_bn=layers.bn, activation=None, - upsample=None, channel_ratio=4): - super(GBlock, self).__init__() - - self.in_channels, self.out_channels = in_channels, out_channels - self.hidden_channels = self.in_channels // channel_ratio - self.which_conv, self.which_bn = which_conv, which_bn - self.activation = activation - # Conv layers - self.conv1 = self.which_conv(self.in_channels, self.hidden_channels, - kernel_size=1, padding=0) - self.conv2 = self.which_conv(self.hidden_channels, self.hidden_channels) - self.conv3 = self.which_conv(self.hidden_channels, self.hidden_channels) - self.conv4 = self.which_conv(self.hidden_channels, self.out_channels, - kernel_size=1, padding=0) - # Batchnorm layers - self.bn1 = self.which_bn(self.in_channels) - self.bn2 = self.which_bn(self.hidden_channels) - self.bn3 = self.which_bn(self.hidden_channels) - self.bn4 = self.which_bn(self.hidden_channels) - # upsample layers - self.upsample = upsample - - def forward(self, x): - # Project down to channel ratio - h = self.conv1(self.activation(self.bn1(x))) - # Apply next BN-ReLU - h = self.activation(self.bn2(h)) - # Drop channels in x if necessary - if self.in_channels != self.out_channels: - x = x[:, :self.out_channels] - # Upsample both h and x at this point - if self.upsample: - h = self.upsample(h) - x = self.upsample(x) - # 3x3 convs - h = self.conv2(h) - h = self.conv3(self.activation(self.bn3(h))) - # Final 1x1 conv - h = self.conv4(self.activation(self.bn4(h))) - return h + x - - -def G_arch(ch=64, attention='64', base_width=64): - arch = {} - arch[128] = {'in_channels': [ch * item for item in [2, 2, 1, 1]], - 'out_channels': [ch * item for item in [2, 1, 1, 1]], - 'upsample': [False, True, False, False], - 'resolution': [base_width, base_width, base_width*2, base_width*2], - 'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')]) - for i in range(3, 8)}} - - return arch - - -class Generator(nn.Module): - def __init__(self, G_ch=64, G_depth=2, bottom_width=4, resolution=128, - G_kernel_size=3, G_attn='64', - num_G_SVs=1, num_G_SV_itrs=1, hier=False, - cross_replica=False, mybn=False, - G_activation=nn.ReLU(inplace=False), - BN_eps=1e-5, SN_eps=1e-12, - G_init='ortho', skip_init=False, - G_param='SN', norm_style='bn'): - super(Generator, self).__init__() - # Channel width multiplier - self.ch = G_ch - # Number of resblocks per stage - self.G_depth = G_depth - # The initial spatial dimensions - self.bottom_width = bottom_width - # Resolution of the output - self.resolution = resolution - # Kernel size? - self.kernel_size = G_kernel_size - # Attention? - self.attention = G_attn - # Hierarchical latent space? - self.hier = hier - # Cross replica batchnorm? - self.cross_replica = cross_replica - # Use my batchnorm? - self.mybn = mybn - # nonlinearity for residual blocks - self.activation = G_activation - # Initialization style - self.init = G_init - # Parameterization style - self.G_param = G_param - # Normalization style - self.norm_style = norm_style - # Epsilon for BatchNorm? - self.BN_eps = BN_eps - # Epsilon for Spectral Norm? - self.SN_eps = SN_eps - # Architecture dict - self.arch = G_arch(self.ch, self.attention)[resolution] - - # Which convs, batchnorms, and linear layers to use - if self.G_param == 'SN': - self.which_conv = functools.partial(layers.SNConv2d, - kernel_size=3, padding=1, - num_svs=num_G_SVs, num_itrs=num_G_SV_itrs, - eps=self.SN_eps) - else: - self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1) - - self.which_bn = functools.partial(layers.bn, - cross_replica=self.cross_replica, - mybn=self.mybn, - norm_style=self.norm_style, - eps=self.BN_eps) - - # Prepare model - # First conv layer to project into feature-space - self.initial_conv = nn.Sequential(self.which_conv(3, self.arch['in_channels'][0]), - layers.bn(self.arch['in_channels'][0], - cross_replica=self.cross_replica, - mybn=self.mybn), - self.activation - ) - - # self.blocks is a doubly-nested list of modules, the outer loop intended - # to be over blocks at a given resolution (resblocks and/or self-attention) - # while the inner loop is over a given block - self.blocks = [] - for index in range(len(self.arch['out_channels'])): - self.blocks += [[GBlock(in_channels=self.arch['in_channels'][index], - out_channels=self.arch['in_channels'][index] if g_index == 0 else - self.arch['out_channels'][index], - which_conv=self.which_conv, - which_bn=self.which_bn, - activation=self.activation, - upsample=(functools.partial(F.interpolate, scale_factor=2) - if self.arch['upsample'][index] and g_index == ( - self.G_depth - 1) else None))] - for g_index in range(self.G_depth)] - - # If attention on this block, attach it to the end - if self.arch['attention'][self.arch['resolution'][index]]: - print('Adding attention layer in G at resolution %d' % self.arch['resolution'][index]) - self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], self.which_conv)] - - # Turn self.blocks into a ModuleList so that it's all properly registered. - self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) - - # output layer: batchnorm-relu-conv. - # Consider using a non-spectral conv here - self.output_layer = nn.Sequential(layers.bn(self.arch['out_channels'][-1], - cross_replica=self.cross_replica, - mybn=self.mybn), - self.activation, - self.which_conv(self.arch['out_channels'][-1], 3)) - - # Initialize weights. Optionally skip init for testing. - if not skip_init: - self.init_weights() - - # Initialize - def init_weights(self): - self.param_count = 0 - for module in self.modules(): - if (isinstance(module, nn.Conv2d) - or isinstance(module, nn.Linear) - or isinstance(module, nn.Embedding)): - if self.init == 'ortho': - init.orthogonal_(module.weight) - elif self.init == 'N02': - init.normal_(module.weight, 0, 0.02) - elif self.init in ['glorot', 'xavier']: - init.xavier_uniform_(module.weight) - else: - print('Init style not recognized...') - self.param_count += sum([p.data.nelement() for p in module.parameters()]) - print('Param count for G''s initialized parameters: %d' % self.param_count) - - def forward(self, z): - # First conv layer to convert into correct filter-space. - h = self.initial_conv(z) - # Loop over blocks - for index, blocklist in enumerate(self.blocks): - # Second inner loop in case block has multiple layers - for block in blocklist: - h = block(h) - - # Apply batchnorm-relu-conv-tanh at output - return (torch.tanh(self.output_layer(h)), ) - -def biggan_medium(num_filters): - return Generator(num_filters) \ No newline at end of file diff --git a/codes/models/archs/biggan_layers.py b/codes/models/archs/biggan_layers.py deleted file mode 100644 index 58e24fc4..00000000 --- a/codes/models/archs/biggan_layers.py +++ /dev/null @@ -1,459 +0,0 @@ -''' Layers - This file contains various layers for the BigGAN models. -''' -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.nn import Parameter as P - - -# Projection of x onto y -def proj(x, y): - return torch.mm(y, x.t()) * y / torch.mm(y, y.t()) - - -# Orthogonalize x wrt list of vectors ys -def gram_schmidt(x, ys): - for y in ys: - x = x - proj(x, y) - return x - - -# Apply num_itrs steps of the power method to estimate top N singular values. -def power_iteration(W, u_, update=True, eps=1e-12): - # Lists holding singular vectors and values - us, vs, svs = [], [], [] - for i, u in enumerate(u_): - # Run one step of the power iteration - with torch.no_grad(): - v = torch.matmul(u, W) - # Run Gram-Schmidt to subtract components of all other singular vectors - v = F.normalize(gram_schmidt(v, vs), eps=eps) - # Add to the list - vs += [v] - # Update the other singular vector - u = torch.matmul(v, W.t()) - # Run Gram-Schmidt to subtract components of all other singular vectors - u = F.normalize(gram_schmidt(u, us), eps=eps) - # Add to the list - us += [u] - if update: - u_[i][:] = u - # Compute this singular value and add it to the list - svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))] - # svs += [torch.sum(F.linear(u, W.transpose(0, 1)) * v)] - return svs, us, vs - - -# Convenience passthrough function -class identity(nn.Module): - def forward(self, input): - return input - - -# Spectral normalization base class -class SN(object): - def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12): - # Number of power iterations per step - self.num_itrs = num_itrs - # Number of singular values - self.num_svs = num_svs - # Transposed? - self.transpose = transpose - # Epsilon value for avoiding divide-by-0 - self.eps = eps - # Register a singular vector for each sv - for i in range(self.num_svs): - self.register_buffer('u%d' % i, torch.randn(1, num_outputs)) - self.register_buffer('sv%d' % i, torch.ones(1)) - - # Singular vectors (u side) - @property - def u(self): - return [getattr(self, 'u%d' % i) for i in range(self.num_svs)] - - # Singular values; - # note that these buffers are just for logging and are not used in training. - @property - def sv(self): - return [getattr(self, 'sv%d' % i) for i in range(self.num_svs)] - - # Compute the spectrally-normalized weight - def W_(self): - W_mat = self.weight.view(self.weight.size(0), -1) - if self.transpose: - W_mat = W_mat.t() - # Apply num_itrs power iterations - for _ in range(self.num_itrs): - svs, us, vs = power_iteration(W_mat, self.u, update=self.training, eps=self.eps) - # Update the svs - if self.training: - with torch.no_grad(): # Make sure to do this in a no_grad() context or you'll get memory leaks! - for i, sv in enumerate(svs): - self.sv[i][:] = sv - return self.weight / svs[0] - - -# 2D Conv layer with spectral norm -class SNConv2d(nn.Conv2d, SN): - def __init__(self, in_channels, out_channels, kernel_size, stride=1, - padding=0, dilation=1, groups=1, bias=True, - num_svs=1, num_itrs=1, eps=1e-12): - nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride, - padding, dilation, groups, bias) - SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps) - - def forward(self, x): - return F.conv2d(x, self.W_(), self.bias, self.stride, - self.padding, self.dilation, self.groups) - - -# Linear layer with spectral norm -class SNLinear(nn.Linear, SN): - def __init__(self, in_features, out_features, bias=True, - num_svs=1, num_itrs=1, eps=1e-12): - nn.Linear.__init__(self, in_features, out_features, bias) - SN.__init__(self, num_svs, num_itrs, out_features, eps=eps) - - def forward(self, x): - return F.linear(x, self.W_(), self.bias) - - -# Embedding layer with spectral norm -# We use num_embeddings as the dim instead of embedding_dim here -# for convenience sake -class SNEmbedding(nn.Embedding, SN): - def __init__(self, num_embeddings, embedding_dim, padding_idx=None, - max_norm=None, norm_type=2, scale_grad_by_freq=False, - sparse=False, _weight=None, - num_svs=1, num_itrs=1, eps=1e-12): - nn.Embedding.__init__(self, num_embeddings, embedding_dim, padding_idx, - max_norm, norm_type, scale_grad_by_freq, - sparse, _weight) - SN.__init__(self, num_svs, num_itrs, num_embeddings, eps=eps) - - def forward(self, x): - return F.embedding(x, self.W_()) - - -# A non-local block as used in SA-GAN -# Note that the implementation as described in the paper is largely incorrect; -# refer to the released code for the actual implementation. -class Attention(nn.Module): - def __init__(self, ch, which_conv=SNConv2d, name='attention'): - super(Attention, self).__init__() - # Channel multiplier - self.ch = ch - self.which_conv = which_conv - self.theta = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False) - self.phi = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False) - self.g = self.which_conv(self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False) - self.o = self.which_conv(self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False) - # Learnable gain parameter - self.gamma = P(torch.tensor(0.), requires_grad=True) - - def forward(self, x, y=None): - # Apply convs - theta = self.theta(x) - phi = F.max_pool2d(self.phi(x), [2, 2]) - g = F.max_pool2d(self.g(x), [2, 2]) - # Perform reshapes - theta = theta.view(-1, self.ch // 8, x.shape[2] * x.shape[3]) - phi = phi.view(-1, self.ch // 8, x.shape[2] * x.shape[3] // 4) - g = g.view(-1, self.ch // 2, x.shape[2] * x.shape[3] // 4) - # Matmul and softmax to get attention maps - beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1) - # Attention map times g path - o = self.o(torch.bmm(g, beta.transpose(1, 2)).view(-1, self.ch // 2, x.shape[2], x.shape[3])) - return self.gamma * o + x - - -# Fused batchnorm op -def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5): - # Apply scale and shift--if gain and bias are provided, fuse them here - # Prepare scale - scale = torch.rsqrt(var + eps) - # If a gain is provided, use it - if gain is not None: - scale = scale * gain - # Prepare shift - shift = mean * scale - # If bias is provided, use it - if bias is not None: - shift = shift - bias - return x * scale - shift - # return ((x - mean) / ((var + eps) ** 0.5)) * gain + bias # The unfused way. - - -# Manual BN -# Calculate means and variances using mean-of-squares minus mean-squared -def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5): - # Cast x to float32 if necessary - float_x = x.float() - # Calculate expected value of x (m) and expected value of x**2 (m2) - # Mean of x - m = torch.mean(float_x, [0, 2, 3], keepdim=True) - # Mean of x squared - m2 = torch.mean(float_x ** 2, [0, 2, 3], keepdim=True) - # Calculate variance as mean of squared minus mean squared. - var = (m2 - m ** 2) - # Cast back to float 16 if necessary - var = var.type(x.type()) - m = m.type(x.type()) - # Return mean and variance for updating stored mean/var if requested - if return_mean_var: - return fused_bn(x, m, var, gain, bias, eps), m.squeeze(), var.squeeze() - else: - return fused_bn(x, m, var, gain, bias, eps) - - -# My batchnorm, supports standing stats -class myBN(nn.Module): - def __init__(self, num_channels, eps=1e-5, momentum=0.1): - super(myBN, self).__init__() - # momentum for updating running stats - self.momentum = momentum - # epsilon to avoid dividing by 0 - self.eps = eps - # Momentum - self.momentum = momentum - # Register buffers - self.register_buffer('stored_mean', torch.zeros(num_channels)) - self.register_buffer('stored_var', torch.ones(num_channels)) - self.register_buffer('accumulation_counter', torch.zeros(1)) - # Accumulate running means and vars - self.accumulate_standing = False - - # reset standing stats - def reset_stats(self): - self.stored_mean[:] = 0 - self.stored_var[:] = 0 - self.accumulation_counter[:] = 0 - - def forward(self, x, gain, bias): - if self.training: - out, mean, var = manual_bn(x, gain, bias, return_mean_var=True, eps=self.eps) - # If accumulating standing stats, increment them - if self.accumulate_standing: - self.stored_mean[:] = self.stored_mean + mean.data - self.stored_var[:] = self.stored_var + var.data - self.accumulation_counter += 1.0 - # If not accumulating standing stats, take running averages - else: - self.stored_mean[:] = self.stored_mean * (1 - self.momentum) + mean * self.momentum - self.stored_var[:] = self.stored_var * (1 - self.momentum) + var * self.momentum - return out - # If not in training mode, use the stored statistics - else: - mean = self.stored_mean.view(1, -1, 1, 1) - var = self.stored_var.view(1, -1, 1, 1) - # If using standing stats, divide them by the accumulation counter - if self.accumulate_standing: - mean = mean / self.accumulation_counter - var = var / self.accumulation_counter - return fused_bn(x, mean, var, gain, bias, self.eps) - - -# Simple function to handle groupnorm norm stylization -def groupnorm(x, norm_style): - # If number of channels specified in norm_style: - if 'ch' in norm_style: - ch = int(norm_style.split('_')[-1]) - groups = max(int(x.shape[1]) // ch, 1) - # If number of groups specified in norm style - elif 'grp' in norm_style: - groups = int(norm_style.split('_')[-1]) - # If neither, default to groups = 16 - else: - groups = 16 - return F.group_norm(x, groups) - - -# Class-conditional bn -# output size is the number of channels, input size is for the linear layers -# Andy's Note: this class feels messy but I'm not really sure how to clean it up -# Suggestions welcome! (By which I mean, refactor this and make a pull request -# if you want to make this more readable/usable). -class ccbn(nn.Module): - def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1, - cross_replica=False, mybn=False, norm_style='bn', ): - super(ccbn, self).__init__() - self.output_size, self.input_size = output_size, input_size - # Prepare gain and bias layers - self.gain = which_linear(input_size, output_size) - self.bias = which_linear(input_size, output_size) - # epsilon to avoid dividing by 0 - self.eps = eps - # Momentum - self.momentum = momentum - # Use cross-replica batchnorm? - self.cross_replica = cross_replica - # Use my batchnorm? - self.mybn = mybn - # Norm style? - self.norm_style = norm_style - - if self.cross_replica: - self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False) - elif self.mybn: - self.bn = myBN(output_size, self.eps, self.momentum) - elif self.norm_style in ['bn', 'in']: - self.register_buffer('stored_mean', torch.zeros(output_size)) - self.register_buffer('stored_var', torch.ones(output_size)) - - def forward(self, x, y): - # Calculate class-conditional gains and biases - gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1) - bias = self.bias(y).view(y.size(0), -1, 1, 1) - # If using my batchnorm - if self.mybn or self.cross_replica: - return self.bn(x, gain=gain, bias=bias) - # else: - else: - if self.norm_style == 'bn': - out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None, - self.training, 0.1, self.eps) - elif self.norm_style == 'in': - out = F.instance_norm(x, self.stored_mean, self.stored_var, None, None, - self.training, 0.1, self.eps) - elif self.norm_style == 'gn': - out = groupnorm(x, self.normstyle) - elif self.norm_style == 'nonorm': - out = x - return out * gain + bias - - def extra_repr(self): - s = 'out: {output_size}, in: {input_size},' - s += ' cross_replica={cross_replica}' - return s.format(**self.__dict__) - - -# Normal, non-class-conditional BN -class bn(nn.Module): - def __init__(self, output_size, eps=1e-5, momentum=0.1, - cross_replica=False, mybn=False, norm_style=None): - super(bn, self).__init__() - self.output_size = output_size - # Prepare gain and bias layers - self.gain = P(torch.ones(output_size), requires_grad=True) - self.bias = P(torch.zeros(output_size), requires_grad=True) - # epsilon to avoid dividing by 0 - self.eps = eps - # Momentum - self.momentum = momentum - # Use cross-replica batchnorm? - self.cross_replica = cross_replica - # Use my batchnorm? - self.mybn = mybn - - if self.cross_replica: - self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False) - elif mybn: - self.bn = myBN(output_size, self.eps, self.momentum) - # Register buffers if neither of the above - else: - self.register_buffer('stored_mean', torch.zeros(output_size)) - self.register_buffer('stored_var', torch.ones(output_size)) - - def forward(self, x, y=None): - if self.cross_replica or self.mybn: - gain = self.gain.view(1, -1, 1, 1) - bias = self.bias.view(1, -1, 1, 1) - return self.bn(x, gain=gain, bias=bias) - else: - return F.batch_norm(x, self.stored_mean, self.stored_var, self.gain, - self.bias, self.training, self.momentum, self.eps) - - -# Generator blocks -# Note that this class assumes the kernel size and padding (and any other -# settings) have been selected in the main generator module and passed in -# through the which_conv arg. Similar rules apply with which_bn (the input -# size [which is actually the number of channels of the conditional info] must -# be preselected) -class GBlock(nn.Module): - def __init__(self, in_channels, out_channels, - which_conv=nn.Conv2d, which_bn=bn, activation=None, - upsample=None): - super(GBlock, self).__init__() - - self.in_channels, self.out_channels = in_channels, out_channels - self.which_conv, self.which_bn = which_conv, which_bn - self.activation = activation - self.upsample = upsample - # Conv layers - self.conv1 = self.which_conv(self.in_channels, self.out_channels) - self.conv2 = self.which_conv(self.out_channels, self.out_channels) - self.learnable_sc = in_channels != out_channels or upsample - if self.learnable_sc: - self.conv_sc = self.which_conv(in_channels, out_channels, - kernel_size=1, padding=0) - # Batchnorm layers - self.bn1 = self.which_bn(in_channels) - self.bn2 = self.which_bn(out_channels) - # upsample layers - self.upsample = upsample - - def forward(self, x, y): - h = self.activation(self.bn1(x, y)) - if self.upsample: - h = self.upsample(h) - x = self.upsample(x) - h = self.conv1(h) - h = self.activation(self.bn2(h, y)) - h = self.conv2(h) - if self.learnable_sc: - x = self.conv_sc(x) - return h + x - - -# Residual block for the discriminator -class DBlock(nn.Module): - def __init__(self, in_channels, out_channels, which_conv=SNConv2d, wide=True, - preactivation=False, activation=None, downsample=None, ): - super(DBlock, self).__init__() - self.in_channels, self.out_channels = in_channels, out_channels - # If using wide D (as in SA-GAN and BigGAN), change the channel pattern - self.hidden_channels = self.out_channels if wide else self.in_channels - self.which_conv = which_conv - self.preactivation = preactivation - self.activation = activation - self.downsample = downsample - - # Conv layers - self.conv1 = self.which_conv(self.in_channels, self.hidden_channels) - self.conv2 = self.which_conv(self.hidden_channels, self.out_channels) - self.learnable_sc = True if (in_channels != out_channels) or downsample else False - if self.learnable_sc: - self.conv_sc = self.which_conv(in_channels, out_channels, - kernel_size=1, padding=0) - - def shortcut(self, x): - if self.preactivation: - if self.learnable_sc: - x = self.conv_sc(x) - if self.downsample: - x = self.downsample(x) - else: - if self.downsample: - x = self.downsample(x) - if self.learnable_sc: - x = self.conv_sc(x) - return x - - def forward(self, x): - if self.preactivation: - # h = self.activation(x) # NOT TODAY SATAN - # Andy's note: This line *must* be an out-of-place ReLU or it - # will negatively affect the shortcut connection. - h = F.relu(x) - else: - h = x - h = self.conv1(h) - h = self.conv2(self.activation(h)) - if self.downsample: - h = self.downsample(h) - - return h + self.shortcut(x) - -# dogball \ No newline at end of file diff --git a/codes/models/archs/biggan_sync_batchnorm.py b/codes/models/archs/biggan_sync_batchnorm.py deleted file mode 100644 index 42e9c852..00000000 --- a/codes/models/archs/biggan_sync_batchnorm.py +++ /dev/null @@ -1,489 +0,0 @@ -# -*- coding: utf-8 -*- -# File : batchnorm.py -# Author : Jiayuan Mao -# Email : maojiayuan@gmail.com -# Date : 27/01/2018 -# -# This file is part of Synchronized-BatchNorm-PyTorch. -# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch -# Distributed under MIT License. - -import collections - -import torch -import torch.nn.functional as F - -from torch.nn.modules.batchnorm import _BatchNorm -from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast - -__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] - - -def _sum_ft(tensor): - """sum over the first and last dimention""" - return tensor.sum(dim=0).sum(dim=-1) - - -def _unsqueeze_ft(tensor): - """add new dementions at the front and the tail""" - return tensor.unsqueeze(0).unsqueeze(-1) - - -_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) -_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) - - -# _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'ssum', 'sum_size']) - -class _SynchronizedBatchNorm(_BatchNorm): - def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): - super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) - - self._sync_master = SyncMaster(self._data_parallel_master) - - self._is_parallel = False - self._parallel_id = None - self._slave_pipe = None - - def forward(self, input, gain=None, bias=None): - # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. - if not (self._is_parallel and self.training): - out = F.batch_norm( - input, self.running_mean, self.running_var, self.weight, self.bias, - self.training, self.momentum, self.eps) - if gain is not None: - out = out + gain - if bias is not None: - out = out + bias - return out - - # Resize the input to (B, C, -1). - input_shape = input.size() - # print(input_shape) - input = input.view(input.size(0), input.size(1), -1) - - # Compute the sum and square-sum. - sum_size = input.size(0) * input.size(2) - input_sum = _sum_ft(input) - input_ssum = _sum_ft(input ** 2) - # Reduce-and-broadcast the statistics. - # print('it begins') - if self._parallel_id == 0: - mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) - else: - mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) - # if self._parallel_id == 0: - # # print('here') - # sum, ssum, num = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) - # else: - # # print('there') - # sum, ssum, num = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) - - # print('how2') - # num = sum_size - # print('Sum: %f, ssum: %f, sumsize: %f, insum: %f' %(float(sum.sum().cpu()), float(ssum.sum().cpu()), float(sum_size), float(input_sum.sum().cpu()))) - # Fix the graph - # sum = (sum.detach() - input_sum.detach()) + input_sum - # ssum = (ssum.detach() - input_ssum.detach()) + input_ssum - - # mean = sum / num - # var = ssum / num - mean ** 2 - # # var = (ssum - mean * sum) / num - # inv_std = torch.rsqrt(var + self.eps) - - # Compute the output. - if gain is not None: - # print('gaining') - # scale = _unsqueeze_ft(inv_std) * gain.squeeze(-1) - # shift = _unsqueeze_ft(mean) * scale - bias.squeeze(-1) - # output = input * scale - shift - output = (input - _unsqueeze_ft(mean)) * (_unsqueeze_ft(inv_std) * gain.squeeze(-1)) + bias.squeeze(-1) - elif self.affine: - # MJY:: Fuse the multiplication for speed. - output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) - else: - output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) - - # Reshape it. - return output.view(input_shape) - - def __data_parallel_replicate__(self, ctx, copy_id): - self._is_parallel = True - self._parallel_id = copy_id - - # parallel_id == 0 means master device. - if self._parallel_id == 0: - ctx.sync_master = self._sync_master - else: - self._slave_pipe = ctx.sync_master.register_slave(copy_id) - - def _data_parallel_master(self, intermediates): - """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" - - # Always using same "device order" makes the ReduceAdd operation faster. - # Thanks to:: Tete Xiao (http://tetexiao.com/) - intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) - - to_reduce = [i[1][:2] for i in intermediates] - to_reduce = [j for i in to_reduce for j in i] # flatten - target_gpus = [i[1].sum.get_device() for i in intermediates] - - sum_size = sum([i[1].sum_size for i in intermediates]) - sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) - mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) - - broadcasted = Broadcast.apply(target_gpus, mean, inv_std) - # print('a') - # print(type(sum_), type(ssum), type(sum_size), sum_.shape, ssum.shape, sum_size) - # broadcasted = Broadcast.apply(target_gpus, sum_, ssum, torch.tensor(sum_size).float().to(sum_.device)) - # print('b') - outputs = [] - for i, rec in enumerate(intermediates): - outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2]))) - # outputs.append((rec[0], _MasterMessage(*broadcasted[i*3:i*3+3]))) - - return outputs - - def _compute_mean_std(self, sum_, ssum, size): - """Compute the mean and standard-deviation with sum and square-sum. This method - also maintains the moving average on the master device.""" - assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' - mean = sum_ / size - sumvar = ssum - sum_ * mean - unbias_var = sumvar / (size - 1) - bias_var = sumvar / size - - self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data - self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data - return mean, torch.rsqrt(bias_var + self.eps) - # return mean, bias_var.clamp(self.eps) ** -0.5 - - -class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): - r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a - mini-batch. - - .. math:: - - y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta - - This module differs from the built-in PyTorch BatchNorm1d as the mean and - standard-deviation are reduced across all devices during training. - - For example, when one uses `nn.DataParallel` to wrap the network during - training, PyTorch's implementation normalize the tensor on each device using - the statistics only on that device, which accelerated the computation and - is also easy to implement, but the statistics might be inaccurate. - Instead, in this synchronized version, the statistics will be computed - over all training samples distributed on multiple devices. - - Note that, for one-GPU or CPU-only case, this module behaves exactly same - as the built-in PyTorch implementation. - - The mean and standard-deviation are calculated per-dimension over - the mini-batches and gamma and beta are learnable parameter vectors - of size C (where C is the input size). - - During training, this layer keeps a running estimate of its computed mean - and variance. The running sum is kept with a default momentum of 0.1. - - During evaluation, this running mean/variance is used for normalization. - - Because the BatchNorm is done over the `C` dimension, computing statistics - on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm - - Args: - num_features: num_features from an expected input of size - `batch_size x num_features [x width]` - eps: a value added to the denominator for numerical stability. - Default: 1e-5 - momentum: the value used for the running_mean and running_var - computation. Default: 0.1 - affine: a boolean value that when set to ``True``, gives the layer learnable - affine parameters. Default: ``True`` - - Shape: - - Input: :math:`(N, C)` or :math:`(N, C, L)` - - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) - - Examples: - >>> # With Learnable Parameters - >>> m = SynchronizedBatchNorm1d(100) - >>> # Without Learnable Parameters - >>> m = SynchronizedBatchNorm1d(100, affine=False) - >>> input = torch.autograd.Variable(torch.randn(20, 100)) - >>> output = m(input) - """ - - def _check_input_dim(self, input): - if input.dim() != 2 and input.dim() != 3: - raise ValueError('expected 2D or 3D input (got {}D input)' - .format(input.dim())) - super(SynchronizedBatchNorm1d, self)._check_input_dim(input) - - -class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): - r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch - of 3d inputs - - .. math:: - - y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta - - This module differs from the built-in PyTorch BatchNorm2d as the mean and - standard-deviation are reduced across all devices during training. - - For example, when one uses `nn.DataParallel` to wrap the network during - training, PyTorch's implementation normalize the tensor on each device using - the statistics only on that device, which accelerated the computation and - is also easy to implement, but the statistics might be inaccurate. - Instead, in this synchronized version, the statistics will be computed - over all training samples distributed on multiple devices. - - Note that, for one-GPU or CPU-only case, this module behaves exactly same - as the built-in PyTorch implementation. - - The mean and standard-deviation are calculated per-dimension over - the mini-batches and gamma and beta are learnable parameter vectors - of size C (where C is the input size). - - During training, this layer keeps a running estimate of its computed mean - and variance. The running sum is kept with a default momentum of 0.1. - - During evaluation, this running mean/variance is used for normalization. - - Because the BatchNorm is done over the `C` dimension, computing statistics - on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm - - Args: - num_features: num_features from an expected input of - size batch_size x num_features x height x width - eps: a value added to the denominator for numerical stability. - Default: 1e-5 - momentum: the value used for the running_mean and running_var - computation. Default: 0.1 - affine: a boolean value that when set to ``True``, gives the layer learnable - affine parameters. Default: ``True`` - - Shape: - - Input: :math:`(N, C, H, W)` - - Output: :math:`(N, C, H, W)` (same shape as input) - - Examples: - >>> # With Learnable Parameters - >>> m = SynchronizedBatchNorm2d(100) - >>> # Without Learnable Parameters - >>> m = SynchronizedBatchNorm2d(100, affine=False) - >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) - >>> output = m(input) - """ - - def _check_input_dim(self, input): - if input.dim() != 4: - raise ValueError('expected 4D input (got {}D input)' - .format(input.dim())) - super(SynchronizedBatchNorm2d, self)._check_input_dim(input) - - -class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): - r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch - of 4d inputs - - .. math:: - - y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta - - This module differs from the built-in PyTorch BatchNorm3d as the mean and - standard-deviation are reduced across all devices during training. - - For example, when one uses `nn.DataParallel` to wrap the network during - training, PyTorch's implementation normalize the tensor on each device using - the statistics only on that device, which accelerated the computation and - is also easy to implement, but the statistics might be inaccurate. - Instead, in this synchronized version, the statistics will be computed - over all training samples distributed on multiple devices. - - Note that, for one-GPU or CPU-only case, this module behaves exactly same - as the built-in PyTorch implementation. - - The mean and standard-deviation are calculated per-dimension over - the mini-batches and gamma and beta are learnable parameter vectors - of size C (where C is the input size). - - During training, this layer keeps a running estimate of its computed mean - and variance. The running sum is kept with a default momentum of 0.1. - - During evaluation, this running mean/variance is used for normalization. - - Because the BatchNorm is done over the `C` dimension, computing statistics - on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm - or Spatio-temporal BatchNorm - - Args: - num_features: num_features from an expected input of - size batch_size x num_features x depth x height x width - eps: a value added to the denominator for numerical stability. - Default: 1e-5 - momentum: the value used for the running_mean and running_var - computation. Default: 0.1 - affine: a boolean value that when set to ``True``, gives the layer learnable - affine parameters. Default: ``True`` - - Shape: - - Input: :math:`(N, C, D, H, W)` - - Output: :math:`(N, C, D, H, W)` (same shape as input) - - Examples: - >>> # With Learnable Parameters - >>> m = SynchronizedBatchNorm3d(100) - >>> # Without Learnable Parameters - >>> m = SynchronizedBatchNorm3d(100, affine=False) - >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) - >>> output = m(input) - """ - - def _check_input_dim(self, input): - if input.dim() != 5: - raise ValueError('expected 5D input (got {}D input)' - .format(input.dim())) - super(SynchronizedBatchNorm3d, self)._check_input_dim(input) - - -# From ccomm.py -# -*- coding: utf-8 -*- -# File : comm.py -# Author : Jiayuan Mao -# Email : maojiayuan@gmail.com -# Date : 27/01/2018 -# -# This file is part of Synchronized-BatchNorm-PyTorch. -# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch -# Distributed under MIT License. - -import queue -import collections -import threading - -__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] - - -class FutureResult(object): - """A thread-safe future implementation. Used only as one-to-one pipe.""" - - def __init__(self): - self._result = None - self._lock = threading.Lock() - self._cond = threading.Condition(self._lock) - - def put(self, result): - with self._lock: - assert self._result is None, 'Previous result has\'t been fetched.' - self._result = result - self._cond.notify() - - def get(self): - with self._lock: - if self._result is None: - self._cond.wait() - - res = self._result - self._result = None - return res - - -_MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) -_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) - - -class SlavePipe(_SlavePipeBase): - """Pipe for master-slave communication.""" - - def run_slave(self, msg): - self.queue.put((self.identifier, msg)) - ret = self.result.get() - self.queue.put(True) - return ret - - -class SyncMaster(object): - """An abstract `SyncMaster` object. - - - During the replication, as the data parallel will trigger an callback of each module, all slave devices should - call `register(id)` and obtain an `SlavePipe` to communicate with the master. - - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, - and passed to a registered callback. - - After receiving the messages, the master device should gather the information and determine to message passed - back to each slave devices. - """ - - def __init__(self, master_callback): - """ - - Args: - master_callback: a callback to be invoked after having collected messages from slave devices. - """ - self._master_callback = master_callback - self._queue = queue.Queue() - self._registry = collections.OrderedDict() - self._activated = False - - def __getstate__(self): - return {'master_callback': self._master_callback} - - def __setstate__(self, state): - self.__init__(state['master_callback']) - - def register_slave(self, identifier): - """ - Register an slave device. - - Args: - identifier: an identifier, usually is the device id. - - Returns: a `SlavePipe` object which can be used to communicate with the master device. - - """ - if self._activated: - assert self._queue.empty(), 'Queue is not clean before next initialization.' - self._activated = False - self._registry.clear() - future = FutureResult() - self._registry[identifier] = _MasterRegistry(future) - return SlavePipe(identifier, self._queue, future) - - def run_master(self, master_msg): - """ - Main entry for the master device in each forward pass. - The messages were first collected from each devices (including the master device), and then - an callback will be invoked to compute the message to be sent back to each devices - (including the master device). - - Args: - master_msg: the message that the master want to send to itself. This will be placed as the first - message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. - - Returns: the message to be sent back to the master device. - - """ - self._activated = True - - intermediates = [(0, master_msg)] - for i in range(self.nr_slaves): - intermediates.append(self._queue.get()) - - results = self._master_callback(intermediates) - assert results[0][0] == 0, 'The first result should belongs to the master.' - - for i, res in results: - if i == 0: - continue - self._registry[i].result.put(res) - - for i in range(self.nr_slaves): - assert self._queue.get() is True - - return results[0][1] - - @property - def nr_slaves(self): - return len(self._registry) \ No newline at end of file