From 79593803f228f56470ea42caff7f92e71d4cb394 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 15 May 2020 07:40:45 -0600 Subject: [PATCH] biggan arch, initial work (not implemented) --- codes/models/archs/biggan_gen_arch.py | 284 ++++++++++++ codes/models/archs/biggan_layers.py | 464 ++++++++++++++++++++ codes/models/archs/biggan_sync_batchnorm.py | 351 +++++++++++++++ 3 files changed, 1099 insertions(+) create mode 100644 codes/models/archs/biggan_gen_arch.py create mode 100644 codes/models/archs/biggan_layers.py create mode 100644 codes/models/archs/biggan_sync_batchnorm.py diff --git a/codes/models/archs/biggan_gen_arch.py b/codes/models/archs/biggan_gen_arch.py new file mode 100644 index 00000000..6264967a --- /dev/null +++ b/codes/models/archs/biggan_gen_arch.py @@ -0,0 +1,284 @@ +# Source: https://github.com/ajbrock/BigGAN-PyTorch/blob/master/BigGANdeep.py +import numpy as np +import math +import functools + +import torch +import torch.nn as nn +from torch.nn import init +import torch.optim as optim +import torch.nn.functional as F +from torch.nn import Parameter as P + +import models.archs.biggan_layers as layers +from models.archs.biggan_sync_batchnorm import SynchronizedBatchNorm2d as SyncBatchNorm2d + +# BigGAN-deep: uses a different resblock and pattern + +# Architectures for G +# Attention is passed in in the format '32_64' to mean applying an attention +# block at both resolution 32x32 and 64x64. Just '64' will apply at 64x64. + +# Channel ratio is the ratio of +class GBlock(nn.Module): + def __init__(self, in_channels, out_channels, + which_conv=nn.Conv2d, which_bn=layers.bn, activation=None, + upsample=None, channel_ratio=4): + super(GBlock, self).__init__() + + self.in_channels, self.out_channels = in_channels, out_channels + self.hidden_channels = self.in_channels // channel_ratio + self.which_conv, self.which_bn = which_conv, which_bn + self.activation = activation + # Conv layers + self.conv1 = self.which_conv(self.in_channels, self.hidden_channels, + kernel_size=1, padding=0) + self.conv2 = self.which_conv(self.hidden_channels, self.hidden_channels) + self.conv3 = self.which_conv(self.hidden_channels, self.hidden_channels) + self.conv4 = self.which_conv(self.hidden_channels, self.out_channels, + kernel_size=1, padding=0) + # Batchnorm layers + self.bn1 = self.which_bn(self.in_channels) + self.bn2 = self.which_bn(self.hidden_channels) + self.bn3 = self.which_bn(self.hidden_channels) + self.bn4 = self.which_bn(self.hidden_channels) + # upsample layers + self.upsample = upsample + + def forward(self, x, y): + # Project down to channel ratio + h = self.conv1(self.activation(self.bn1(x, y))) + # Apply next BN-ReLU + h = self.activation(self.bn2(h, y)) + # Drop channels in x if necessary + if self.in_channels != self.out_channels: + x = x[:, :self.out_channels] + # Upsample both h and x at this point + if self.upsample: + h = self.upsample(h) + x = self.upsample(x) + # 3x3 convs + h = self.conv2(h) + h = self.conv3(self.activation(self.bn3(h, y))) + # Final 1x1 conv + h = self.conv4(self.activation(self.bn4(h, y))) + return h + x + + +def G_arch(ch=64, attention='64', ksize='333333', dilation='111111'): + arch = {} + arch[256] = {'in_channels': [ch * item for item in [16, 16, 8, 8, 4, 2]], + 'out_channels': [ch * item for item in [16, 8, 8, 4, 2, 1]], + 'upsample': [True] * 6, + 'resolution': [8, 16, 32, 64, 128, 256], + 'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')]) + for i in range(3, 9)}} + arch[128] = {'in_channels': [ch * item for item in [16, 16, 8, 4, 2]], + 'out_channels': [ch * item for item in [16, 8, 4, 2, 1]], + 'upsample': [True] * 5, + 'resolution': [8, 16, 32, 64, 128], + 'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')]) + for i in range(3, 8)}} + arch[64] = {'in_channels': [ch * item for item in [16, 16, 8, 4]], + 'out_channels': [ch * item for item in [16, 8, 4, 2]], + 'upsample': [True] * 4, + 'resolution': [8, 16, 32, 64], + 'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')]) + for i in range(3, 7)}} + arch[32] = {'in_channels': [ch * item for item in [4, 4, 4]], + 'out_channels': [ch * item for item in [4, 4, 4]], + 'upsample': [True] * 3, + 'resolution': [8, 16, 32], + 'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')]) + for i in range(3, 6)}} + + return arch + + +class Generator(nn.Module): + def __init__(self, G_ch=64, G_depth=2, dim_z=128, bottom_width=4, resolution=128, + G_kernel_size=3, G_attn='64', n_classes=1000, + num_G_SVs=1, num_G_SV_itrs=1, + G_shared=True, shared_dim=0, hier=False, + cross_replica=False, mybn=False, + G_activation=nn.ReLU(inplace=False), + G_lr=5e-5, G_B1=0.0, G_B2=0.999, adam_eps=1e-8, + BN_eps=1e-5, SN_eps=1e-12, G_mixed_precision=False, G_fp16=False, + G_init='ortho', skip_init=False, no_optim=False, + G_param='SN', norm_style='bn', + **kwargs): + super(Generator, self).__init__() + # Channel width mulitplier + self.ch = G_ch + # Number of resblocks per stage + self.G_depth = G_depth + # Dimensionality of the latent space + self.dim_z = dim_z + # The initial spatial dimensions + self.bottom_width = bottom_width + # Resolution of the output + self.resolution = resolution + # Kernel size? + self.kernel_size = G_kernel_size + # Attention? + self.attention = G_attn + # number of classes, for use in categorical conditional generation + self.n_classes = n_classes + # Use shared embeddings? + self.G_shared = G_shared + # Dimensionality of the shared embedding? Unused if not using G_shared + self.shared_dim = shared_dim if shared_dim > 0 else dim_z + # Hierarchical latent space? + self.hier = hier + # Cross replica batchnorm? + self.cross_replica = cross_replica + # Use my batchnorm? + self.mybn = mybn + # nonlinearity for residual blocks + self.activation = G_activation + # Initialization style + self.init = G_init + # Parameterization style + self.G_param = G_param + # Normalization style + self.norm_style = norm_style + # Epsilon for BatchNorm? + self.BN_eps = BN_eps + # Epsilon for Spectral Norm? + self.SN_eps = SN_eps + # fp16? + self.fp16 = G_fp16 + # Architecture dict + self.arch = G_arch(self.ch, self.attention)[resolution] + + # Which convs, batchnorms, and linear layers to use + if self.G_param == 'SN': + self.which_conv = functools.partial(layers.SNConv2d, + kernel_size=3, padding=1, + num_svs=num_G_SVs, num_itrs=num_G_SV_itrs, + eps=self.SN_eps) + self.which_linear = functools.partial(layers.SNLinear, + num_svs=num_G_SVs, num_itrs=num_G_SV_itrs, + eps=self.SN_eps) + else: + self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1) + self.which_linear = nn.Linear + + # We use a non-spectral-normed embedding here regardless; + # For some reason applying SN to G's embedding seems to randomly cripple G + self.which_embedding = nn.Embedding + bn_linear = (functools.partial(self.which_linear, bias=False) if self.G_shared + else self.which_embedding) + self.which_bn = functools.partial(layers.ccbn, + which_linear=bn_linear, + cross_replica=self.cross_replica, + mybn=self.mybn, + input_size=(self.shared_dim + self.dim_z if self.G_shared + else self.n_classes), + norm_style=self.norm_style, + eps=self.BN_eps) + + # Prepare model + # If not using shared embeddings, self.shared is just a passthrough + self.shared = (self.which_embedding(n_classes, self.shared_dim) if G_shared + else layers.identity()) + # First linear layer + self.linear = self.which_linear(self.dim_z + self.shared_dim, + self.arch['in_channels'][0] * (self.bottom_width ** 2)) + + # self.blocks is a doubly-nested list of modules, the outer loop intended + # to be over blocks at a given resolution (resblocks and/or self-attention) + # while the inner loop is over a given block + self.blocks = [] + for index in range(len(self.arch['out_channels'])): + self.blocks += [[GBlock(in_channels=self.arch['in_channels'][index], + out_channels=self.arch['in_channels'][index] if g_index == 0 else + self.arch['out_channels'][index], + which_conv=self.which_conv, + which_bn=self.which_bn, + activation=self.activation, + upsample=(functools.partial(F.interpolate, scale_factor=2) + if self.arch['upsample'][index] and g_index == ( + self.G_depth - 1) else None))] + for g_index in range(self.G_depth)] + + # If attention on this block, attach it to the end + if self.arch['attention'][self.arch['resolution'][index]]: + print('Adding attention layer in G at resolution %d' % self.arch['resolution'][index]) + self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], self.which_conv)] + + # Turn self.blocks into a ModuleList so that it's all properly registered. + self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) + + # output layer: batchnorm-relu-conv. + # Consider using a non-spectral conv here + self.output_layer = nn.Sequential(layers.bn(self.arch['out_channels'][-1], + cross_replica=self.cross_replica, + mybn=self.mybn), + self.activation, + self.which_conv(self.arch['out_channels'][-1], 3)) + + # Initialize weights. Optionally skip init for testing. + if not skip_init: + self.init_weights() + + # Set up optimizer + # If this is an EMA copy, no need for an optim, so just return now + if no_optim: + return + self.lr, self.B1, self.B2, self.adam_eps = G_lr, G_B1, G_B2, adam_eps + if G_mixed_precision: + print('Using fp16 adam in G...') + import utils + self.optim = utils.Adam16(params=self.parameters(), lr=self.lr, + betas=(self.B1, self.B2), weight_decay=0, + eps=self.adam_eps) + else: + self.optim = optim.Adam(params=self.parameters(), lr=self.lr, + betas=(self.B1, self.B2), weight_decay=0, + eps=self.adam_eps) + + # LR scheduling, left here for forward compatibility + # self.lr_sched = {'itr' : 0}# if self.progressive else {} + # self.j = 0 + + # Initialize + def init_weights(self): + self.param_count = 0 + for module in self.modules(): + if (isinstance(module, nn.Conv2d) + or isinstance(module, nn.Linear) + or isinstance(module, nn.Embedding)): + if self.init == 'ortho': + init.orthogonal_(module.weight) + elif self.init == 'N02': + init.normal_(module.weight, 0, 0.02) + elif self.init in ['glorot', 'xavier']: + init.xavier_uniform_(module.weight) + else: + print('Init style not recognized...') + self.param_count += sum([p.data.nelement() for p in module.parameters()]) + print('Param count for G''s initialized parameters: %d' % self.param_count) + + # Note on this forward function: we pass in a y vector which has + # already been passed through G.shared to enable easy class-wise + # interpolation later. If we passed in the one-hot and then ran it through + # G.shared in this forward function, it would be harder to handle. + # NOTE: The z vs y dichotomy here is for compatibility with not-y + def forward(self, z, y): + # If hierarchical, concatenate zs and ys + if self.hier: + z = torch.cat([y, z], 1) + y = z + # First linear layer + h = self.linear(z) + # Reshape + h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width) + # Loop over blocks + for index, blocklist in enumerate(self.blocks): + # Second inner loop in case block has multiple layers + for block in blocklist: + h = block(h, y) + + # Apply batchnorm-relu-conv-tanh at output + return torch.tanh(self.output_layer(h)) \ No newline at end of file diff --git a/codes/models/archs/biggan_layers.py b/codes/models/archs/biggan_layers.py new file mode 100644 index 00000000..8e1179e5 --- /dev/null +++ b/codes/models/archs/biggan_layers.py @@ -0,0 +1,464 @@ +''' Layers + This file contains various layers for the BigGAN models. +''' +import numpy as np +import torch +import torch.nn as nn +from torch.nn import init +import torch.optim as optim +import torch.nn.functional as F +from torch.nn import Parameter as P + +from sync_batchnorm import SynchronizedBatchNorm2d as SyncBN2d + + +# Projection of x onto y +def proj(x, y): + return torch.mm(y, x.t()) * y / torch.mm(y, y.t()) + + +# Orthogonalize x wrt list of vectors ys +def gram_schmidt(x, ys): + for y in ys: + x = x - proj(x, y) + return x + + +# Apply num_itrs steps of the power method to estimate top N singular values. +def power_iteration(W, u_, update=True, eps=1e-12): + # Lists holding singular vectors and values + us, vs, svs = [], [], [] + for i, u in enumerate(u_): + # Run one step of the power iteration + with torch.no_grad(): + v = torch.matmul(u, W) + # Run Gram-Schmidt to subtract components of all other singular vectors + v = F.normalize(gram_schmidt(v, vs), eps=eps) + # Add to the list + vs += [v] + # Update the other singular vector + u = torch.matmul(v, W.t()) + # Run Gram-Schmidt to subtract components of all other singular vectors + u = F.normalize(gram_schmidt(u, us), eps=eps) + # Add to the list + us += [u] + if update: + u_[i][:] = u + # Compute this singular value and add it to the list + svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))] + # svs += [torch.sum(F.linear(u, W.transpose(0, 1)) * v)] + return svs, us, vs + + +# Convenience passthrough function +class identity(nn.Module): + def forward(self, input): + return input + + +# Spectral normalization base class +class SN(object): + def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12): + # Number of power iterations per step + self.num_itrs = num_itrs + # Number of singular values + self.num_svs = num_svs + # Transposed? + self.transpose = transpose + # Epsilon value for avoiding divide-by-0 + self.eps = eps + # Register a singular vector for each sv + for i in range(self.num_svs): + self.register_buffer('u%d' % i, torch.randn(1, num_outputs)) + self.register_buffer('sv%d' % i, torch.ones(1)) + + # Singular vectors (u side) + @property + def u(self): + return [getattr(self, 'u%d' % i) for i in range(self.num_svs)] + + # Singular values; + # note that these buffers are just for logging and are not used in training. + @property + def sv(self): + return [getattr(self, 'sv%d' % i) for i in range(self.num_svs)] + + # Compute the spectrally-normalized weight + def W_(self): + W_mat = self.weight.view(self.weight.size(0), -1) + if self.transpose: + W_mat = W_mat.t() + # Apply num_itrs power iterations + for _ in range(self.num_itrs): + svs, us, vs = power_iteration(W_mat, self.u, update=self.training, eps=self.eps) + # Update the svs + if self.training: + with torch.no_grad(): # Make sure to do this in a no_grad() context or you'll get memory leaks! + for i, sv in enumerate(svs): + self.sv[i][:] = sv + return self.weight / svs[0] + + +# 2D Conv layer with spectral norm +class SNConv2d(nn.Conv2d, SN): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True, + num_svs=1, num_itrs=1, eps=1e-12): + nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias) + SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps) + + def forward(self, x): + return F.conv2d(x, self.W_(), self.bias, self.stride, + self.padding, self.dilation, self.groups) + + +# Linear layer with spectral norm +class SNLinear(nn.Linear, SN): + def __init__(self, in_features, out_features, bias=True, + num_svs=1, num_itrs=1, eps=1e-12): + nn.Linear.__init__(self, in_features, out_features, bias) + SN.__init__(self, num_svs, num_itrs, out_features, eps=eps) + + def forward(self, x): + return F.linear(x, self.W_(), self.bias) + + +# Embedding layer with spectral norm +# We use num_embeddings as the dim instead of embedding_dim here +# for convenience sake +class SNEmbedding(nn.Embedding, SN): + def __init__(self, num_embeddings, embedding_dim, padding_idx=None, + max_norm=None, norm_type=2, scale_grad_by_freq=False, + sparse=False, _weight=None, + num_svs=1, num_itrs=1, eps=1e-12): + nn.Embedding.__init__(self, num_embeddings, embedding_dim, padding_idx, + max_norm, norm_type, scale_grad_by_freq, + sparse, _weight) + SN.__init__(self, num_svs, num_itrs, num_embeddings, eps=eps) + + def forward(self, x): + return F.embedding(x, self.W_()) + + +# A non-local block as used in SA-GAN +# Note that the implementation as described in the paper is largely incorrect; +# refer to the released code for the actual implementation. +class Attention(nn.Module): + def __init__(self, ch, which_conv=SNConv2d, name='attention'): + super(Attention, self).__init__() + # Channel multiplier + self.ch = ch + self.which_conv = which_conv + self.theta = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False) + self.phi = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False) + self.g = self.which_conv(self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False) + self.o = self.which_conv(self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False) + # Learnable gain parameter + self.gamma = P(torch.tensor(0.), requires_grad=True) + + def forward(self, x, y=None): + # Apply convs + theta = self.theta(x) + phi = F.max_pool2d(self.phi(x), [2, 2]) + g = F.max_pool2d(self.g(x), [2, 2]) + # Perform reshapes + theta = theta.view(-1, self.ch // 8, x.shape[2] * x.shape[3]) + phi = phi.view(-1, self.ch // 8, x.shape[2] * x.shape[3] // 4) + g = g.view(-1, self.ch // 2, x.shape[2] * x.shape[3] // 4) + # Matmul and softmax to get attention maps + beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1) + # Attention map times g path + o = self.o(torch.bmm(g, beta.transpose(1, 2)).view(-1, self.ch // 2, x.shape[2], x.shape[3])) + return self.gamma * o + x + + +# Fused batchnorm op +def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5): + # Apply scale and shift--if gain and bias are provided, fuse them here + # Prepare scale + scale = torch.rsqrt(var + eps) + # If a gain is provided, use it + if gain is not None: + scale = scale * gain + # Prepare shift + shift = mean * scale + # If bias is provided, use it + if bias is not None: + shift = shift - bias + return x * scale - shift + # return ((x - mean) / ((var + eps) ** 0.5)) * gain + bias # The unfused way. + + +# Manual BN +# Calculate means and variances using mean-of-squares minus mean-squared +def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5): + # Cast x to float32 if necessary + float_x = x.float() + # Calculate expected value of x (m) and expected value of x**2 (m2) + # Mean of x + m = torch.mean(float_x, [0, 2, 3], keepdim=True) + # Mean of x squared + m2 = torch.mean(float_x ** 2, [0, 2, 3], keepdim=True) + # Calculate variance as mean of squared minus mean squared. + var = (m2 - m ** 2) + # Cast back to float 16 if necessary + var = var.type(x.type()) + m = m.type(x.type()) + # Return mean and variance for updating stored mean/var if requested + if return_mean_var: + return fused_bn(x, m, var, gain, bias, eps), m.squeeze(), var.squeeze() + else: + return fused_bn(x, m, var, gain, bias, eps) + + +# My batchnorm, supports standing stats +class myBN(nn.Module): + def __init__(self, num_channels, eps=1e-5, momentum=0.1): + super(myBN, self).__init__() + # momentum for updating running stats + self.momentum = momentum + # epsilon to avoid dividing by 0 + self.eps = eps + # Momentum + self.momentum = momentum + # Register buffers + self.register_buffer('stored_mean', torch.zeros(num_channels)) + self.register_buffer('stored_var', torch.ones(num_channels)) + self.register_buffer('accumulation_counter', torch.zeros(1)) + # Accumulate running means and vars + self.accumulate_standing = False + + # reset standing stats + def reset_stats(self): + self.stored_mean[:] = 0 + self.stored_var[:] = 0 + self.accumulation_counter[:] = 0 + + def forward(self, x, gain, bias): + if self.training: + out, mean, var = manual_bn(x, gain, bias, return_mean_var=True, eps=self.eps) + # If accumulating standing stats, increment them + if self.accumulate_standing: + self.stored_mean[:] = self.stored_mean + mean.data + self.stored_var[:] = self.stored_var + var.data + self.accumulation_counter += 1.0 + # If not accumulating standing stats, take running averages + else: + self.stored_mean[:] = self.stored_mean * (1 - self.momentum) + mean * self.momentum + self.stored_var[:] = self.stored_var * (1 - self.momentum) + var * self.momentum + return out + # If not in training mode, use the stored statistics + else: + mean = self.stored_mean.view(1, -1, 1, 1) + var = self.stored_var.view(1, -1, 1, 1) + # If using standing stats, divide them by the accumulation counter + if self.accumulate_standing: + mean = mean / self.accumulation_counter + var = var / self.accumulation_counter + return fused_bn(x, mean, var, gain, bias, self.eps) + + +# Simple function to handle groupnorm norm stylization +def groupnorm(x, norm_style): + # If number of channels specified in norm_style: + if 'ch' in norm_style: + ch = int(norm_style.split('_')[-1]) + groups = max(int(x.shape[1]) // ch, 1) + # If number of groups specified in norm style + elif 'grp' in norm_style: + groups = int(norm_style.split('_')[-1]) + # If neither, default to groups = 16 + else: + groups = 16 + return F.group_norm(x, groups) + + +# Class-conditional bn +# output size is the number of channels, input size is for the linear layers +# Andy's Note: this class feels messy but I'm not really sure how to clean it up +# Suggestions welcome! (By which I mean, refactor this and make a pull request +# if you want to make this more readable/usable). +class ccbn(nn.Module): + def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1, + cross_replica=False, mybn=False, norm_style='bn', ): + super(ccbn, self).__init__() + self.output_size, self.input_size = output_size, input_size + # Prepare gain and bias layers + self.gain = which_linear(input_size, output_size) + self.bias = which_linear(input_size, output_size) + # epsilon to avoid dividing by 0 + self.eps = eps + # Momentum + self.momentum = momentum + # Use cross-replica batchnorm? + self.cross_replica = cross_replica + # Use my batchnorm? + self.mybn = mybn + # Norm style? + self.norm_style = norm_style + + if self.cross_replica: + self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False) + elif self.mybn: + self.bn = myBN(output_size, self.eps, self.momentum) + elif self.norm_style in ['bn', 'in']: + self.register_buffer('stored_mean', torch.zeros(output_size)) + self.register_buffer('stored_var', torch.ones(output_size)) + + def forward(self, x, y): + # Calculate class-conditional gains and biases + gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1) + bias = self.bias(y).view(y.size(0), -1, 1, 1) + # If using my batchnorm + if self.mybn or self.cross_replica: + return self.bn(x, gain=gain, bias=bias) + # else: + else: + if self.norm_style == 'bn': + out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None, + self.training, 0.1, self.eps) + elif self.norm_style == 'in': + out = F.instance_norm(x, self.stored_mean, self.stored_var, None, None, + self.training, 0.1, self.eps) + elif self.norm_style == 'gn': + out = groupnorm(x, self.normstyle) + elif self.norm_style == 'nonorm': + out = x + return out * gain + bias + + def extra_repr(self): + s = 'out: {output_size}, in: {input_size},' + s += ' cross_replica={cross_replica}' + return s.format(**self.__dict__) + + +# Normal, non-class-conditional BN +class bn(nn.Module): + def __init__(self, output_size, eps=1e-5, momentum=0.1, + cross_replica=False, mybn=False): + super(bn, self).__init__() + self.output_size = output_size + # Prepare gain and bias layers + self.gain = P(torch.ones(output_size), requires_grad=True) + self.bias = P(torch.zeros(output_size), requires_grad=True) + # epsilon to avoid dividing by 0 + self.eps = eps + # Momentum + self.momentum = momentum + # Use cross-replica batchnorm? + self.cross_replica = cross_replica + # Use my batchnorm? + self.mybn = mybn + + if self.cross_replica: + self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False) + elif mybn: + self.bn = myBN(output_size, self.eps, self.momentum) + # Register buffers if neither of the above + else: + self.register_buffer('stored_mean', torch.zeros(output_size)) + self.register_buffer('stored_var', torch.ones(output_size)) + + def forward(self, x, y=None): + if self.cross_replica or self.mybn: + gain = self.gain.view(1, -1, 1, 1) + bias = self.bias.view(1, -1, 1, 1) + return self.bn(x, gain=gain, bias=bias) + else: + return F.batch_norm(x, self.stored_mean, self.stored_var, self.gain, + self.bias, self.training, self.momentum, self.eps) + + +# Generator blocks +# Note that this class assumes the kernel size and padding (and any other +# settings) have been selected in the main generator module and passed in +# through the which_conv arg. Similar rules apply with which_bn (the input +# size [which is actually the number of channels of the conditional info] must +# be preselected) +class GBlock(nn.Module): + def __init__(self, in_channels, out_channels, + which_conv=nn.Conv2d, which_bn=bn, activation=None, + upsample=None): + super(GBlock, self).__init__() + + self.in_channels, self.out_channels = in_channels, out_channels + self.which_conv, self.which_bn = which_conv, which_bn + self.activation = activation + self.upsample = upsample + # Conv layers + self.conv1 = self.which_conv(self.in_channels, self.out_channels) + self.conv2 = self.which_conv(self.out_channels, self.out_channels) + self.learnable_sc = in_channels != out_channels or upsample + if self.learnable_sc: + self.conv_sc = self.which_conv(in_channels, out_channels, + kernel_size=1, padding=0) + # Batchnorm layers + self.bn1 = self.which_bn(in_channels) + self.bn2 = self.which_bn(out_channels) + # upsample layers + self.upsample = upsample + + def forward(self, x, y): + h = self.activation(self.bn1(x, y)) + if self.upsample: + h = self.upsample(h) + x = self.upsample(x) + h = self.conv1(h) + h = self.activation(self.bn2(h, y)) + h = self.conv2(h) + if self.learnable_sc: + x = self.conv_sc(x) + return h + x + + +# Residual block for the discriminator +class DBlock(nn.Module): + def __init__(self, in_channels, out_channels, which_conv=SNConv2d, wide=True, + preactivation=False, activation=None, downsample=None, ): + super(DBlock, self).__init__() + self.in_channels, self.out_channels = in_channels, out_channels + # If using wide D (as in SA-GAN and BigGAN), change the channel pattern + self.hidden_channels = self.out_channels if wide else self.in_channels + self.which_conv = which_conv + self.preactivation = preactivation + self.activation = activation + self.downsample = downsample + + # Conv layers + self.conv1 = self.which_conv(self.in_channels, self.hidden_channels) + self.conv2 = self.which_conv(self.hidden_channels, self.out_channels) + self.learnable_sc = True if (in_channels != out_channels) or downsample else False + if self.learnable_sc: + self.conv_sc = self.which_conv(in_channels, out_channels, + kernel_size=1, padding=0) + + def shortcut(self, x): + if self.preactivation: + if self.learnable_sc: + x = self.conv_sc(x) + if self.downsample: + x = self.downsample(x) + else: + if self.downsample: + x = self.downsample(x) + if self.learnable_sc: + x = self.conv_sc(x) + return x + + def forward(self, x): + if self.preactivation: + # h = self.activation(x) # NOT TODAY SATAN + # Andy's note: This line *must* be an out-of-place ReLU or it + # will negatively affect the shortcut connection. + h = F.relu(x) + else: + h = x + h = self.conv1(h) + h = self.conv2(self.activation(h)) + if self.downsample: + h = self.downsample(h) + + return h + self.shortcut(x) + +# dogball \ No newline at end of file diff --git a/codes/models/archs/biggan_sync_batchnorm.py b/codes/models/archs/biggan_sync_batchnorm.py new file mode 100644 index 00000000..a55a75a9 --- /dev/null +++ b/codes/models/archs/biggan_sync_batchnorm.py @@ -0,0 +1,351 @@ +# -*- coding: utf-8 -*- +# File : batchnorm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import collections + +import torch +import torch.nn.functional as F + +from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast + +from .comm import SyncMaster + +__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] + + +def _sum_ft(tensor): + """sum over the first and last dimention""" + return tensor.sum(dim=0).sum(dim=-1) + + +def _unsqueeze_ft(tensor): + """add new dementions at the front and the tail""" + return tensor.unsqueeze(0).unsqueeze(-1) + + +_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) +_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) + + +# _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'ssum', 'sum_size']) + +class _SynchronizedBatchNorm(_BatchNorm): + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): + super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) + + self._sync_master = SyncMaster(self._data_parallel_master) + + self._is_parallel = False + self._parallel_id = None + self._slave_pipe = None + + def forward(self, input, gain=None, bias=None): + # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. + if not (self._is_parallel and self.training): + out = F.batch_norm( + input, self.running_mean, self.running_var, self.weight, self.bias, + self.training, self.momentum, self.eps) + if gain is not None: + out = out + gain + if bias is not None: + out = out + bias + return out + + # Resize the input to (B, C, -1). + input_shape = input.size() + # print(input_shape) + input = input.view(input.size(0), input.size(1), -1) + + # Compute the sum and square-sum. + sum_size = input.size(0) * input.size(2) + input_sum = _sum_ft(input) + input_ssum = _sum_ft(input ** 2) + # Reduce-and-broadcast the statistics. + # print('it begins') + if self._parallel_id == 0: + mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) + else: + mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) + # if self._parallel_id == 0: + # # print('here') + # sum, ssum, num = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) + # else: + # # print('there') + # sum, ssum, num = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) + + # print('how2') + # num = sum_size + # print('Sum: %f, ssum: %f, sumsize: %f, insum: %f' %(float(sum.sum().cpu()), float(ssum.sum().cpu()), float(sum_size), float(input_sum.sum().cpu()))) + # Fix the graph + # sum = (sum.detach() - input_sum.detach()) + input_sum + # ssum = (ssum.detach() - input_ssum.detach()) + input_ssum + + # mean = sum / num + # var = ssum / num - mean ** 2 + # # var = (ssum - mean * sum) / num + # inv_std = torch.rsqrt(var + self.eps) + + # Compute the output. + if gain is not None: + # print('gaining') + # scale = _unsqueeze_ft(inv_std) * gain.squeeze(-1) + # shift = _unsqueeze_ft(mean) * scale - bias.squeeze(-1) + # output = input * scale - shift + output = (input - _unsqueeze_ft(mean)) * (_unsqueeze_ft(inv_std) * gain.squeeze(-1)) + bias.squeeze(-1) + elif self.affine: + # MJY:: Fuse the multiplication for speed. + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) + else: + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) + + # Reshape it. + return output.view(input_shape) + + def __data_parallel_replicate__(self, ctx, copy_id): + self._is_parallel = True + self._parallel_id = copy_id + + # parallel_id == 0 means master device. + if self._parallel_id == 0: + ctx.sync_master = self._sync_master + else: + self._slave_pipe = ctx.sync_master.register_slave(copy_id) + + def _data_parallel_master(self, intermediates): + """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" + + # Always using same "device order" makes the ReduceAdd operation faster. + # Thanks to:: Tete Xiao (http://tetexiao.com/) + intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) + + to_reduce = [i[1][:2] for i in intermediates] + to_reduce = [j for i in to_reduce for j in i] # flatten + target_gpus = [i[1].sum.get_device() for i in intermediates] + + sum_size = sum([i[1].sum_size for i in intermediates]) + sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) + mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) + + broadcasted = Broadcast.apply(target_gpus, mean, inv_std) + # print('a') + # print(type(sum_), type(ssum), type(sum_size), sum_.shape, ssum.shape, sum_size) + # broadcasted = Broadcast.apply(target_gpus, sum_, ssum, torch.tensor(sum_size).float().to(sum_.device)) + # print('b') + outputs = [] + for i, rec in enumerate(intermediates): + outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2]))) + # outputs.append((rec[0], _MasterMessage(*broadcasted[i*3:i*3+3]))) + + return outputs + + def _compute_mean_std(self, sum_, ssum, size): + """Compute the mean and standard-deviation with sum and square-sum. This method + also maintains the moving average on the master device.""" + assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' + mean = sum_ / size + sumvar = ssum - sum_ * mean + unbias_var = sumvar / (size - 1) + bias_var = sumvar / size + + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data + self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data + return mean, torch.rsqrt(bias_var + self.eps) + # return mean, bias_var.clamp(self.eps) ** -0.5 + + +class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): + r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a + mini-batch. + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm1d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm + + Args: + num_features: num_features from an expected input of size + `batch_size x num_features [x width]` + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C)` or :math:`(N, C, L)` + - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 2 and input.dim() != 3: + raise ValueError('expected 2D or 3D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm1d, self)._check_input_dim(input) + + +class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch + of 3d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm2d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C, H, W)` + - Output: :math:`(N, C, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 4: + raise ValueError('expected 4D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm2d, self)._check_input_dim(input) + + +class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch + of 4d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm3d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm + or Spatio-temporal BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x depth x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C, D, H, W)` + - Output: :math:`(N, C, D, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 5: + raise ValueError('expected 5D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm3d, self)._check_input_dim(input) \ No newline at end of file