Lots of new discriminator nets

This commit is contained in:
James Betker 2020-11-10 16:06:54 -07:00
parent 4e5ba61ae7
commit 6a2fd5f7d0
12 changed files with 990 additions and 442 deletions

0
codes/models/__init__.py Normal file
View File

View File

@ -1,197 +0,0 @@
import os
import torch
import torchvision
from torch import nn
from models.archs.SPSR_arch import ImageGradientNoPadding
from models.archs.arch_util import ConvGnLelu, ExpansionBlock2, ConvGnSilu, ConjoinBlock, MultiConvBlock, \
FinalUpsampleBlock2x, ReferenceJoinBlock
from models.archs.spinenet_arch import SpineNet
from utils.util import checkpoint
class BasicEmbeddingPyramid(nn.Module):
def __init__(self, use_norms=True):
super(BasicEmbeddingPyramid, self).__init__()
self.initial_process = ConvGnLelu(64, 64, kernel_size=1, bias=True, activation=True, norm=False)
self.reducers = nn.ModuleList([ConvGnLelu(64, 128, stride=2, kernel_size=1, bias=False, activation=True, norm=False),
ConvGnLelu(128, 128, kernel_size=3, bias=False, activation=True, norm=use_norms),
ConvGnLelu(128, 256, stride=2, kernel_size=1, bias=False, activation=True, norm=False),
ConvGnLelu(256, 256, kernel_size=3, bias=False, activation=True, norm=use_norms)])
self.expanders = nn.ModuleList([ExpansionBlock2(256, 128, block=ConvGnLelu),
ExpansionBlock2(128, 64, block=ConvGnLelu)])
self.embedding_processor1 = ConvGnSilu(256, 128, kernel_size=1, bias=True, activation=True, norm=False)
self.embedding_joiner1 = ConjoinBlock(128, block=ConvGnLelu, norm=use_norms)
self.embedding_processor2 = ConvGnSilu(256, 256, kernel_size=1, bias=True, activation=True, norm=False)
self.embedding_joiner2 = ConjoinBlock(256, block=ConvGnLelu, norm=use_norms)
self.final_process = nn.Sequential(ConvGnLelu(128, 96, kernel_size=1, bias=False, activation=False, norm=False,
weight_init_factor=.1),
ConvGnLelu(96, 64, kernel_size=1, bias=False, activation=False, norm=False,
weight_init_factor=.1),
ConvGnLelu(64, 64, kernel_size=1, bias=False, activation=False, norm=False,
weight_init_factor=.1),
ConvGnLelu(64, 64, kernel_size=1, bias=False, activation=False, norm=False,
weight_init_factor=.1))
def forward(self, x, *embeddings):
p = self.initial_process(x)
identities = []
for i in range(2):
identities.append(p)
p = self.reducers[i*2](p)
p = self.reducers[i*2+1](p)
if i == 0:
p = self.embedding_joiner1(p, self.embedding_processor1(embeddings[0]))
elif i == 1:
p = self.embedding_joiner2(p, self.embedding_processor2(embeddings[1]))
for i in range(2):
p = self.expanders[i](p, identities[-(i+1)])
x = self.final_process(torch.cat([x, p], dim=1))
return x, p
class ChainedEmbeddingGenWithStructure(nn.Module):
def __init__(self, in_nc=3, depth=10, recurrent=False, recurrent_nf=3, recurrent_stride=2):
super(ChainedEmbeddingGenWithStructure, self).__init__()
self.recurrent = recurrent
self.initial_conv = ConvGnLelu(in_nc, 64, kernel_size=7, bias=True, norm=False, activation=False)
if recurrent:
self.recurrent_nf = recurrent_nf
self.recurrent_stride = recurrent_stride
self.recurrent_process = ConvGnLelu(recurrent_nf, 64, kernel_size=3, stride=recurrent_stride, norm=False, bias=True, activation=False)
self.recurrent_join = ReferenceJoinBlock(64, residual_weight_init_factor=.01, final_norm=False, kernel_size=1, depth=3, join=False)
self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False)
self.blocks = nn.ModuleList([BasicEmbeddingPyramid() for i in range(depth)])
self.structure_joins = nn.ModuleList([ConjoinBlock(64) for i in range(3)])
self.structure_blocks = nn.ModuleList([ConvGnLelu(64, 64, kernel_size=3, bias=False, norm=False, activation=False, weight_init_factor=.1) for i in range(3)])
self.structure_upsample = FinalUpsampleBlock2x(64)
self.grad_extract = ImageGradientNoPadding()
self.upsample = FinalUpsampleBlock2x(64)
self.ref_join_std = 0
def forward(self, x, recurrent=None):
fea = self.initial_conv(x)
if self.recurrent:
if recurrent is None:
if self.recurrent_nf == 3:
recurrent = torch.zeros_like(x)
if self.recurrent_stride != 1:
recurrent = torch.nn.functional.interpolate(recurrent, scale_factor=self.recurrent_stride, mode='nearest')
else:
recurrent = torch.zeros_like(fea)
rec = self.recurrent_process(recurrent)
fea, recstd = self.recurrent_join(fea, rec)
self.ref_join_std = recstd.item()
if self.spine is not None:
emb = checkpoint(self.spine, fea)
else:
b,f,h,w = fea.shape
emb = (torch.zeros((b,f,h//2,w//2), device=fea.device),
torch.zeros((b,f,h//4,w//4), device=fea.device))
grad = fea
for i, block in enumerate(self.blocks):
fea = fea + checkpoint(block, fea, *emb)[0]
if i < 3:
structure_br = checkpoint(self.structure_joins[i], grad, fea)
grad = grad + checkpoint(self.structure_blocks[i], structure_br)
out = checkpoint(self.upsample, fea)
return out, self.grad_extract(checkpoint(self.structure_upsample, grad)), self.grad_extract(out), fea
def get_debug_values(self, step, net_name):
return { 'ref_join_std': self.ref_join_std }
# This is a structural block that learns to mute regions of a residual transformation given a signal.
class OptionalPassthroughBlock(nn.Module):
def __init__(self, nf, initial_bias=10):
super(OptionalPassthroughBlock, self).__init__()
self.switch_process = nn.Sequential(ConvGnLelu(nf, nf // 2, 1, activation=False, norm=False, bias=False),
ConvGnLelu(nf // 2, nf // 4, 1, activation=False, norm=False, bias=False),
ConvGnLelu(nf // 4, 1, 1, activation=False, norm=False, bias=False))
self.bias = nn.Parameter(torch.tensor(initial_bias, dtype=torch.float), requires_grad=True)
self.activation = nn.Sigmoid()
def forward(self, x, switch_signal):
switch = self.switch_process(switch_signal)
bypass_map = self.activation(self.bias + switch)
return x * bypass_map, bypass_map
class MultifacetedChainedEmbeddingGen(nn.Module):
def __init__(self, depth=10, scale=2):
super(MultifacetedChainedEmbeddingGen, self).__init__()
assert scale == 2 or scale == 4
self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False)
if scale == 2:
self.teco_recurrent_process = ConvGnLelu(3, 64, kernel_size=3, stride=2, norm=False, bias=True, activation=False)
else:
self.teco_recurrent_process = ConvGnLelu(3, 64, kernel_size=7, stride=4, norm=False, bias=True, activation=False)
self.teco_recurrent_join = ReferenceJoinBlock(64, residual_weight_init_factor=.01, final_norm=False, kernel_size=1, depth=3, join=False)
self.prog_recurrent_process = ConvGnLelu(64, 64, kernel_size=3, stride=1, norm=False, bias=True, activation=False)
self.prog_recurrent_join = ReferenceJoinBlock(64, residual_weight_init_factor=.01, final_norm=False, kernel_size=1, depth=3, join=False)
self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False)
self.blocks = nn.ModuleList([BasicEmbeddingPyramid() for i in range(depth)])
self.bypasses = nn.ModuleList([OptionalPassthroughBlock(64, initial_bias=0) for i in range(depth)])
self.structure_joins = nn.ModuleList([ConjoinBlock(64) for i in range(3)])
self.structure_blocks = nn.ModuleList([ConvGnLelu(64, 64, kernel_size=3, bias=False, norm=False, activation=False, weight_init_factor=.1) for i in range(3)])
self.structure_upsample = FinalUpsampleBlock2x(64, scale=scale)
self.grad_extract = ImageGradientNoPadding()
self.upsample = FinalUpsampleBlock2x(64, scale=scale)
self.teco_ref_std = 0
self.prog_ref_std = 0
self.block_residual_means = [0 for _ in range(depth)]
self.block_residual_stds = [0 for _ in range(depth)]
self.bypass_maps = []
def forward(self, x, teco_recurrent=None, prog_recurrent=None):
fea = self.initial_conv(x)
# Integrate recurrence inputs.
if teco_recurrent is not None:
teco_rec = self.teco_recurrent_process(teco_recurrent)
fea, std = self.teco_recurrent_join(fea, teco_rec)
self.teco_ref_std = std.item()
elif prog_recurrent is not None:
prog_rec = self.prog_recurrent_process(prog_recurrent)
prog_rec, std = self.prog_recurrent_join(fea, prog_rec)
self.prog_ref_std = std.item()
emb = checkpoint(self.spine, fea)
grad = fea
self.bypass_maps = []
for i, block in enumerate(self.blocks):
residual, context = checkpoint(block, fea, *emb)
residual, bypass_map = checkpoint(self.bypasses[i], residual, context)
fea = fea + residual
self.bypass_maps.append(bypass_map.detach())
self.block_residual_means[i] = residual.mean().item()
self.block_residual_stds[i] = residual.std().item()
if i < 3:
structure_br = checkpoint(self.structure_joins[i], grad, fea)
grad = grad + checkpoint(self.structure_blocks[i], structure_br)
out = checkpoint(self.upsample, fea)
return out, self.grad_extract(checkpoint(self.structure_upsample, grad)), self.grad_extract(out), fea
def visual_dbg(self, step, path):
for i, bm in enumerate(self.bypass_maps):
torchvision.utils.save_image(bm.cpu().float(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1)))
def get_debug_values(self, step, net_name):
biases = [b.bias.item() for b in self.bypasses]
blk_stds, blk_means = {}, {}
for i, (s, m) in enumerate(zip(self.block_residual_stds, self.block_residual_means)):
blk_stds['block_%i' % (i+1,)] = s
blk_means['block_%i' % (i+1,)] = m
return {'teco_std': self.teco_ref_std,
'prog_std': self.prog_ref_std,
'bypass_biases': sum(biases) / len(biases),
'blocks_std': blk_stds, 'blocks_mean': blk_means}

View File

@ -1,225 +0,0 @@
import torch
import torch.nn as nn
import numpy as np
__all__ = ['FixupResNet', 'fixup_resnet18', 'fixup_resnet34', 'fixup_resnet50', 'fixup_resnet101', 'fixup_resnet152']
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
def conv5x5(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=5, stride=stride,
padding=2, bias=False)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class FixupBasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, use_bn=False, conv_create=conv3x3):
super(FixupBasicBlock, self).__init__()
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.bias1a = nn.Parameter(torch.zeros(1))
self.conv1 = conv_create(inplanes, planes, stride)
self.bias1b = nn.Parameter(torch.zeros(1))
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self.bias2a = nn.Parameter(torch.zeros(1))
self.conv2 = conv_create(planes, planes)
self.scale = nn.Parameter(torch.ones(1))
self.bias2b = nn.Parameter(torch.zeros(1))
self.downsample = downsample
self.stride = stride
self.use_bn = use_bn
if use_bn:
self.bn1 = nn.BatchNorm2d(planes)
self.bn2 = nn.BatchNorm2d(planes)
def forward(self, x):
identity = x
out = self.conv1(x + self.bias1a)
if self.use_bn:
out = self.bn1(out)
out = self.lrelu(out + self.bias1b)
out = self.conv2(out + self.bias2a)
if self.use_bn:
out = self.bn2(out)
out = out * self.scale + self.bias2b
if self.downsample is not None:
identity = self.downsample(x + self.bias1a)
out += identity
out = self.lrelu(out)
return out
class FixupBottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(FixupBottleneck, self).__init__()
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.bias1a = nn.Parameter(torch.zeros(1))
self.conv1 = conv1x1(inplanes, planes)
self.bias1b = nn.Parameter(torch.zeros(1))
self.bias2a = nn.Parameter(torch.zeros(1))
self.conv2 = conv3x3(planes, planes, stride)
self.bias2b = nn.Parameter(torch.zeros(1))
self.bias3a = nn.Parameter(torch.zeros(1))
self.conv3 = conv1x1(planes, planes * self.expansion)
self.scale = nn.Parameter(torch.ones(1))
self.bias3b = nn.Parameter(torch.zeros(1))
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x + self.bias1a)
out = self.lrelu(out + self.bias1b)
out = self.conv2(out + self.bias2a)
out = self.lrelu(out + self.bias2b)
out = self.conv3(out + self.bias3a)
out = out * self.scale + self.bias3b
if self.downsample is not None:
identity = self.downsample(x + self.bias1a)
out += identity
out = self.lrelu(out)
return out
class FixupResNet(nn.Module):
def __init__(self, block, layers, num_filters=64, num_classes=1000, input_img_size=64, number_skips=2, use_bn=False,
disable_passthrough=False):
super(FixupResNet, self).__init__()
self.num_layers = sum(layers)
self.inplanes = 3
self.number_skips = number_skips
self.disable_passthrough = disable_passthrough
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.layer0 = self._make_layer(block, num_filters*2, layers[0], stride=2, use_bn=use_bn, conv_type=conv5x5)
if number_skips > 0:
self.inplanes = self.inplanes + 3 # Accomodate a skip connection from the generator.
self.layer1 = self._make_layer(block, num_filters*4, layers[1], stride=2, use_bn=use_bn, conv_type=conv5x5)
if number_skips > 1:
self.inplanes = self.inplanes + 3 # Accomodate a second skip connection from the generator.
self.layer2 = self._make_layer(block, num_filters*8, layers[2], stride=2, use_bn=use_bn)
# SRGAN already has a feature loss tied to a separate VGG discriminator. We really don't care about features.
# Therefore, level off the filter count from this block forwards.
self.layer3 = self._make_layer(block, num_filters*8, layers[3], stride=2, use_bn=use_bn)
self.layer4 = self._make_layer(block, num_filters*8, layers[4], stride=2, use_bn=use_bn)
self.bias2 = nn.Parameter(torch.zeros(1))
reduced_img_sz = int(input_img_size / 32)
self.fc1 = nn.Linear(num_filters * 8 * reduced_img_sz * reduced_img_sz, 100)
self.fc2 = nn.Linear(100, num_classes)
for m in self.modules():
if isinstance(m, FixupBasicBlock):
nn.init.normal_(m.conv1.weight, mean=0, std=np.sqrt(2 / (m.conv1.weight.shape[0] * np.prod(m.conv1.weight.shape[2:]))) * self.num_layers ** (-0.5))
nn.init.constant_(m.conv2.weight, 0)
if m.downsample is not None:
nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:]))))
elif isinstance(m, FixupBottleneck):
nn.init.normal_(m.conv1.weight, mean=0, std=np.sqrt(2 / (m.conv1.weight.shape[0] * np.prod(m.conv1.weight.shape[2:]))) * self.num_layers ** (-0.25))
nn.init.normal_(m.conv2.weight, mean=0, std=np.sqrt(2 / (m.conv2.weight.shape[0] * np.prod(m.conv2.weight.shape[2:]))) * self.num_layers ** (-0.25))
nn.init.constant_(m.conv3.weight, 0)
if m.downsample is not None:
nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:]))))
def _make_layer(self, block, outplanes, blocks, stride=1, use_bn=False, conv_type=conv3x3):
layers = []
for _ in range(1, blocks):
layers.append(block(self.inplanes, self.inplanes))
downsample = None
if stride != 1 or self.inplanes != outplanes * block.expansion:
downsample = conv1x1(self.inplanes, outplanes * block.expansion, stride)
layers.append(block(self.inplanes, outplanes, stride, downsample, use_bn=use_bn, conv_create=conv_type))
self.inplanes = outplanes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
if len(x) == 3:
# This class can take a medium skip (half-res) and low skip (quarter-res) provided as a tuple in the input.
x, med_skip, lo_skip = x
else:
# Or just a tuple with only the high res input (this assumes number_skips was set right).
x = x[0]
if self.disable_passthrough:
if self.number_skips > 0:
med_skip = torch.zeros_like(med_skip)
if self.number_skips > 1:
lo_skip = torch.zeros_like(lo_skip)
x = self.layer0(x)
if self.number_skips > 0:
x = torch.cat([x, med_skip], dim=1)
x = self.layer1(x)
if self.number_skips > 1:
x = torch.cat([x, lo_skip], dim=1)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = x.view(x.size(0), -1)
x = self.lrelu(self.fc1(x))
x = self.fc2(x + self.bias2)
return x
def fixup_resnet18(**kwargs):
"""Constructs a Fixup-ResNet-18 model.2
"""
model = FixupResNet(FixupBasicBlock, [2, 2, 2, 2, 2], **kwargs)
return model
def fixup_resnet34(**kwargs):
"""Constructs a Fixup-ResNet-34 model.
"""
model = FixupResNet(FixupBasicBlock, [5, 5, 3, 3, 3], **kwargs)
return model
def fixup_resnet50(**kwargs):
"""Constructs a Fixup-ResNet-50 model.
"""
model = FixupResNet(FixupBottleneck, [3, 4, 6, 3, 2], **kwargs)
return model
def fixup_resnet101(**kwargs):
"""Constructs a Fixup-ResNet-101 model.
"""
model = FixupResNet(FixupBottleneck, [3, 4, 23, 3, 2], **kwargs)
return model
def fixup_resnet152(**kwargs):
"""Constructs a Fixup-ResNet-152 model.
"""
model = FixupResNet(FixupBottleneck, [3, 8, 36, 3, 2], **kwargs)
return model
__all__ = ['FixupResNet', 'fixup_resnet18', 'fixup_resnet34', 'fixup_resnet50', 'fixup_resnet101', 'fixup_resnet152']

View File

@ -0,0 +1,139 @@
import functools
import torch
from torch.nn import init
import models.archs.biggan.biggan_layers as layers
import torch.nn as nn
# Discriminator architecture, same paradigm as G's above
def D_arch(ch=64, attention='64',ksize='333333', dilation='111111'):
arch = {}
arch[256] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8, 8, 16]],
'out_channels' : [item * ch for item in [1, 2, 4, 8, 8, 16, 16]],
'downsample' : [True] * 6 + [False],
'resolution' : [128, 64, 32, 16, 8, 4, 4 ],
'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')]
for i in range(2,8)}}
arch[128] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8, 16]],
'out_channels' : [item * ch for item in [1, 2, 4, 8, 16, 16]],
'downsample' : [True] * 5 + [False],
'resolution' : [64, 32, 16, 8, 4, 4],
'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')]
for i in range(2,8)}}
arch[64] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8]],
'out_channels' : [item * ch for item in [1, 2, 4, 8, 16]],
'downsample' : [True] * 4 + [False],
'resolution' : [32, 16, 8, 4, 4],
'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')]
for i in range(2,7)}}
arch[32] = {'in_channels' : [3] + [item * ch for item in [4, 4, 4]],
'out_channels' : [item * ch for item in [4, 4, 4, 4]],
'downsample' : [True, True, False, False],
'resolution' : [16, 16, 16, 16],
'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')]
for i in range(2,6)}}
return arch
class BigGanDiscriminator(nn.Module):
def __init__(self, D_ch=64, D_wide=True, resolution=128,
D_kernel_size=3, D_attn='64', num_D_SVs=1, num_D_SV_itrs=1, D_activation=nn.ReLU(inplace=False),
SN_eps=1e-12, output_dim=1, D_fp16=False,
D_init='ortho', skip_init=False, D_param='SN'):
super(BigGanDiscriminator, self).__init__()
# Width multiplier
self.ch = D_ch
# Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN?
self.D_wide = D_wide
# Resolution
self.resolution = resolution
# Kernel size
self.kernel_size = D_kernel_size
# Attention?
self.attention = D_attn
# Activation
self.activation = D_activation
# Initialization style
self.init = D_init
# Parameterization style
self.D_param = D_param
# Epsilon for Spectral Norm?
self.SN_eps = SN_eps
# Fp16?
self.fp16 = D_fp16
# Architecture
self.arch = D_arch(self.ch, self.attention)[resolution]
# Which convs, batchnorms, and linear layers to use
# No option to turn off SN in D right now
if self.D_param == 'SN':
self.which_conv = functools.partial(layers.SNConv2d,
kernel_size=3, padding=1,
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
eps=self.SN_eps)
self.which_linear = functools.partial(layers.SNLinear,
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
eps=self.SN_eps)
self.which_embedding = functools.partial(layers.SNEmbedding,
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
eps=self.SN_eps)
# Prepare model
# self.blocks is a doubly-nested list of modules, the outer loop intended
# to be over blocks at a given resolution (resblocks and/or self-attention)
self.blocks = []
for index in range(len(self.arch['out_channels'])):
self.blocks += [[layers.DBlock(in_channels=self.arch['in_channels'][index],
out_channels=self.arch['out_channels'][index],
which_conv=self.which_conv,
wide=self.D_wide,
activation=self.activation,
preactivation=(index > 0),
downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] else None))]]
# If attention on this block, attach it to the end
if self.arch['attention'][self.arch['resolution'][index]]:
print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index])
self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index],
self.which_conv)]
# Turn self.blocks into a ModuleList so that it's all properly registered.
self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
# Linear output layer. The output dimension is typically 1, but may be
# larger if we're e.g. turning this into a VAE with an inference output
self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim)
# Initialize weights
if not skip_init:
self.init_weights()
# Initialize
def init_weights(self):
self.param_count = 0
for module in self.modules():
if (isinstance(module, nn.Conv2d)
or isinstance(module, nn.Linear)
or isinstance(module, nn.Embedding)):
if self.init == 'ortho':
init.orthogonal_(module.weight)
elif self.init == 'N02':
init.normal_(module.weight, 0, 0.02)
elif self.init in ['glorot', 'xavier']:
init.xavier_uniform_(module.weight)
else:
print('Init style not recognized...')
self.param_count += sum([p.data.nelement() for p in module.parameters()])
print('Param count for D''s initialized parameters: %d' % self.param_count)
def forward(self, x, y=None):
# Stick x into h for cleaner for loops without flow control
h = x
# Loop over blocks
for index, blocklist in enumerate(self.blocks):
for block in blocklist:
h = block(h)
# Apply global sum pooling as in SN-GAN
h = torch.sum(self.activation(h), [2, 3])
# Get initial class-unconditional output
out = self.linear(h)
return out

View File

@ -0,0 +1,457 @@
''' Layers
This file contains various layers for the BigGAN models.
'''
import numpy as np
import torch
import torch.nn as nn
from torch.nn import init
import torch.optim as optim
import torch.nn.functional as F
from torch.nn import Parameter as P
# Projection of x onto y
def proj(x, y):
return torch.mm(y, x.t()) * y / torch.mm(y, y.t())
# Orthogonalize x wrt list of vectors ys
def gram_schmidt(x, ys):
for y in ys:
x = x - proj(x, y)
return x
# Apply num_itrs steps of the power method to estimate top N singular values.
def power_iteration(W, u_, update=True, eps=1e-12):
# Lists holding singular vectors and values
us, vs, svs = [], [], []
for i, u in enumerate(u_):
# Run one step of the power iteration
with torch.no_grad():
v = torch.matmul(u, W)
# Run Gram-Schmidt to subtract components of all other singular vectors
v = F.normalize(gram_schmidt(v, vs), eps=eps)
# Add to the list
vs += [v]
# Update the other singular vector
u = torch.matmul(v, W.t())
# Run Gram-Schmidt to subtract components of all other singular vectors
u = F.normalize(gram_schmidt(u, us), eps=eps)
# Add to the list
us += [u]
if update:
u_[i][:] = u
# Compute this singular value and add it to the list
svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))]
# svs += [torch.sum(F.linear(u, W.transpose(0, 1)) * v)]
return svs, us, vs
# Convenience passthrough function
class identity(nn.Module):
def forward(self, input):
return input
# Spectral normalization base class
class SN(object):
def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12):
# Number of power iterations per step
self.num_itrs = num_itrs
# Number of singular values
self.num_svs = num_svs
# Transposed?
self.transpose = transpose
# Epsilon value for avoiding divide-by-0
self.eps = eps
# Register a singular vector for each sv
for i in range(self.num_svs):
self.register_buffer('u%d' % i, torch.randn(1, num_outputs))
self.register_buffer('sv%d' % i, torch.ones(1))
# Singular vectors (u side)
@property
def u(self):
return [getattr(self, 'u%d' % i) for i in range(self.num_svs)]
# Singular values;
# note that these buffers are just for logging and are not used in training.
@property
def sv(self):
return [getattr(self, 'sv%d' % i) for i in range(self.num_svs)]
# Compute the spectrally-normalized weight
def W_(self):
W_mat = self.weight.view(self.weight.size(0), -1)
if self.transpose:
W_mat = W_mat.t()
# Apply num_itrs power iterations
for _ in range(self.num_itrs):
svs, us, vs = power_iteration(W_mat, self.u, update=self.training, eps=self.eps)
# Update the svs
if self.training:
with torch.no_grad(): # Make sure to do this in a no_grad() context or you'll get memory leaks!
for i, sv in enumerate(svs):
self.sv[i][:] = sv
return self.weight / svs[0]
# 2D Conv layer with spectral norm
class SNConv2d(nn.Conv2d, SN):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True,
num_svs=1, num_itrs=1, eps=1e-12):
nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias)
SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps)
def forward(self, x):
return F.conv2d(x, self.W_(), self.bias, self.stride,
self.padding, self.dilation, self.groups)
# Linear layer with spectral norm
class SNLinear(nn.Linear, SN):
def __init__(self, in_features, out_features, bias=True,
num_svs=1, num_itrs=1, eps=1e-12):
nn.Linear.__init__(self, in_features, out_features, bias)
SN.__init__(self, num_svs, num_itrs, out_features, eps=eps)
def forward(self, x):
return F.linear(x, self.W_(), self.bias)
# Embedding layer with spectral norm
# We use num_embeddings as the dim instead of embedding_dim here
# for convenience sake
class SNEmbedding(nn.Embedding, SN):
def __init__(self, num_embeddings, embedding_dim, padding_idx=None,
max_norm=None, norm_type=2, scale_grad_by_freq=False,
sparse=False, _weight=None,
num_svs=1, num_itrs=1, eps=1e-12):
nn.Embedding.__init__(self, num_embeddings, embedding_dim, padding_idx,
max_norm, norm_type, scale_grad_by_freq,
sparse, _weight)
SN.__init__(self, num_svs, num_itrs, num_embeddings, eps=eps)
def forward(self, x):
return F.embedding(x, self.W_())
# A non-local block as used in SA-GAN
# Note that the implementation as described in the paper is largely incorrect;
# refer to the released code for the actual implementation.
class Attention(nn.Module):
def __init__(self, ch, which_conv=SNConv2d, name='attention'):
super(Attention, self).__init__()
# Channel multiplier
self.ch = ch
self.which_conv = which_conv
self.theta = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False)
self.phi = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False)
self.g = self.which_conv(self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False)
self.o = self.which_conv(self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False)
# Learnable gain parameter
self.gamma = P(torch.tensor(0.), requires_grad=True)
def forward(self, x, y=None):
# Apply convs
theta = self.theta(x)
phi = F.max_pool2d(self.phi(x), [2, 2])
g = F.max_pool2d(self.g(x), [2, 2])
# Perform reshapes
theta = theta.view(-1, self.ch // 8, x.shape[2] * x.shape[3])
phi = phi.view(-1, self.ch // 8, x.shape[2] * x.shape[3] // 4)
g = g.view(-1, self.ch // 2, x.shape[2] * x.shape[3] // 4)
# Matmul and softmax to get attention maps
beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1)
# Attention map times g path
o = self.o(torch.bmm(g, beta.transpose(1, 2)).view(-1, self.ch // 2, x.shape[2], x.shape[3]))
return self.gamma * o + x
# Fused batchnorm op
def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5):
# Apply scale and shift--if gain and bias are provided, fuse them here
# Prepare scale
scale = torch.rsqrt(var + eps)
# If a gain is provided, use it
if gain is not None:
scale = scale * gain
# Prepare shift
shift = mean * scale
# If bias is provided, use it
if bias is not None:
shift = shift - bias
return x * scale - shift
# return ((x - mean) / ((var + eps) ** 0.5)) * gain + bias # The unfused way.
# Manual BN
# Calculate means and variances using mean-of-squares minus mean-squared
def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5):
# Cast x to float32 if necessary
float_x = x.float()
# Calculate expected value of x (m) and expected value of x**2 (m2)
# Mean of x
m = torch.mean(float_x, [0, 2, 3], keepdim=True)
# Mean of x squared
m2 = torch.mean(float_x ** 2, [0, 2, 3], keepdim=True)
# Calculate variance as mean of squared minus mean squared.
var = (m2 - m ** 2)
# Cast back to float 16 if necessary
var = var.type(x.type())
m = m.type(x.type())
# Return mean and variance for updating stored mean/var if requested
if return_mean_var:
return fused_bn(x, m, var, gain, bias, eps), m.squeeze(), var.squeeze()
else:
return fused_bn(x, m, var, gain, bias, eps)
# My batchnorm, supports standing stats
class myBN(nn.Module):
def __init__(self, num_channels, eps=1e-5, momentum=0.1):
super(myBN, self).__init__()
# momentum for updating running stats
self.momentum = momentum
# epsilon to avoid dividing by 0
self.eps = eps
# Momentum
self.momentum = momentum
# Register buffers
self.register_buffer('stored_mean', torch.zeros(num_channels))
self.register_buffer('stored_var', torch.ones(num_channels))
self.register_buffer('accumulation_counter', torch.zeros(1))
# Accumulate running means and vars
self.accumulate_standing = False
# reset standing stats
def reset_stats(self):
self.stored_mean[:] = 0
self.stored_var[:] = 0
self.accumulation_counter[:] = 0
def forward(self, x, gain, bias):
if self.training:
out, mean, var = manual_bn(x, gain, bias, return_mean_var=True, eps=self.eps)
# If accumulating standing stats, increment them
if self.accumulate_standing:
self.stored_mean[:] = self.stored_mean + mean.data
self.stored_var[:] = self.stored_var + var.data
self.accumulation_counter += 1.0
# If not accumulating standing stats, take running averages
else:
self.stored_mean[:] = self.stored_mean * (1 - self.momentum) + mean * self.momentum
self.stored_var[:] = self.stored_var * (1 - self.momentum) + var * self.momentum
return out
# If not in training mode, use the stored statistics
else:
mean = self.stored_mean.view(1, -1, 1, 1)
var = self.stored_var.view(1, -1, 1, 1)
# If using standing stats, divide them by the accumulation counter
if self.accumulate_standing:
mean = mean / self.accumulation_counter
var = var / self.accumulation_counter
return fused_bn(x, mean, var, gain, bias, self.eps)
# Simple function to handle groupnorm norm stylization
def groupnorm(x, norm_style):
# If number of channels specified in norm_style:
if 'ch' in norm_style:
ch = int(norm_style.split('_')[-1])
groups = max(int(x.shape[1]) // ch, 1)
# If number of groups specified in norm style
elif 'grp' in norm_style:
groups = int(norm_style.split('_')[-1])
# If neither, default to groups = 16
else:
groups = 16
return F.group_norm(x, groups)
# Class-conditional bn
# output size is the number of channels, input size is for the linear layers
# Andy's Note: this class feels messy but I'm not really sure how to clean it up
# Suggestions welcome! (By which I mean, refactor this and make a pull request
# if you want to make this more readable/usable).
class ccbn(nn.Module):
def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1,
cross_replica=False, mybn=False, norm_style='bn', ):
super(ccbn, self).__init__()
self.output_size, self.input_size = output_size, input_size
# Prepare gain and bias layers
self.gain = which_linear(input_size, output_size)
self.bias = which_linear(input_size, output_size)
# epsilon to avoid dividing by 0
self.eps = eps
# Momentum
self.momentum = momentum
# Use cross-replica batchnorm?
self.cross_replica = cross_replica
# Use my batchnorm?
self.mybn = mybn
# Norm style?
self.norm_style = norm_style
if self.cross_replica or self.mybn:
self.bn = myBN(output_size, self.eps, self.momentum)
elif self.norm_style in ['bn', 'in']:
self.register_buffer('stored_mean', torch.zeros(output_size))
self.register_buffer('stored_var', torch.ones(output_size))
def forward(self, x, y):
# Calculate class-conditional gains and biases
gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
bias = self.bias(y).view(y.size(0), -1, 1, 1)
# If using my batchnorm
if self.mybn or self.cross_replica:
return self.bn(x, gain=gain, bias=bias)
# else:
else:
if self.norm_style == 'bn':
out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None,
self.training, 0.1, self.eps)
elif self.norm_style == 'in':
out = F.instance_norm(x, self.stored_mean, self.stored_var, None, None,
self.training, 0.1, self.eps)
elif self.norm_style == 'gn':
out = groupnorm(x, self.normstyle)
elif self.norm_style == 'nonorm':
out = x
return out * gain + bias
def extra_repr(self):
s = 'out: {output_size}, in: {input_size},'
s += ' cross_replica={cross_replica}'
return s.format(**self.__dict__)
# Normal, non-class-conditional BN
class bn(nn.Module):
def __init__(self, output_size, eps=1e-5, momentum=0.1,
cross_replica=False, mybn=False):
super(bn, self).__init__()
self.output_size = output_size
# Prepare gain and bias layers
self.gain = P(torch.ones(output_size), requires_grad=True)
self.bias = P(torch.zeros(output_size), requires_grad=True)
# epsilon to avoid dividing by 0
self.eps = eps
# Momentum
self.momentum = momentum
# Use cross-replica batchnorm?
self.cross_replica = cross_replica
# Use my batchnorm?
self.mybn = mybn
if self.cross_replica or mybn:
self.bn = myBN(output_size, self.eps, self.momentum)
# Register buffers if neither of the above
else:
self.register_buffer('stored_mean', torch.zeros(output_size))
self.register_buffer('stored_var', torch.ones(output_size))
def forward(self, x, y=None):
if self.cross_replica or self.mybn:
gain = self.gain.view(1, -1, 1, 1)
bias = self.bias.view(1, -1, 1, 1)
return self.bn(x, gain=gain, bias=bias)
else:
return F.batch_norm(x, self.stored_mean, self.stored_var, self.gain,
self.bias, self.training, self.momentum, self.eps)
# Generator blocks
# Note that this class assumes the kernel size and padding (and any other
# settings) have been selected in the main generator module and passed in
# through the which_conv arg. Similar rules apply with which_bn (the input
# size [which is actually the number of channels of the conditional info] must
# be preselected)
class GBlock(nn.Module):
def __init__(self, in_channels, out_channels,
which_conv=nn.Conv2d, which_bn=bn, activation=None,
upsample=None):
super(GBlock, self).__init__()
self.in_channels, self.out_channels = in_channels, out_channels
self.which_conv, self.which_bn = which_conv, which_bn
self.activation = activation
self.upsample = upsample
# Conv layers
self.conv1 = self.which_conv(self.in_channels, self.out_channels)
self.conv2 = self.which_conv(self.out_channels, self.out_channels)
self.learnable_sc = in_channels != out_channels or upsample
if self.learnable_sc:
self.conv_sc = self.which_conv(in_channels, out_channels,
kernel_size=1, padding=0)
# Batchnorm layers
self.bn1 = self.which_bn(in_channels)
self.bn2 = self.which_bn(out_channels)
# upsample layers
self.upsample = upsample
def forward(self, x, y):
h = self.activation(self.bn1(x, y))
if self.upsample:
h = self.upsample(h)
x = self.upsample(x)
h = self.conv1(h)
h = self.activation(self.bn2(h, y))
h = self.conv2(h)
if self.learnable_sc:
x = self.conv_sc(x)
return h + x
# Residual block for the discriminator
class DBlock(nn.Module):
def __init__(self, in_channels, out_channels, which_conv=SNConv2d, wide=True,
preactivation=False, activation=None, downsample=None, ):
super(DBlock, self).__init__()
self.in_channels, self.out_channels = in_channels, out_channels
# If using wide D (as in SA-GAN and BigGAN), change the channel pattern
self.hidden_channels = self.out_channels if wide else self.in_channels
self.which_conv = which_conv
self.preactivation = preactivation
self.activation = activation
self.downsample = downsample
# Conv layers
self.conv1 = self.which_conv(self.in_channels, self.hidden_channels)
self.conv2 = self.which_conv(self.hidden_channels, self.out_channels)
self.learnable_sc = True if (in_channels != out_channels) or downsample else False
if self.learnable_sc:
self.conv_sc = self.which_conv(in_channels, out_channels,
kernel_size=1, padding=0)
def shortcut(self, x):
if self.preactivation:
if self.learnable_sc:
x = self.conv_sc(x)
if self.downsample:
x = self.downsample(x)
else:
if self.downsample:
x = self.downsample(x)
if self.learnable_sc:
x = self.conv_sc(x)
return x
def forward(self, x):
if self.preactivation:
# h = self.activation(x) # NOT TODAY SATAN
# Andy's note: This line *must* be an out-of-place ReLU or it
# will negatively affect the shortcut connection.
h = F.relu(x)
else:
h = x
h = self.conv1(h)
h = self.conv2(self.activation(h))
if self.downsample:
h = self.downsample(h)
return h + self.shortcut(x)

View File

@ -0,0 +1,375 @@
from collections import OrderedDict
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
class BlurLayer(nn.Module):
def __init__(self, kernel=None, normalize=True, flip=False, stride=1):
super(BlurLayer, self).__init__()
if kernel is None:
kernel = [1, 2, 1]
kernel = torch.tensor(kernel, dtype=torch.float32)
kernel = kernel[:, None] * kernel[None, :]
kernel = kernel[None, None]
if normalize:
kernel = kernel / kernel.sum()
if flip:
kernel = kernel[:, :, ::-1, ::-1]
self.register_buffer('kernel', kernel)
self.stride = stride
def forward(self, x):
# expand kernel channels
kernel = self.kernel.expand(x.size(1), -1, -1, -1)
x = F.conv2d(
x,
kernel,
stride=self.stride,
padding=int((self.kernel.size(2) - 1) / 2),
groups=x.size(1)
)
return x
class Upscale2d(nn.Module):
@staticmethod
def upscale2d(x, factor=2, gain=1):
assert x.dim() == 4
if gain != 1:
x = x * gain
if factor != 1:
shape = x.shape
x = x.view(shape[0], shape[1], shape[2], 1, shape[3], 1).expand(-1, -1, -1, factor, -1, factor)
x = x.contiguous().view(shape[0], shape[1], factor * shape[2], factor * shape[3])
return x
def __init__(self, factor=2, gain=1):
super().__init__()
assert isinstance(factor, int) and factor >= 1
self.gain = gain
self.factor = factor
def forward(self, x):
return self.upscale2d(x, factor=self.factor, gain=self.gain)
class Downscale2d(nn.Module):
def __init__(self, factor=2, gain=1):
super().__init__()
assert isinstance(factor, int) and factor >= 1
self.factor = factor
self.gain = gain
if factor == 2:
f = [np.sqrt(gain) / factor] * factor
self.blur = BlurLayer(kernel=f, normalize=False, stride=factor)
else:
self.blur = None
def forward(self, x):
assert x.dim() == 4
# 2x2, float32 => downscale using _blur2d().
if self.blur is not None and x.dtype == torch.float32:
return self.blur(x)
# Apply gain.
if self.gain != 1:
x = x * self.gain
# No-op => early exit.
if self.factor == 1:
return x
# Large factor => downscale using tf.nn.avg_pool().
# NOTE: Requires tf_config['graph_options.place_pruned_graph']=True to work.
return F.avg_pool2d(x, self.factor)
class EqualizedConv2d(nn.Module):
"""Conv layer with equalized learning rate and custom learning rate multiplier."""
def __init__(self, input_channels, output_channels, kernel_size, stride=1, gain=2 ** 0.5, use_wscale=False,
lrmul=1, bias=True, intermediate=None, upscale=False, downscale=False):
super().__init__()
if upscale:
self.upscale = Upscale2d()
else:
self.upscale = None
if downscale:
self.downscale = Downscale2d()
else:
self.downscale = None
he_std = gain * (input_channels * kernel_size ** 2) ** (-0.5) # He init
self.kernel_size = kernel_size
if use_wscale:
init_std = 1.0 / lrmul
self.w_mul = he_std * lrmul
else:
init_std = he_std / lrmul
self.w_mul = lrmul
self.weight = torch.nn.Parameter(
torch.randn(output_channels, input_channels, kernel_size, kernel_size) * init_std)
if bias:
self.bias = torch.nn.Parameter(torch.zeros(output_channels))
self.b_mul = lrmul
else:
self.bias = None
self.intermediate = intermediate
def forward(self, x):
bias = self.bias
if bias is not None:
bias = bias * self.b_mul
have_convolution = False
if self.upscale is not None and min(x.shape[2:]) * 2 >= 128:
# this is the fused upscale + conv from StyleGAN, sadly this seems incompatible with the non-fused way
# this really needs to be cleaned up and go into the conv...
w = self.weight * self.w_mul
w = w.permute(1, 0, 2, 3)
# probably applying a conv on w would be more efficient. also this quadruples the weight (average)?!
w = F.pad(w, [1, 1, 1, 1])
w = w[:, :, 1:, 1:] + w[:, :, :-1, 1:] + w[:, :, 1:, :-1] + w[:, :, :-1, :-1]
x = F.conv_transpose2d(x, w, stride=2, padding=(w.size(-1) - 1) // 2)
have_convolution = True
elif self.upscale is not None:
x = self.upscale(x)
downscale = self.downscale
intermediate = self.intermediate
if downscale is not None and min(x.shape[2:]) >= 128:
w = self.weight * self.w_mul
w = F.pad(w, [1, 1, 1, 1])
# in contrast to upscale, this is a mean...
w = (w[:, :, 1:, 1:] + w[:, :, :-1, 1:] + w[:, :, 1:, :-1] + w[:, :, :-1, :-1]) * 0.25 # avg_pool?
x = F.conv2d(x, w, stride=2, padding=(w.size(-1) - 1) // 2)
have_convolution = True
downscale = None
elif downscale is not None:
assert intermediate is None
intermediate = downscale
if not have_convolution and intermediate is None:
return F.conv2d(x, self.weight * self.w_mul, bias, padding=self.kernel_size // 2)
elif not have_convolution:
x = F.conv2d(x, self.weight * self.w_mul, None, padding=self.kernel_size // 2)
if intermediate is not None:
x = intermediate(x)
if bias is not None:
x = x + bias.view(1, -1, 1, 1)
return x
class EqualizedLinear(nn.Module):
"""Linear layer with equalized learning rate and custom learning rate multiplier."""
def __init__(self, input_size, output_size, gain=2 ** 0.5, use_wscale=False, lrmul=1, bias=True):
super().__init__()
he_std = gain * input_size ** (-0.5) # He init
# Equalized learning rate and custom learning rate multiplier.
if use_wscale:
init_std = 1.0 / lrmul
self.w_mul = he_std * lrmul
else:
init_std = he_std / lrmul
self.w_mul = lrmul
self.weight = torch.nn.Parameter(torch.randn(output_size, input_size) * init_std)
if bias:
self.bias = torch.nn.Parameter(torch.zeros(output_size))
self.b_mul = lrmul
else:
self.bias = None
def forward(self, x):
bias = self.bias
if bias is not None:
bias = bias * self.b_mul
return F.linear(x, self.weight * self.w_mul, bias)
class View(nn.Module):
def __init__(self, *shape):
super().__init__()
self.shape = shape
def forward(self, x):
return x.view(x.size(0), *self.shape)
class StddevLayer(nn.Module):
def __init__(self, group_size=4, num_new_features=1):
super().__init__()
self.group_size = group_size
self.num_new_features = num_new_features
def forward(self, x):
b, c, h, w = x.shape
group_size = min(self.group_size, b)
y = x.reshape([group_size, -1, self.num_new_features,
c // self.num_new_features, h, w])
y = y - y.mean(0, keepdim=True)
y = (y ** 2).mean(0, keepdim=True)
y = (y + 1e-8) ** 0.5
y = y.mean([3, 4, 5], keepdim=True).squeeze(3) # don't keep the meaned-out channels
y = y.expand(group_size, -1, -1, h, w).clone().reshape(b, self.num_new_features, h, w)
z = torch.cat([x, y], dim=1)
return z
class DiscriminatorBlock(nn.Sequential):
def __init__(self, in_channels, out_channels, gain, use_wscale, activation_layer, blur_kernel):
super().__init__(OrderedDict([
('conv0', EqualizedConv2d(in_channels, in_channels, kernel_size=3, gain=gain, use_wscale=use_wscale)),
# out channels nf(res-1)
('act0', activation_layer),
('blur', BlurLayer(kernel=blur_kernel)),
('conv1_down', EqualizedConv2d(in_channels, out_channels, kernel_size=3,
gain=gain, use_wscale=use_wscale, downscale=True)),
('act1', activation_layer)]))
class DiscriminatorTop(nn.Sequential):
def __init__(self,
mbstd_group_size,
mbstd_num_features,
in_channels,
intermediate_channels,
gain, use_wscale,
activation_layer,
resolution=4,
in_channels2=None,
output_features=1,
last_gain=1):
"""
:param mbstd_group_size:
:param mbstd_num_features:
:param in_channels:
:param intermediate_channels:
:param gain:
:param use_wscale:
:param activation_layer:
:param resolution:
:param in_channels2:
:param output_features:
:param last_gain:
"""
layers = []
if mbstd_group_size > 1:
layers.append(('stddev_layer', StddevLayer(mbstd_group_size, mbstd_num_features)))
if in_channels2 is None:
in_channels2 = in_channels
layers.append(('conv', EqualizedConv2d(in_channels + mbstd_num_features, in_channels2, kernel_size=3,
gain=gain, use_wscale=use_wscale)))
layers.append(('act0', activation_layer))
layers.append(('view', View(-1)))
layers.append(('dense0', EqualizedLinear(in_channels2 * resolution * resolution, intermediate_channels,
gain=gain, use_wscale=use_wscale)))
layers.append(('act1', activation_layer))
layers.append(('dense1', EqualizedLinear(intermediate_channels, output_features,
gain=last_gain, use_wscale=use_wscale)))
super().__init__(OrderedDict(layers))
class StyleGanDiscriminator(nn.Module):
def __init__(self, resolution, num_channels=3, fmap_base=8192, fmap_decay=1.0, fmap_max=512,
nonlinearity='lrelu', use_wscale=True, mbstd_group_size=4, mbstd_num_features=1,
blur_filter=None, structure='fixed', **kwargs):
"""
Discriminator used in the StyleGAN paper.
:param num_channels: Number of input color channels. Overridden based on dataset.
:param resolution: Input resolution. Overridden based on dataset.
# label_size=0, # Dimensionality of the labels, 0 if no labels. Overridden based on dataset.
:param fmap_base: Overall multiplier for the number of feature maps.
:param fmap_decay: log2 feature map reduction when doubling the resolution.
:param fmap_max: Maximum number of feature maps in any layer.
:param nonlinearity: Activation function: 'relu', 'lrelu'
:param use_wscale: Enable equalized learning rate?
:param mbstd_group_size: Group size for the mini_batch standard deviation layer, 0 = disable.
:param mbstd_num_features: Number of features for the mini_batch standard deviation layer.
:param blur_filter: Low-pass filter to apply when resampling activations. None = no filtering.
:param structure: 'fixed' = no progressive growing, 'linear' = human-readable
:param kwargs: Ignore unrecognized keyword args.
"""
super(StyleGanDiscriminator, self).__init__()
def nf(stage):
return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max)
self.mbstd_num_features = mbstd_num_features
self.mbstd_group_size = mbstd_group_size
self.structure = structure
# if blur_filter is None:
# blur_filter = [1, 2, 1]
resolution_log2 = int(np.log2(resolution))
assert resolution == 2 ** resolution_log2 and resolution >= 4
self.depth = resolution_log2 - 1
act, gain = {'relu': (torch.relu, np.sqrt(2)),
'lrelu': (nn.LeakyReLU(negative_slope=0.2), np.sqrt(2))}[nonlinearity]
# create the remaining layers
blocks = []
from_rgb = []
for res in range(resolution_log2, 2, -1):
# name = '{s}x{s}'.format(s=2 ** res)
blocks.append(DiscriminatorBlock(nf(res - 1), nf(res - 2),
gain=gain, use_wscale=use_wscale, activation_layer=act,
blur_kernel=blur_filter))
# create the fromRGB layers for various inputs:
from_rgb.append(EqualizedConv2d(num_channels, nf(res - 1), kernel_size=1,
gain=gain, use_wscale=use_wscale))
self.blocks = nn.ModuleList(blocks)
# Building the final block.
self.final_block = DiscriminatorTop(self.mbstd_group_size, self.mbstd_num_features,
in_channels=nf(2), intermediate_channels=nf(2),
gain=gain, use_wscale=use_wscale, activation_layer=act)
from_rgb.append(EqualizedConv2d(num_channels, nf(2), kernel_size=1,
gain=gain, use_wscale=use_wscale))
self.from_rgb = nn.ModuleList(from_rgb)
# register the temporary downSampler
self.temporaryDownsampler = nn.AvgPool2d(2)
def forward(self, images_in, depth=0, alpha=1.):
"""
:param images_in: First input: Images [mini_batch, channel, height, width].
:param labels_in: Second input: Labels [mini_batch, label_size].
:param depth: current height of operation (Progressive GAN)
:param alpha: current value of alpha for fade-in
:return:
"""
if self.structure == 'fixed':
x = self.from_rgb[0](images_in)
for i, block in enumerate(self.blocks):
x = block(x)
scores_out = self.final_block(x)
elif self.structure == 'linear':
assert depth < self.depth, "Requested output depth cannot be produced"
if depth > 0:
residual = self.from_rgb[self.depth - depth](self.temporaryDownsampler(images_in))
straight = self.blocks[self.depth - depth - 1](self.from_rgb[self.depth - depth - 1](images_in))
x = (alpha * straight) + ((1 - alpha) * residual)
for block in self.blocks[(self.depth - depth):]:
x = block(x)
else:
x = self.from_rgb[-1](images_in)
scores_out = self.final_block(x)
else:
raise KeyError("Unknown structure: ", self.structure)
return scores_out

View File

View File

@ -7,8 +7,7 @@ import torch
import torchvision
from munch import munchify
import models.archs.DiscriminatorResnet_arch as DiscriminatorResnet_arch
import models.archs.DiscriminatorResnet_arch_passthrough as DiscriminatorResnet_arch_passthrough
import models.archs.fixup_resnet.DiscriminatorResnet_arch as DiscriminatorResnet_arch
import models.archs.RRDBNet_arch as RRDBNet_arch
import models.archs.SPSR_arch as spsr
import models.archs.SRResNet_arch as SRResNet_arch
@ -17,9 +16,11 @@ import models.archs.discriminator_vgg_arch as SRGAN_arch
import models.archs.feature_arch as feature_arch
import models.archs.panet.panet as panet
import models.archs.rcan as rcan
import models.archs.ChainedEmbeddingGen as chained
from models.archs import srg2_classic
from models.archs.biggan.biggan_discriminator import BigGanDiscriminator
from models.archs.stylegan.Discriminator_StyleGAN import StyleGanDiscriminator
from models.archs.pyramid_arch import BasicResamplingFlowNet
from models.archs.rrdb_with_adain_latent import AdaRRDBNet, LinearLatentEstimator
from models.archs.rrdb_with_latent import LatentEstimator, RRDBNetWithLatent, LatentEstimator2
from models.archs.teco_resgen import TecoGen
@ -90,15 +91,6 @@ def define_G(opt, net_key='network_G', scale=None):
netG = spsr.Spsr7(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],
multiplexer_reductions=opt_net['multiplexer_reductions'] if 'multiplexer_reductions' in opt_net.keys() else 3,
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10, recurrent=recurrent)
elif which_model == 'chained_gen_structured':
rec = opt_net['recurrent'] if 'recurrent' in opt_net.keys() else False
recnf = opt_net['recurrent_nf'] if 'recurrent_nf' in opt_net.keys() else 3
recstd = opt_net['recurrent_stride'] if 'recurrent_stride' in opt_net.keys() else 2
in_nc = opt_net['in_nc'] if 'in_nc' in opt_net.keys() else 3
netG = chained.ChainedEmbeddingGenWithStructure(depth=opt_net['depth'], recurrent=rec, recurrent_nf=recnf, recurrent_stride=recstd, in_nc=in_nc)
elif which_model == 'multifaceted_chained':
scale = opt_net['scale'] if 'scale' in opt_net.keys() else 2
netG = chained.MultifacetedChainedEmbeddingGen(depth=opt_net['depth'], scale=scale)
elif which_model == "flownet2":
from models.flownet2.models import FlowNet2
ld = 'load_path' in opt_net.keys()
@ -125,12 +117,19 @@ def define_G(opt, net_key='network_G', scale=None):
blocks_per_checkpoint=opt_net['blocks_per_checkpoint'],
scale=opt_net['scale'],
bottom_latent_only=opt_net['bottom_latent_only'])
elif which_model == "adarrdb":
netG = AdaRRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'],
blocks_per_checkpoint=opt_net['blocks_per_checkpoint'],
scale=opt_net['scale'])
elif which_model == "latent_estimator":
if opt_net['version'] == 2:
netG = LatentEstimator2(in_nc=3, nf=opt_net['nf'])
else:
overwrite = [1,2] if opt_net['only_base_level'] else []
netG = LatentEstimator(in_nc=3, nf=opt_net['nf'], overwrite_levels=overwrite)
elif which_model == "linear_latent_estimator":
netG = LinearLatentEstimator(in_nc=3, nf=opt_net['nf'])
else:
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
return netG
@ -159,19 +158,19 @@ def define_D_net(opt_net, img_sz=None, wrap=False):
netD = GradDiscWrapper(netD)
elif which_model == 'discriminator_vgg_128_gn_checkpointed':
netD = SRGAN_arch.Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128, do_checkpointing=True)
elif which_model == 'stylegan_vgg':
netD = StyleGanDiscriminator(128)
elif which_model == 'discriminator_resnet':
netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
elif which_model == 'discriminator_resnet_50':
netD = DiscriminatorResnet_arch.fixup_resnet50(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
elif which_model == 'discriminator_resnet_passthrough':
netD = DiscriminatorResnet_arch_passthrough.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz,
number_skips=opt_net['number_skips'], use_bn=True,
disable_passthrough=opt_net['disable_passthrough'])
elif which_model == 'resnext':
netD = torchvision.models.resnext50_32x4d(norm_layer=functools.partial(torch.nn.GroupNorm, 8))
state_dict = torch.hub.load_state_dict_from_url('https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', progress=True)
netD.load_state_dict(state_dict, strict=False)
#state_dict = torch.hub.load_state_dict_from_url('https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', progress=True)
#netD.load_state_dict(state_dict, strict=False)
netD.fc = torch.nn.Linear(512 * 4, 1)
elif which_model == 'biggan_resnet':
netD = BigGanDiscriminator(D_activation=torch.nn.LeakyReLU(negative_slope=.2))
elif which_model == 'discriminator_pix':
netD = SRGAN_arch.Discriminator_VGG_PixLoss(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
elif which_model == "discriminator_unet":

View File

@ -265,7 +265,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_latent_mi1_rrdb4x_6bl_lower_signal.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_adalatent_mi1_rrdb4x_6bl.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()

View File

@ -280,7 +280,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_latent_mi1_rrdb4x_6bl_lower_signal_2.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_adalatent_mi1_rrdb4x_6bl_resdisc.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
args = parser.parse_args()
opt = option.parse(args.opt, is_train=True)