import functools import torch from torch.nn import init import models.archs.biggan.biggan_layers as layers import torch.nn as nn # Discriminator architecture, same paradigm as G's above def D_arch(ch=64, attention='64',ksize='333333', dilation='111111'): arch = {} arch[256] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8, 8, 16]], 'out_channels' : [item * ch for item in [1, 2, 4, 8, 8, 16, 16]], 'downsample' : [True] * 6 + [False], 'resolution' : [128, 64, 32, 16, 8, 4, 4 ], 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')] for i in range(2,8)}} arch[128] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8, 16]], 'out_channels' : [item * ch for item in [1, 2, 4, 8, 16, 16]], 'downsample' : [True] * 5 + [False], 'resolution' : [64, 32, 16, 8, 4, 4], 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')] for i in range(2,8)}} arch[64] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8]], 'out_channels' : [item * ch for item in [1, 2, 4, 8, 16]], 'downsample' : [True] * 4 + [False], 'resolution' : [32, 16, 8, 4, 4], 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')] for i in range(2,7)}} arch[32] = {'in_channels' : [3] + [item * ch for item in [4, 4, 4]], 'out_channels' : [item * ch for item in [4, 4, 4, 4]], 'downsample' : [True, True, False, False], 'resolution' : [16, 16, 16, 16], 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')] for i in range(2,6)}} return arch class BigGanDiscriminator(nn.Module): def __init__(self, D_ch=64, D_wide=True, resolution=128, D_kernel_size=3, D_attn='64', num_D_SVs=1, num_D_SV_itrs=1, D_activation=nn.ReLU(inplace=False), SN_eps=1e-12, output_dim=1, D_fp16=False, D_init='ortho', skip_init=False, D_param='SN'): super(BigGanDiscriminator, self).__init__() # Width multiplier self.ch = D_ch # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN? self.D_wide = D_wide # Resolution self.resolution = resolution # Kernel size self.kernel_size = D_kernel_size # Attention? self.attention = D_attn # Activation self.activation = D_activation # Initialization style self.init = D_init # Parameterization style self.D_param = D_param # Epsilon for Spectral Norm? self.SN_eps = SN_eps # Fp16? self.fp16 = D_fp16 # Architecture self.arch = D_arch(self.ch, self.attention)[resolution] # Which convs, batchnorms, and linear layers to use # No option to turn off SN in D right now if self.D_param == 'SN': self.which_conv = functools.partial(layers.SNConv2d, kernel_size=3, padding=1, num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, eps=self.SN_eps) self.which_linear = functools.partial(layers.SNLinear, num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, eps=self.SN_eps) self.which_embedding = functools.partial(layers.SNEmbedding, num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, eps=self.SN_eps) # Prepare model # self.blocks is a doubly-nested list of modules, the outer loop intended # to be over blocks at a given resolution (resblocks and/or self-attention) self.blocks = [] for index in range(len(self.arch['out_channels'])): self.blocks += [[layers.DBlock(in_channels=self.arch['in_channels'][index], out_channels=self.arch['out_channels'][index], which_conv=self.which_conv, wide=self.D_wide, activation=self.activation, preactivation=(index > 0), downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] else None))]] # If attention on this block, attach it to the end if self.arch['attention'][self.arch['resolution'][index]]: print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index]) self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], self.which_conv)] # Turn self.blocks into a ModuleList so that it's all properly registered. self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) # Linear output layer. The output dimension is typically 1, but may be # larger if we're e.g. turning this into a VAE with an inference output self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim) # Initialize weights if not skip_init: self.init_weights() # Initialize def init_weights(self): self.param_count = 0 for module in self.modules(): if (isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, nn.Embedding)): if self.init == 'ortho': init.orthogonal_(module.weight) elif self.init == 'N02': init.normal_(module.weight, 0, 0.02) elif self.init in ['glorot', 'xavier']: init.xavier_uniform_(module.weight) else: print('Init style not recognized...') self.param_count += sum([p.data.nelement() for p in module.parameters()]) print('Param count for D''s initialized parameters: %d' % self.param_count) def forward(self, x, y=None): # Stick x into h for cleaner for loops without flow control h = x # Loop over blocks for index, blocklist in enumerate(self.blocks): for block in blocklist: h = block(h) # Apply global sum pooling as in SN-GAN h = torch.sum(self.activation(h), [2, 3]) # Get initial class-unconditional output out = self.linear(h) return out