209 lines
9.1 KiB
Python
209 lines
9.1 KiB
Python
# Source: https://github.com/ajbrock/BigGAN-PyTorch/blob/master/BigGANdeep.py
|
|
import functools
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import init
|
|
import torch.nn.functional as F
|
|
|
|
import models.archs.biggan_layers as layers
|
|
|
|
# BigGAN-deep: uses a different resblock and pattern
|
|
|
|
# Architectures for G
|
|
# Attention is passed in in the format '32_64' to mean applying an attention
|
|
# block at both resolution 32x32 and 64x64. Just '64' will apply at 64x64.
|
|
|
|
# Channel ratio is the ratio of
|
|
class GBlock(nn.Module):
|
|
def __init__(self, in_channels, out_channels,
|
|
which_conv=nn.Conv2d, which_bn=layers.bn, activation=None,
|
|
upsample=None, channel_ratio=4):
|
|
super(GBlock, self).__init__()
|
|
|
|
self.in_channels, self.out_channels = in_channels, out_channels
|
|
self.hidden_channels = self.in_channels // channel_ratio
|
|
self.which_conv, self.which_bn = which_conv, which_bn
|
|
self.activation = activation
|
|
# Conv layers
|
|
self.conv1 = self.which_conv(self.in_channels, self.hidden_channels,
|
|
kernel_size=1, padding=0)
|
|
self.conv2 = self.which_conv(self.hidden_channels, self.hidden_channels)
|
|
self.conv3 = self.which_conv(self.hidden_channels, self.hidden_channels)
|
|
self.conv4 = self.which_conv(self.hidden_channels, self.out_channels,
|
|
kernel_size=1, padding=0)
|
|
# Batchnorm layers
|
|
self.bn1 = self.which_bn(self.in_channels)
|
|
self.bn2 = self.which_bn(self.hidden_channels)
|
|
self.bn3 = self.which_bn(self.hidden_channels)
|
|
self.bn4 = self.which_bn(self.hidden_channels)
|
|
# upsample layers
|
|
self.upsample = upsample
|
|
|
|
def forward(self, x):
|
|
# Project down to channel ratio
|
|
h = self.conv1(self.activation(self.bn1(x)))
|
|
# Apply next BN-ReLU
|
|
h = self.activation(self.bn2(h))
|
|
# Drop channels in x if necessary
|
|
if self.in_channels != self.out_channels:
|
|
x = x[:, :self.out_channels]
|
|
# Upsample both h and x at this point
|
|
if self.upsample:
|
|
h = self.upsample(h)
|
|
x = self.upsample(x)
|
|
# 3x3 convs
|
|
h = self.conv2(h)
|
|
h = self.conv3(self.activation(self.bn3(h)))
|
|
# Final 1x1 conv
|
|
h = self.conv4(self.activation(self.bn4(h)))
|
|
return h + x
|
|
|
|
|
|
def G_arch(ch=64, attention='64', base_width=64):
|
|
arch = {}
|
|
arch[128] = {'in_channels': [ch * item for item in [2, 2, 1, 1]],
|
|
'out_channels': [ch * item for item in [2, 1, 1, 1]],
|
|
'upsample': [False, True, False, False],
|
|
'resolution': [base_width, base_width, base_width*2, base_width*2],
|
|
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
|
|
for i in range(3, 8)}}
|
|
|
|
return arch
|
|
|
|
|
|
class Generator(nn.Module):
|
|
def __init__(self, G_ch=64, G_depth=2, bottom_width=4, resolution=128,
|
|
G_kernel_size=3, G_attn='64',
|
|
num_G_SVs=1, num_G_SV_itrs=1, hier=False,
|
|
cross_replica=False, mybn=False,
|
|
G_activation=nn.ReLU(inplace=False),
|
|
BN_eps=1e-5, SN_eps=1e-12,
|
|
G_init='ortho', skip_init=False,
|
|
G_param='SN', norm_style='bn'):
|
|
super(Generator, self).__init__()
|
|
# Channel width multiplier
|
|
self.ch = G_ch
|
|
# Number of resblocks per stage
|
|
self.G_depth = G_depth
|
|
# The initial spatial dimensions
|
|
self.bottom_width = bottom_width
|
|
# Resolution of the output
|
|
self.resolution = resolution
|
|
# Kernel size?
|
|
self.kernel_size = G_kernel_size
|
|
# Attention?
|
|
self.attention = G_attn
|
|
# Hierarchical latent space?
|
|
self.hier = hier
|
|
# Cross replica batchnorm?
|
|
self.cross_replica = cross_replica
|
|
# Use my batchnorm?
|
|
self.mybn = mybn
|
|
# nonlinearity for residual blocks
|
|
self.activation = G_activation
|
|
# Initialization style
|
|
self.init = G_init
|
|
# Parameterization style
|
|
self.G_param = G_param
|
|
# Normalization style
|
|
self.norm_style = norm_style
|
|
# Epsilon for BatchNorm?
|
|
self.BN_eps = BN_eps
|
|
# Epsilon for Spectral Norm?
|
|
self.SN_eps = SN_eps
|
|
# Architecture dict
|
|
self.arch = G_arch(self.ch, self.attention)[resolution]
|
|
|
|
# Which convs, batchnorms, and linear layers to use
|
|
if self.G_param == 'SN':
|
|
self.which_conv = functools.partial(layers.SNConv2d,
|
|
kernel_size=3, padding=1,
|
|
num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
|
|
eps=self.SN_eps)
|
|
else:
|
|
self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
|
|
|
|
self.which_bn = functools.partial(layers.bn,
|
|
cross_replica=self.cross_replica,
|
|
mybn=self.mybn,
|
|
norm_style=self.norm_style,
|
|
eps=self.BN_eps)
|
|
|
|
# Prepare model
|
|
# First conv layer to project into feature-space
|
|
self.initial_conv = nn.Sequential(self.which_conv(3, self.arch['in_channels'][0]),
|
|
layers.bn(self.arch['in_channels'][0],
|
|
cross_replica=self.cross_replica,
|
|
mybn=self.mybn),
|
|
self.activation
|
|
)
|
|
|
|
# self.blocks is a doubly-nested list of modules, the outer loop intended
|
|
# to be over blocks at a given resolution (resblocks and/or self-attention)
|
|
# while the inner loop is over a given block
|
|
self.blocks = []
|
|
for index in range(len(self.arch['out_channels'])):
|
|
self.blocks += [[GBlock(in_channels=self.arch['in_channels'][index],
|
|
out_channels=self.arch['in_channels'][index] if g_index == 0 else
|
|
self.arch['out_channels'][index],
|
|
which_conv=self.which_conv,
|
|
which_bn=self.which_bn,
|
|
activation=self.activation,
|
|
upsample=(functools.partial(F.interpolate, scale_factor=2)
|
|
if self.arch['upsample'][index] and g_index == (
|
|
self.G_depth - 1) else None))]
|
|
for g_index in range(self.G_depth)]
|
|
|
|
# If attention on this block, attach it to the end
|
|
if self.arch['attention'][self.arch['resolution'][index]]:
|
|
print('Adding attention layer in G at resolution %d' % self.arch['resolution'][index])
|
|
self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], self.which_conv)]
|
|
|
|
# Turn self.blocks into a ModuleList so that it's all properly registered.
|
|
self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
|
|
|
|
# output layer: batchnorm-relu-conv.
|
|
# Consider using a non-spectral conv here
|
|
self.output_layer = nn.Sequential(layers.bn(self.arch['out_channels'][-1],
|
|
cross_replica=self.cross_replica,
|
|
mybn=self.mybn),
|
|
self.activation,
|
|
self.which_conv(self.arch['out_channels'][-1], 3))
|
|
|
|
# Initialize weights. Optionally skip init for testing.
|
|
if not skip_init:
|
|
self.init_weights()
|
|
|
|
# Initialize
|
|
def init_weights(self):
|
|
self.param_count = 0
|
|
for module in self.modules():
|
|
if (isinstance(module, nn.Conv2d)
|
|
or isinstance(module, nn.Linear)
|
|
or isinstance(module, nn.Embedding)):
|
|
if self.init == 'ortho':
|
|
init.orthogonal_(module.weight)
|
|
elif self.init == 'N02':
|
|
init.normal_(module.weight, 0, 0.02)
|
|
elif self.init in ['glorot', 'xavier']:
|
|
init.xavier_uniform_(module.weight)
|
|
else:
|
|
print('Init style not recognized...')
|
|
self.param_count += sum([p.data.nelement() for p in module.parameters()])
|
|
print('Param count for G''s initialized parameters: %d' % self.param_count)
|
|
|
|
def forward(self, z):
|
|
# First conv layer to convert into correct filter-space.
|
|
h = self.initial_conv(z)
|
|
# Loop over blocks
|
|
for index, blocklist in enumerate(self.blocks):
|
|
# Second inner loop in case block has multiple layers
|
|
for block in blocklist:
|
|
h = block(h)
|
|
|
|
# Apply batchnorm-relu-conv-tanh at output
|
|
return (torch.tanh(self.output_layer(h)), )
|
|
|
|
def biggan_medium(num_filters):
|
|
return Generator(num_filters) |