diff --git a/codes/models/__init__.py b/codes/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/codes/models/archs/ChainedEmbeddingGen.py b/codes/models/archs/ChainedEmbeddingGen.py deleted file mode 100644 index f283e1b3..00000000 --- a/codes/models/archs/ChainedEmbeddingGen.py +++ /dev/null @@ -1,197 +0,0 @@ -import os - -import torch -import torchvision -from torch import nn - -from models.archs.SPSR_arch import ImageGradientNoPadding -from models.archs.arch_util import ConvGnLelu, ExpansionBlock2, ConvGnSilu, ConjoinBlock, MultiConvBlock, \ - FinalUpsampleBlock2x, ReferenceJoinBlock -from models.archs.spinenet_arch import SpineNet -from utils.util import checkpoint - - -class BasicEmbeddingPyramid(nn.Module): - def __init__(self, use_norms=True): - super(BasicEmbeddingPyramid, self).__init__() - self.initial_process = ConvGnLelu(64, 64, kernel_size=1, bias=True, activation=True, norm=False) - self.reducers = nn.ModuleList([ConvGnLelu(64, 128, stride=2, kernel_size=1, bias=False, activation=True, norm=False), - ConvGnLelu(128, 128, kernel_size=3, bias=False, activation=True, norm=use_norms), - ConvGnLelu(128, 256, stride=2, kernel_size=1, bias=False, activation=True, norm=False), - ConvGnLelu(256, 256, kernel_size=3, bias=False, activation=True, norm=use_norms)]) - self.expanders = nn.ModuleList([ExpansionBlock2(256, 128, block=ConvGnLelu), - ExpansionBlock2(128, 64, block=ConvGnLelu)]) - self.embedding_processor1 = ConvGnSilu(256, 128, kernel_size=1, bias=True, activation=True, norm=False) - self.embedding_joiner1 = ConjoinBlock(128, block=ConvGnLelu, norm=use_norms) - self.embedding_processor2 = ConvGnSilu(256, 256, kernel_size=1, bias=True, activation=True, norm=False) - self.embedding_joiner2 = ConjoinBlock(256, block=ConvGnLelu, norm=use_norms) - - self.final_process = nn.Sequential(ConvGnLelu(128, 96, kernel_size=1, bias=False, activation=False, norm=False, - weight_init_factor=.1), - ConvGnLelu(96, 64, kernel_size=1, bias=False, activation=False, norm=False, - weight_init_factor=.1), - ConvGnLelu(64, 64, kernel_size=1, bias=False, activation=False, norm=False, - weight_init_factor=.1), - ConvGnLelu(64, 64, kernel_size=1, bias=False, activation=False, norm=False, - weight_init_factor=.1)) - - def forward(self, x, *embeddings): - p = self.initial_process(x) - identities = [] - for i in range(2): - identities.append(p) - p = self.reducers[i*2](p) - p = self.reducers[i*2+1](p) - if i == 0: - p = self.embedding_joiner1(p, self.embedding_processor1(embeddings[0])) - elif i == 1: - p = self.embedding_joiner2(p, self.embedding_processor2(embeddings[1])) - for i in range(2): - p = self.expanders[i](p, identities[-(i+1)]) - x = self.final_process(torch.cat([x, p], dim=1)) - return x, p - - - - -class ChainedEmbeddingGenWithStructure(nn.Module): - def __init__(self, in_nc=3, depth=10, recurrent=False, recurrent_nf=3, recurrent_stride=2): - super(ChainedEmbeddingGenWithStructure, self).__init__() - self.recurrent = recurrent - self.initial_conv = ConvGnLelu(in_nc, 64, kernel_size=7, bias=True, norm=False, activation=False) - if recurrent: - self.recurrent_nf = recurrent_nf - self.recurrent_stride = recurrent_stride - self.recurrent_process = ConvGnLelu(recurrent_nf, 64, kernel_size=3, stride=recurrent_stride, norm=False, bias=True, activation=False) - self.recurrent_join = ReferenceJoinBlock(64, residual_weight_init_factor=.01, final_norm=False, kernel_size=1, depth=3, join=False) - self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False) - self.blocks = nn.ModuleList([BasicEmbeddingPyramid() for i in range(depth)]) - self.structure_joins = nn.ModuleList([ConjoinBlock(64) for i in range(3)]) - self.structure_blocks = nn.ModuleList([ConvGnLelu(64, 64, kernel_size=3, bias=False, norm=False, activation=False, weight_init_factor=.1) for i in range(3)]) - self.structure_upsample = FinalUpsampleBlock2x(64) - self.grad_extract = ImageGradientNoPadding() - self.upsample = FinalUpsampleBlock2x(64) - self.ref_join_std = 0 - - def forward(self, x, recurrent=None): - fea = self.initial_conv(x) - if self.recurrent: - if recurrent is None: - if self.recurrent_nf == 3: - recurrent = torch.zeros_like(x) - if self.recurrent_stride != 1: - recurrent = torch.nn.functional.interpolate(recurrent, scale_factor=self.recurrent_stride, mode='nearest') - else: - recurrent = torch.zeros_like(fea) - rec = self.recurrent_process(recurrent) - fea, recstd = self.recurrent_join(fea, rec) - self.ref_join_std = recstd.item() - if self.spine is not None: - emb = checkpoint(self.spine, fea) - else: - b,f,h,w = fea.shape - emb = (torch.zeros((b,f,h//2,w//2), device=fea.device), - torch.zeros((b,f,h//4,w//4), device=fea.device)) - grad = fea - for i, block in enumerate(self.blocks): - fea = fea + checkpoint(block, fea, *emb)[0] - if i < 3: - structure_br = checkpoint(self.structure_joins[i], grad, fea) - grad = grad + checkpoint(self.structure_blocks[i], structure_br) - out = checkpoint(self.upsample, fea) - return out, self.grad_extract(checkpoint(self.structure_upsample, grad)), self.grad_extract(out), fea - - def get_debug_values(self, step, net_name): - return { 'ref_join_std': self.ref_join_std } - - -# This is a structural block that learns to mute regions of a residual transformation given a signal. -class OptionalPassthroughBlock(nn.Module): - def __init__(self, nf, initial_bias=10): - super(OptionalPassthroughBlock, self).__init__() - self.switch_process = nn.Sequential(ConvGnLelu(nf, nf // 2, 1, activation=False, norm=False, bias=False), - ConvGnLelu(nf // 2, nf // 4, 1, activation=False, norm=False, bias=False), - ConvGnLelu(nf // 4, 1, 1, activation=False, norm=False, bias=False)) - self.bias = nn.Parameter(torch.tensor(initial_bias, dtype=torch.float), requires_grad=True) - self.activation = nn.Sigmoid() - - def forward(self, x, switch_signal): - switch = self.switch_process(switch_signal) - bypass_map = self.activation(self.bias + switch) - return x * bypass_map, bypass_map - - -class MultifacetedChainedEmbeddingGen(nn.Module): - def __init__(self, depth=10, scale=2): - super(MultifacetedChainedEmbeddingGen, self).__init__() - assert scale == 2 or scale == 4 - - self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False) - - if scale == 2: - self.teco_recurrent_process = ConvGnLelu(3, 64, kernel_size=3, stride=2, norm=False, bias=True, activation=False) - else: - self.teco_recurrent_process = ConvGnLelu(3, 64, kernel_size=7, stride=4, norm=False, bias=True, activation=False) - self.teco_recurrent_join = ReferenceJoinBlock(64, residual_weight_init_factor=.01, final_norm=False, kernel_size=1, depth=3, join=False) - - self.prog_recurrent_process = ConvGnLelu(64, 64, kernel_size=3, stride=1, norm=False, bias=True, activation=False) - self.prog_recurrent_join = ReferenceJoinBlock(64, residual_weight_init_factor=.01, final_norm=False, kernel_size=1, depth=3, join=False) - - self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False) - self.blocks = nn.ModuleList([BasicEmbeddingPyramid() for i in range(depth)]) - self.bypasses = nn.ModuleList([OptionalPassthroughBlock(64, initial_bias=0) for i in range(depth)]) - self.structure_joins = nn.ModuleList([ConjoinBlock(64) for i in range(3)]) - self.structure_blocks = nn.ModuleList([ConvGnLelu(64, 64, kernel_size=3, bias=False, norm=False, activation=False, weight_init_factor=.1) for i in range(3)]) - self.structure_upsample = FinalUpsampleBlock2x(64, scale=scale) - self.grad_extract = ImageGradientNoPadding() - self.upsample = FinalUpsampleBlock2x(64, scale=scale) - - self.teco_ref_std = 0 - self.prog_ref_std = 0 - self.block_residual_means = [0 for _ in range(depth)] - self.block_residual_stds = [0 for _ in range(depth)] - self.bypass_maps = [] - - def forward(self, x, teco_recurrent=None, prog_recurrent=None): - fea = self.initial_conv(x) - - # Integrate recurrence inputs. - if teco_recurrent is not None: - teco_rec = self.teco_recurrent_process(teco_recurrent) - fea, std = self.teco_recurrent_join(fea, teco_rec) - self.teco_ref_std = std.item() - elif prog_recurrent is not None: - prog_rec = self.prog_recurrent_process(prog_recurrent) - prog_rec, std = self.prog_recurrent_join(fea, prog_rec) - self.prog_ref_std = std.item() - - emb = checkpoint(self.spine, fea) - grad = fea - self.bypass_maps = [] - for i, block in enumerate(self.blocks): - residual, context = checkpoint(block, fea, *emb) - residual, bypass_map = checkpoint(self.bypasses[i], residual, context) - fea = fea + residual - self.bypass_maps.append(bypass_map.detach()) - self.block_residual_means[i] = residual.mean().item() - self.block_residual_stds[i] = residual.std().item() - if i < 3: - structure_br = checkpoint(self.structure_joins[i], grad, fea) - grad = grad + checkpoint(self.structure_blocks[i], structure_br) - out = checkpoint(self.upsample, fea) - return out, self.grad_extract(checkpoint(self.structure_upsample, grad)), self.grad_extract(out), fea - - def visual_dbg(self, step, path): - for i, bm in enumerate(self.bypass_maps): - torchvision.utils.save_image(bm.cpu().float(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1))) - - def get_debug_values(self, step, net_name): - biases = [b.bias.item() for b in self.bypasses] - blk_stds, blk_means = {}, {} - for i, (s, m) in enumerate(zip(self.block_residual_stds, self.block_residual_means)): - blk_stds['block_%i' % (i+1,)] = s - blk_means['block_%i' % (i+1,)] = m - return {'teco_std': self.teco_ref_std, - 'prog_std': self.prog_ref_std, - 'bypass_biases': sum(biases) / len(biases), - 'blocks_std': blk_stds, 'blocks_mean': blk_means} diff --git a/codes/models/archs/DiscriminatorResnet_arch_passthrough.py b/codes/models/archs/DiscriminatorResnet_arch_passthrough.py deleted file mode 100644 index 236b1647..00000000 --- a/codes/models/archs/DiscriminatorResnet_arch_passthrough.py +++ /dev/null @@ -1,225 +0,0 @@ -import torch -import torch.nn as nn -import numpy as np - - -__all__ = ['FixupResNet', 'fixup_resnet18', 'fixup_resnet34', 'fixup_resnet50', 'fixup_resnet101', 'fixup_resnet152'] - - -def conv3x3(in_planes, out_planes, stride=1): - """3x3 convolution with padding""" - return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, - padding=1, bias=False) - -def conv5x5(in_planes, out_planes, stride=1): - """3x3 convolution with padding""" - return nn.Conv2d(in_planes, out_planes, kernel_size=5, stride=stride, - padding=2, bias=False) - - -def conv1x1(in_planes, out_planes, stride=1): - """1x1 convolution""" - return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) - - -class FixupBasicBlock(nn.Module): - expansion = 1 - - def __init__(self, inplanes, planes, stride=1, downsample=None, use_bn=False, conv_create=conv3x3): - super(FixupBasicBlock, self).__init__() - # Both self.conv1 and self.downsample layers downsample the input when stride != 1 - self.bias1a = nn.Parameter(torch.zeros(1)) - self.conv1 = conv_create(inplanes, planes, stride) - self.bias1b = nn.Parameter(torch.zeros(1)) - self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) - self.bias2a = nn.Parameter(torch.zeros(1)) - self.conv2 = conv_create(planes, planes) - self.scale = nn.Parameter(torch.ones(1)) - self.bias2b = nn.Parameter(torch.zeros(1)) - self.downsample = downsample - self.stride = stride - self.use_bn = use_bn - if use_bn: - self.bn1 = nn.BatchNorm2d(planes) - self.bn2 = nn.BatchNorm2d(planes) - - def forward(self, x): - identity = x - - out = self.conv1(x + self.bias1a) - if self.use_bn: - out = self.bn1(out) - out = self.lrelu(out + self.bias1b) - - out = self.conv2(out + self.bias2a) - if self.use_bn: - out = self.bn2(out) - out = out * self.scale + self.bias2b - - if self.downsample is not None: - identity = self.downsample(x + self.bias1a) - - out += identity - out = self.lrelu(out) - - return out - -class FixupBottleneck(nn.Module): - expansion = 4 - - def __init__(self, inplanes, planes, stride=1, downsample=None): - super(FixupBottleneck, self).__init__() - # Both self.conv2 and self.downsample layers downsample the input when stride != 1 - self.bias1a = nn.Parameter(torch.zeros(1)) - self.conv1 = conv1x1(inplanes, planes) - self.bias1b = nn.Parameter(torch.zeros(1)) - self.bias2a = nn.Parameter(torch.zeros(1)) - self.conv2 = conv3x3(planes, planes, stride) - self.bias2b = nn.Parameter(torch.zeros(1)) - self.bias3a = nn.Parameter(torch.zeros(1)) - self.conv3 = conv1x1(planes, planes * self.expansion) - self.scale = nn.Parameter(torch.ones(1)) - self.bias3b = nn.Parameter(torch.zeros(1)) - self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) - self.downsample = downsample - self.stride = stride - - def forward(self, x): - identity = x - - out = self.conv1(x + self.bias1a) - out = self.lrelu(out + self.bias1b) - - out = self.conv2(out + self.bias2a) - out = self.lrelu(out + self.bias2b) - - out = self.conv3(out + self.bias3a) - out = out * self.scale + self.bias3b - - if self.downsample is not None: - identity = self.downsample(x + self.bias1a) - - out += identity - out = self.lrelu(out) - - return out - - -class FixupResNet(nn.Module): - - def __init__(self, block, layers, num_filters=64, num_classes=1000, input_img_size=64, number_skips=2, use_bn=False, - disable_passthrough=False): - super(FixupResNet, self).__init__() - self.num_layers = sum(layers) - self.inplanes = 3 - self.number_skips = number_skips - self.disable_passthrough = disable_passthrough - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - self.layer0 = self._make_layer(block, num_filters*2, layers[0], stride=2, use_bn=use_bn, conv_type=conv5x5) - if number_skips > 0: - self.inplanes = self.inplanes + 3 # Accomodate a skip connection from the generator. - self.layer1 = self._make_layer(block, num_filters*4, layers[1], stride=2, use_bn=use_bn, conv_type=conv5x5) - if number_skips > 1: - self.inplanes = self.inplanes + 3 # Accomodate a second skip connection from the generator. - self.layer2 = self._make_layer(block, num_filters*8, layers[2], stride=2, use_bn=use_bn) - # SRGAN already has a feature loss tied to a separate VGG discriminator. We really don't care about features. - # Therefore, level off the filter count from this block forwards. - self.layer3 = self._make_layer(block, num_filters*8, layers[3], stride=2, use_bn=use_bn) - self.layer4 = self._make_layer(block, num_filters*8, layers[4], stride=2, use_bn=use_bn) - self.bias2 = nn.Parameter(torch.zeros(1)) - reduced_img_sz = int(input_img_size / 32) - self.fc1 = nn.Linear(num_filters * 8 * reduced_img_sz * reduced_img_sz, 100) - self.fc2 = nn.Linear(100, num_classes) - - for m in self.modules(): - if isinstance(m, FixupBasicBlock): - nn.init.normal_(m.conv1.weight, mean=0, std=np.sqrt(2 / (m.conv1.weight.shape[0] * np.prod(m.conv1.weight.shape[2:]))) * self.num_layers ** (-0.5)) - nn.init.constant_(m.conv2.weight, 0) - if m.downsample is not None: - nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:])))) - elif isinstance(m, FixupBottleneck): - nn.init.normal_(m.conv1.weight, mean=0, std=np.sqrt(2 / (m.conv1.weight.shape[0] * np.prod(m.conv1.weight.shape[2:]))) * self.num_layers ** (-0.25)) - nn.init.normal_(m.conv2.weight, mean=0, std=np.sqrt(2 / (m.conv2.weight.shape[0] * np.prod(m.conv2.weight.shape[2:]))) * self.num_layers ** (-0.25)) - nn.init.constant_(m.conv3.weight, 0) - if m.downsample is not None: - nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:])))) - - def _make_layer(self, block, outplanes, blocks, stride=1, use_bn=False, conv_type=conv3x3): - layers = [] - for _ in range(1, blocks): - layers.append(block(self.inplanes, self.inplanes)) - - downsample = None - if stride != 1 or self.inplanes != outplanes * block.expansion: - downsample = conv1x1(self.inplanes, outplanes * block.expansion, stride) - layers.append(block(self.inplanes, outplanes, stride, downsample, use_bn=use_bn, conv_create=conv_type)) - self.inplanes = outplanes * block.expansion - - return nn.Sequential(*layers) - - def forward(self, x): - if len(x) == 3: - # This class can take a medium skip (half-res) and low skip (quarter-res) provided as a tuple in the input. - x, med_skip, lo_skip = x - else: - # Or just a tuple with only the high res input (this assumes number_skips was set right). - x = x[0] - - if self.disable_passthrough: - if self.number_skips > 0: - med_skip = torch.zeros_like(med_skip) - if self.number_skips > 1: - lo_skip = torch.zeros_like(lo_skip) - x = self.layer0(x) - if self.number_skips > 0: - x = torch.cat([x, med_skip], dim=1) - x = self.layer1(x) - if self.number_skips > 1: - x = torch.cat([x, lo_skip], dim=1) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) - - x = x.view(x.size(0), -1) - x = self.lrelu(self.fc1(x)) - x = self.fc2(x + self.bias2) - - return x - - -def fixup_resnet18(**kwargs): - """Constructs a Fixup-ResNet-18 model.2 - """ - model = FixupResNet(FixupBasicBlock, [2, 2, 2, 2, 2], **kwargs) - return model - - -def fixup_resnet34(**kwargs): - """Constructs a Fixup-ResNet-34 model. - """ - model = FixupResNet(FixupBasicBlock, [5, 5, 3, 3, 3], **kwargs) - return model - - -def fixup_resnet50(**kwargs): - """Constructs a Fixup-ResNet-50 model. - """ - model = FixupResNet(FixupBottleneck, [3, 4, 6, 3, 2], **kwargs) - return model - - -def fixup_resnet101(**kwargs): - """Constructs a Fixup-ResNet-101 model. - """ - model = FixupResNet(FixupBottleneck, [3, 4, 23, 3, 2], **kwargs) - return model - - -def fixup_resnet152(**kwargs): - """Constructs a Fixup-ResNet-152 model. - """ - model = FixupResNet(FixupBottleneck, [3, 8, 36, 3, 2], **kwargs) - return model - - -__all__ = ['FixupResNet', 'fixup_resnet18', 'fixup_resnet34', 'fixup_resnet50', 'fixup_resnet101', 'fixup_resnet152'] \ No newline at end of file diff --git a/codes/models/archs/biggan/biggan_discriminator.py b/codes/models/archs/biggan/biggan_discriminator.py new file mode 100644 index 00000000..a85f4443 --- /dev/null +++ b/codes/models/archs/biggan/biggan_discriminator.py @@ -0,0 +1,139 @@ +import functools + +import torch +from torch.nn import init + +import models.archs.biggan.biggan_layers as layers +import torch.nn as nn + + +# Discriminator architecture, same paradigm as G's above +def D_arch(ch=64, attention='64',ksize='333333', dilation='111111'): + arch = {} + arch[256] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8, 8, 16]], + 'out_channels' : [item * ch for item in [1, 2, 4, 8, 8, 16, 16]], + 'downsample' : [True] * 6 + [False], + 'resolution' : [128, 64, 32, 16, 8, 4, 4 ], + 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')] + for i in range(2,8)}} + arch[128] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8, 16]], + 'out_channels' : [item * ch for item in [1, 2, 4, 8, 16, 16]], + 'downsample' : [True] * 5 + [False], + 'resolution' : [64, 32, 16, 8, 4, 4], + 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')] + for i in range(2,8)}} + arch[64] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8]], + 'out_channels' : [item * ch for item in [1, 2, 4, 8, 16]], + 'downsample' : [True] * 4 + [False], + 'resolution' : [32, 16, 8, 4, 4], + 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')] + for i in range(2,7)}} + arch[32] = {'in_channels' : [3] + [item * ch for item in [4, 4, 4]], + 'out_channels' : [item * ch for item in [4, 4, 4, 4]], + 'downsample' : [True, True, False, False], + 'resolution' : [16, 16, 16, 16], + 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')] + for i in range(2,6)}} + return arch + + +class BigGanDiscriminator(nn.Module): + + def __init__(self, D_ch=64, D_wide=True, resolution=128, + D_kernel_size=3, D_attn='64', num_D_SVs=1, num_D_SV_itrs=1, D_activation=nn.ReLU(inplace=False), + SN_eps=1e-12, output_dim=1, D_fp16=False, + D_init='ortho', skip_init=False, D_param='SN'): + super(BigGanDiscriminator, self).__init__() + # Width multiplier + self.ch = D_ch + # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN? + self.D_wide = D_wide + # Resolution + self.resolution = resolution + # Kernel size + self.kernel_size = D_kernel_size + # Attention? + self.attention = D_attn + # Activation + self.activation = D_activation + # Initialization style + self.init = D_init + # Parameterization style + self.D_param = D_param + # Epsilon for Spectral Norm? + self.SN_eps = SN_eps + # Fp16? + self.fp16 = D_fp16 + # Architecture + self.arch = D_arch(self.ch, self.attention)[resolution] + + # Which convs, batchnorms, and linear layers to use + # No option to turn off SN in D right now + if self.D_param == 'SN': + self.which_conv = functools.partial(layers.SNConv2d, + kernel_size=3, padding=1, + num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, + eps=self.SN_eps) + self.which_linear = functools.partial(layers.SNLinear, + num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, + eps=self.SN_eps) + self.which_embedding = functools.partial(layers.SNEmbedding, + num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, + eps=self.SN_eps) + # Prepare model + # self.blocks is a doubly-nested list of modules, the outer loop intended + # to be over blocks at a given resolution (resblocks and/or self-attention) + self.blocks = [] + for index in range(len(self.arch['out_channels'])): + self.blocks += [[layers.DBlock(in_channels=self.arch['in_channels'][index], + out_channels=self.arch['out_channels'][index], + which_conv=self.which_conv, + wide=self.D_wide, + activation=self.activation, + preactivation=(index > 0), + downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] else None))]] + # If attention on this block, attach it to the end + if self.arch['attention'][self.arch['resolution'][index]]: + print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index]) + self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], + self.which_conv)] + # Turn self.blocks into a ModuleList so that it's all properly registered. + self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) + # Linear output layer. The output dimension is typically 1, but may be + # larger if we're e.g. turning this into a VAE with an inference output + self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim) + + # Initialize weights + if not skip_init: + self.init_weights() + + # Initialize + def init_weights(self): + self.param_count = 0 + for module in self.modules(): + if (isinstance(module, nn.Conv2d) + or isinstance(module, nn.Linear) + or isinstance(module, nn.Embedding)): + if self.init == 'ortho': + init.orthogonal_(module.weight) + elif self.init == 'N02': + init.normal_(module.weight, 0, 0.02) + elif self.init in ['glorot', 'xavier']: + init.xavier_uniform_(module.weight) + else: + print('Init style not recognized...') + self.param_count += sum([p.data.nelement() for p in module.parameters()]) + print('Param count for D''s initialized parameters: %d' % self.param_count) + + def forward(self, x, y=None): + # Stick x into h for cleaner for loops without flow control + h = x + # Loop over blocks + for index, blocklist in enumerate(self.blocks): + for block in blocklist: + h = block(h) + # Apply global sum pooling as in SN-GAN + h = torch.sum(self.activation(h), [2, 3]) + # Get initial class-unconditional output + out = self.linear(h) + return out diff --git a/codes/models/archs/biggan/biggan_layers.py b/codes/models/archs/biggan/biggan_layers.py new file mode 100644 index 00000000..292d167f --- /dev/null +++ b/codes/models/archs/biggan/biggan_layers.py @@ -0,0 +1,457 @@ +''' Layers + This file contains various layers for the BigGAN models. +''' +import numpy as np +import torch +import torch.nn as nn +from torch.nn import init +import torch.optim as optim +import torch.nn.functional as F +from torch.nn import Parameter as P + + +# Projection of x onto y +def proj(x, y): + return torch.mm(y, x.t()) * y / torch.mm(y, y.t()) + + +# Orthogonalize x wrt list of vectors ys +def gram_schmidt(x, ys): + for y in ys: + x = x - proj(x, y) + return x + + +# Apply num_itrs steps of the power method to estimate top N singular values. +def power_iteration(W, u_, update=True, eps=1e-12): + # Lists holding singular vectors and values + us, vs, svs = [], [], [] + for i, u in enumerate(u_): + # Run one step of the power iteration + with torch.no_grad(): + v = torch.matmul(u, W) + # Run Gram-Schmidt to subtract components of all other singular vectors + v = F.normalize(gram_schmidt(v, vs), eps=eps) + # Add to the list + vs += [v] + # Update the other singular vector + u = torch.matmul(v, W.t()) + # Run Gram-Schmidt to subtract components of all other singular vectors + u = F.normalize(gram_schmidt(u, us), eps=eps) + # Add to the list + us += [u] + if update: + u_[i][:] = u + # Compute this singular value and add it to the list + svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))] + # svs += [torch.sum(F.linear(u, W.transpose(0, 1)) * v)] + return svs, us, vs + + +# Convenience passthrough function +class identity(nn.Module): + def forward(self, input): + return input + + +# Spectral normalization base class +class SN(object): + def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12): + # Number of power iterations per step + self.num_itrs = num_itrs + # Number of singular values + self.num_svs = num_svs + # Transposed? + self.transpose = transpose + # Epsilon value for avoiding divide-by-0 + self.eps = eps + # Register a singular vector for each sv + for i in range(self.num_svs): + self.register_buffer('u%d' % i, torch.randn(1, num_outputs)) + self.register_buffer('sv%d' % i, torch.ones(1)) + + # Singular vectors (u side) + @property + def u(self): + return [getattr(self, 'u%d' % i) for i in range(self.num_svs)] + + # Singular values; + # note that these buffers are just for logging and are not used in training. + @property + def sv(self): + return [getattr(self, 'sv%d' % i) for i in range(self.num_svs)] + + # Compute the spectrally-normalized weight + def W_(self): + W_mat = self.weight.view(self.weight.size(0), -1) + if self.transpose: + W_mat = W_mat.t() + # Apply num_itrs power iterations + for _ in range(self.num_itrs): + svs, us, vs = power_iteration(W_mat, self.u, update=self.training, eps=self.eps) + # Update the svs + if self.training: + with torch.no_grad(): # Make sure to do this in a no_grad() context or you'll get memory leaks! + for i, sv in enumerate(svs): + self.sv[i][:] = sv + return self.weight / svs[0] + + +# 2D Conv layer with spectral norm +class SNConv2d(nn.Conv2d, SN): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True, + num_svs=1, num_itrs=1, eps=1e-12): + nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias) + SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps) + + def forward(self, x): + return F.conv2d(x, self.W_(), self.bias, self.stride, + self.padding, self.dilation, self.groups) + + +# Linear layer with spectral norm +class SNLinear(nn.Linear, SN): + def __init__(self, in_features, out_features, bias=True, + num_svs=1, num_itrs=1, eps=1e-12): + nn.Linear.__init__(self, in_features, out_features, bias) + SN.__init__(self, num_svs, num_itrs, out_features, eps=eps) + + def forward(self, x): + return F.linear(x, self.W_(), self.bias) + + +# Embedding layer with spectral norm +# We use num_embeddings as the dim instead of embedding_dim here +# for convenience sake +class SNEmbedding(nn.Embedding, SN): + def __init__(self, num_embeddings, embedding_dim, padding_idx=None, + max_norm=None, norm_type=2, scale_grad_by_freq=False, + sparse=False, _weight=None, + num_svs=1, num_itrs=1, eps=1e-12): + nn.Embedding.__init__(self, num_embeddings, embedding_dim, padding_idx, + max_norm, norm_type, scale_grad_by_freq, + sparse, _weight) + SN.__init__(self, num_svs, num_itrs, num_embeddings, eps=eps) + + def forward(self, x): + return F.embedding(x, self.W_()) + + +# A non-local block as used in SA-GAN +# Note that the implementation as described in the paper is largely incorrect; +# refer to the released code for the actual implementation. +class Attention(nn.Module): + def __init__(self, ch, which_conv=SNConv2d, name='attention'): + super(Attention, self).__init__() + # Channel multiplier + self.ch = ch + self.which_conv = which_conv + self.theta = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False) + self.phi = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False) + self.g = self.which_conv(self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False) + self.o = self.which_conv(self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False) + # Learnable gain parameter + self.gamma = P(torch.tensor(0.), requires_grad=True) + + def forward(self, x, y=None): + # Apply convs + theta = self.theta(x) + phi = F.max_pool2d(self.phi(x), [2, 2]) + g = F.max_pool2d(self.g(x), [2, 2]) + # Perform reshapes + theta = theta.view(-1, self.ch // 8, x.shape[2] * x.shape[3]) + phi = phi.view(-1, self.ch // 8, x.shape[2] * x.shape[3] // 4) + g = g.view(-1, self.ch // 2, x.shape[2] * x.shape[3] // 4) + # Matmul and softmax to get attention maps + beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1) + # Attention map times g path + o = self.o(torch.bmm(g, beta.transpose(1, 2)).view(-1, self.ch // 2, x.shape[2], x.shape[3])) + return self.gamma * o + x + + +# Fused batchnorm op +def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5): + # Apply scale and shift--if gain and bias are provided, fuse them here + # Prepare scale + scale = torch.rsqrt(var + eps) + # If a gain is provided, use it + if gain is not None: + scale = scale * gain + # Prepare shift + shift = mean * scale + # If bias is provided, use it + if bias is not None: + shift = shift - bias + return x * scale - shift + # return ((x - mean) / ((var + eps) ** 0.5)) * gain + bias # The unfused way. + + +# Manual BN +# Calculate means and variances using mean-of-squares minus mean-squared +def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5): + # Cast x to float32 if necessary + float_x = x.float() + # Calculate expected value of x (m) and expected value of x**2 (m2) + # Mean of x + m = torch.mean(float_x, [0, 2, 3], keepdim=True) + # Mean of x squared + m2 = torch.mean(float_x ** 2, [0, 2, 3], keepdim=True) + # Calculate variance as mean of squared minus mean squared. + var = (m2 - m ** 2) + # Cast back to float 16 if necessary + var = var.type(x.type()) + m = m.type(x.type()) + # Return mean and variance for updating stored mean/var if requested + if return_mean_var: + return fused_bn(x, m, var, gain, bias, eps), m.squeeze(), var.squeeze() + else: + return fused_bn(x, m, var, gain, bias, eps) + + +# My batchnorm, supports standing stats +class myBN(nn.Module): + def __init__(self, num_channels, eps=1e-5, momentum=0.1): + super(myBN, self).__init__() + # momentum for updating running stats + self.momentum = momentum + # epsilon to avoid dividing by 0 + self.eps = eps + # Momentum + self.momentum = momentum + # Register buffers + self.register_buffer('stored_mean', torch.zeros(num_channels)) + self.register_buffer('stored_var', torch.ones(num_channels)) + self.register_buffer('accumulation_counter', torch.zeros(1)) + # Accumulate running means and vars + self.accumulate_standing = False + + # reset standing stats + def reset_stats(self): + self.stored_mean[:] = 0 + self.stored_var[:] = 0 + self.accumulation_counter[:] = 0 + + def forward(self, x, gain, bias): + if self.training: + out, mean, var = manual_bn(x, gain, bias, return_mean_var=True, eps=self.eps) + # If accumulating standing stats, increment them + if self.accumulate_standing: + self.stored_mean[:] = self.stored_mean + mean.data + self.stored_var[:] = self.stored_var + var.data + self.accumulation_counter += 1.0 + # If not accumulating standing stats, take running averages + else: + self.stored_mean[:] = self.stored_mean * (1 - self.momentum) + mean * self.momentum + self.stored_var[:] = self.stored_var * (1 - self.momentum) + var * self.momentum + return out + # If not in training mode, use the stored statistics + else: + mean = self.stored_mean.view(1, -1, 1, 1) + var = self.stored_var.view(1, -1, 1, 1) + # If using standing stats, divide them by the accumulation counter + if self.accumulate_standing: + mean = mean / self.accumulation_counter + var = var / self.accumulation_counter + return fused_bn(x, mean, var, gain, bias, self.eps) + + +# Simple function to handle groupnorm norm stylization +def groupnorm(x, norm_style): + # If number of channels specified in norm_style: + if 'ch' in norm_style: + ch = int(norm_style.split('_')[-1]) + groups = max(int(x.shape[1]) // ch, 1) + # If number of groups specified in norm style + elif 'grp' in norm_style: + groups = int(norm_style.split('_')[-1]) + # If neither, default to groups = 16 + else: + groups = 16 + return F.group_norm(x, groups) + + +# Class-conditional bn +# output size is the number of channels, input size is for the linear layers +# Andy's Note: this class feels messy but I'm not really sure how to clean it up +# Suggestions welcome! (By which I mean, refactor this and make a pull request +# if you want to make this more readable/usable). +class ccbn(nn.Module): + def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1, + cross_replica=False, mybn=False, norm_style='bn', ): + super(ccbn, self).__init__() + self.output_size, self.input_size = output_size, input_size + # Prepare gain and bias layers + self.gain = which_linear(input_size, output_size) + self.bias = which_linear(input_size, output_size) + # epsilon to avoid dividing by 0 + self.eps = eps + # Momentum + self.momentum = momentum + # Use cross-replica batchnorm? + self.cross_replica = cross_replica + # Use my batchnorm? + self.mybn = mybn + # Norm style? + self.norm_style = norm_style + + if self.cross_replica or self.mybn: + self.bn = myBN(output_size, self.eps, self.momentum) + elif self.norm_style in ['bn', 'in']: + self.register_buffer('stored_mean', torch.zeros(output_size)) + self.register_buffer('stored_var', torch.ones(output_size)) + + def forward(self, x, y): + # Calculate class-conditional gains and biases + gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1) + bias = self.bias(y).view(y.size(0), -1, 1, 1) + # If using my batchnorm + if self.mybn or self.cross_replica: + return self.bn(x, gain=gain, bias=bias) + # else: + else: + if self.norm_style == 'bn': + out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None, + self.training, 0.1, self.eps) + elif self.norm_style == 'in': + out = F.instance_norm(x, self.stored_mean, self.stored_var, None, None, + self.training, 0.1, self.eps) + elif self.norm_style == 'gn': + out = groupnorm(x, self.normstyle) + elif self.norm_style == 'nonorm': + out = x + return out * gain + bias + + def extra_repr(self): + s = 'out: {output_size}, in: {input_size},' + s += ' cross_replica={cross_replica}' + return s.format(**self.__dict__) + + +# Normal, non-class-conditional BN +class bn(nn.Module): + def __init__(self, output_size, eps=1e-5, momentum=0.1, + cross_replica=False, mybn=False): + super(bn, self).__init__() + self.output_size = output_size + # Prepare gain and bias layers + self.gain = P(torch.ones(output_size), requires_grad=True) + self.bias = P(torch.zeros(output_size), requires_grad=True) + # epsilon to avoid dividing by 0 + self.eps = eps + # Momentum + self.momentum = momentum + # Use cross-replica batchnorm? + self.cross_replica = cross_replica + # Use my batchnorm? + self.mybn = mybn + + if self.cross_replica or mybn: + self.bn = myBN(output_size, self.eps, self.momentum) + # Register buffers if neither of the above + else: + self.register_buffer('stored_mean', torch.zeros(output_size)) + self.register_buffer('stored_var', torch.ones(output_size)) + + def forward(self, x, y=None): + if self.cross_replica or self.mybn: + gain = self.gain.view(1, -1, 1, 1) + bias = self.bias.view(1, -1, 1, 1) + return self.bn(x, gain=gain, bias=bias) + else: + return F.batch_norm(x, self.stored_mean, self.stored_var, self.gain, + self.bias, self.training, self.momentum, self.eps) + + +# Generator blocks +# Note that this class assumes the kernel size and padding (and any other +# settings) have been selected in the main generator module and passed in +# through the which_conv arg. Similar rules apply with which_bn (the input +# size [which is actually the number of channels of the conditional info] must +# be preselected) +class GBlock(nn.Module): + def __init__(self, in_channels, out_channels, + which_conv=nn.Conv2d, which_bn=bn, activation=None, + upsample=None): + super(GBlock, self).__init__() + + self.in_channels, self.out_channels = in_channels, out_channels + self.which_conv, self.which_bn = which_conv, which_bn + self.activation = activation + self.upsample = upsample + # Conv layers + self.conv1 = self.which_conv(self.in_channels, self.out_channels) + self.conv2 = self.which_conv(self.out_channels, self.out_channels) + self.learnable_sc = in_channels != out_channels or upsample + if self.learnable_sc: + self.conv_sc = self.which_conv(in_channels, out_channels, + kernel_size=1, padding=0) + # Batchnorm layers + self.bn1 = self.which_bn(in_channels) + self.bn2 = self.which_bn(out_channels) + # upsample layers + self.upsample = upsample + + def forward(self, x, y): + h = self.activation(self.bn1(x, y)) + if self.upsample: + h = self.upsample(h) + x = self.upsample(x) + h = self.conv1(h) + h = self.activation(self.bn2(h, y)) + h = self.conv2(h) + if self.learnable_sc: + x = self.conv_sc(x) + return h + x + + +# Residual block for the discriminator +class DBlock(nn.Module): + def __init__(self, in_channels, out_channels, which_conv=SNConv2d, wide=True, + preactivation=False, activation=None, downsample=None, ): + super(DBlock, self).__init__() + self.in_channels, self.out_channels = in_channels, out_channels + # If using wide D (as in SA-GAN and BigGAN), change the channel pattern + self.hidden_channels = self.out_channels if wide else self.in_channels + self.which_conv = which_conv + self.preactivation = preactivation + self.activation = activation + self.downsample = downsample + + # Conv layers + self.conv1 = self.which_conv(self.in_channels, self.hidden_channels) + self.conv2 = self.which_conv(self.hidden_channels, self.out_channels) + self.learnable_sc = True if (in_channels != out_channels) or downsample else False + if self.learnable_sc: + self.conv_sc = self.which_conv(in_channels, out_channels, + kernel_size=1, padding=0) + + def shortcut(self, x): + if self.preactivation: + if self.learnable_sc: + x = self.conv_sc(x) + if self.downsample: + x = self.downsample(x) + else: + if self.downsample: + x = self.downsample(x) + if self.learnable_sc: + x = self.conv_sc(x) + return x + + def forward(self, x): + if self.preactivation: + # h = self.activation(x) # NOT TODAY SATAN + # Andy's note: This line *must* be an out-of-place ReLU or it + # will negatively affect the shortcut connection. + h = F.relu(x) + else: + h = x + h = self.conv1(h) + h = self.conv2(self.activation(h)) + if self.downsample: + h = self.downsample(h) + + return h + self.shortcut(x) + diff --git a/codes/models/archs/DiscriminatorResnet_arch.py b/codes/models/archs/fixup_resnet/DiscriminatorResnet_arch.py similarity index 100% rename from codes/models/archs/DiscriminatorResnet_arch.py rename to codes/models/archs/fixup_resnet/DiscriminatorResnet_arch.py diff --git a/codes/models/archs/fixup_resnet/__init__.py b/codes/models/archs/fixup_resnet/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/codes/models/archs/stylegan/Discriminator_StyleGAN.py b/codes/models/archs/stylegan/Discriminator_StyleGAN.py new file mode 100644 index 00000000..cab5b278 --- /dev/null +++ b/codes/models/archs/stylegan/Discriminator_StyleGAN.py @@ -0,0 +1,375 @@ +from collections import OrderedDict + +import torch +from torch import nn +import torch.nn.functional as F +import numpy as np + + +class BlurLayer(nn.Module): + def __init__(self, kernel=None, normalize=True, flip=False, stride=1): + super(BlurLayer, self).__init__() + if kernel is None: + kernel = [1, 2, 1] + kernel = torch.tensor(kernel, dtype=torch.float32) + kernel = kernel[:, None] * kernel[None, :] + kernel = kernel[None, None] + if normalize: + kernel = kernel / kernel.sum() + if flip: + kernel = kernel[:, :, ::-1, ::-1] + self.register_buffer('kernel', kernel) + self.stride = stride + + def forward(self, x): + # expand kernel channels + kernel = self.kernel.expand(x.size(1), -1, -1, -1) + x = F.conv2d( + x, + kernel, + stride=self.stride, + padding=int((self.kernel.size(2) - 1) / 2), + groups=x.size(1) + ) + return x + + +class Upscale2d(nn.Module): + @staticmethod + def upscale2d(x, factor=2, gain=1): + assert x.dim() == 4 + if gain != 1: + x = x * gain + if factor != 1: + shape = x.shape + x = x.view(shape[0], shape[1], shape[2], 1, shape[3], 1).expand(-1, -1, -1, factor, -1, factor) + x = x.contiguous().view(shape[0], shape[1], factor * shape[2], factor * shape[3]) + return x + + def __init__(self, factor=2, gain=1): + super().__init__() + assert isinstance(factor, int) and factor >= 1 + self.gain = gain + self.factor = factor + + def forward(self, x): + return self.upscale2d(x, factor=self.factor, gain=self.gain) + + +class Downscale2d(nn.Module): + def __init__(self, factor=2, gain=1): + super().__init__() + assert isinstance(factor, int) and factor >= 1 + self.factor = factor + self.gain = gain + if factor == 2: + f = [np.sqrt(gain) / factor] * factor + self.blur = BlurLayer(kernel=f, normalize=False, stride=factor) + else: + self.blur = None + + def forward(self, x): + assert x.dim() == 4 + # 2x2, float32 => downscale using _blur2d(). + if self.blur is not None and x.dtype == torch.float32: + return self.blur(x) + + # Apply gain. + if self.gain != 1: + x = x * self.gain + + # No-op => early exit. + if self.factor == 1: + return x + + # Large factor => downscale using tf.nn.avg_pool(). + # NOTE: Requires tf_config['graph_options.place_pruned_graph']=True to work. + return F.avg_pool2d(x, self.factor) + + +class EqualizedConv2d(nn.Module): + """Conv layer with equalized learning rate and custom learning rate multiplier.""" + + def __init__(self, input_channels, output_channels, kernel_size, stride=1, gain=2 ** 0.5, use_wscale=False, + lrmul=1, bias=True, intermediate=None, upscale=False, downscale=False): + super().__init__() + if upscale: + self.upscale = Upscale2d() + else: + self.upscale = None + if downscale: + self.downscale = Downscale2d() + else: + self.downscale = None + he_std = gain * (input_channels * kernel_size ** 2) ** (-0.5) # He init + self.kernel_size = kernel_size + if use_wscale: + init_std = 1.0 / lrmul + self.w_mul = he_std * lrmul + else: + init_std = he_std / lrmul + self.w_mul = lrmul + self.weight = torch.nn.Parameter( + torch.randn(output_channels, input_channels, kernel_size, kernel_size) * init_std) + if bias: + self.bias = torch.nn.Parameter(torch.zeros(output_channels)) + self.b_mul = lrmul + else: + self.bias = None + self.intermediate = intermediate + + def forward(self, x): + bias = self.bias + if bias is not None: + bias = bias * self.b_mul + + have_convolution = False + if self.upscale is not None and min(x.shape[2:]) * 2 >= 128: + # this is the fused upscale + conv from StyleGAN, sadly this seems incompatible with the non-fused way + # this really needs to be cleaned up and go into the conv... + w = self.weight * self.w_mul + w = w.permute(1, 0, 2, 3) + # probably applying a conv on w would be more efficient. also this quadruples the weight (average)?! + w = F.pad(w, [1, 1, 1, 1]) + w = w[:, :, 1:, 1:] + w[:, :, :-1, 1:] + w[:, :, 1:, :-1] + w[:, :, :-1, :-1] + x = F.conv_transpose2d(x, w, stride=2, padding=(w.size(-1) - 1) // 2) + have_convolution = True + elif self.upscale is not None: + x = self.upscale(x) + + downscale = self.downscale + intermediate = self.intermediate + if downscale is not None and min(x.shape[2:]) >= 128: + w = self.weight * self.w_mul + w = F.pad(w, [1, 1, 1, 1]) + # in contrast to upscale, this is a mean... + w = (w[:, :, 1:, 1:] + w[:, :, :-1, 1:] + w[:, :, 1:, :-1] + w[:, :, :-1, :-1]) * 0.25 # avg_pool? + x = F.conv2d(x, w, stride=2, padding=(w.size(-1) - 1) // 2) + have_convolution = True + downscale = None + elif downscale is not None: + assert intermediate is None + intermediate = downscale + + if not have_convolution and intermediate is None: + return F.conv2d(x, self.weight * self.w_mul, bias, padding=self.kernel_size // 2) + elif not have_convolution: + x = F.conv2d(x, self.weight * self.w_mul, None, padding=self.kernel_size // 2) + + if intermediate is not None: + x = intermediate(x) + + if bias is not None: + x = x + bias.view(1, -1, 1, 1) + return x + + +class EqualizedLinear(nn.Module): + """Linear layer with equalized learning rate and custom learning rate multiplier.""" + + def __init__(self, input_size, output_size, gain=2 ** 0.5, use_wscale=False, lrmul=1, bias=True): + super().__init__() + he_std = gain * input_size ** (-0.5) # He init + # Equalized learning rate and custom learning rate multiplier. + if use_wscale: + init_std = 1.0 / lrmul + self.w_mul = he_std * lrmul + else: + init_std = he_std / lrmul + self.w_mul = lrmul + self.weight = torch.nn.Parameter(torch.randn(output_size, input_size) * init_std) + if bias: + self.bias = torch.nn.Parameter(torch.zeros(output_size)) + self.b_mul = lrmul + else: + self.bias = None + + def forward(self, x): + bias = self.bias + if bias is not None: + bias = bias * self.b_mul + return F.linear(x, self.weight * self.w_mul, bias) + + +class View(nn.Module): + def __init__(self, *shape): + super().__init__() + self.shape = shape + + + def forward(self, x): + return x.view(x.size(0), *self.shape) + + +class StddevLayer(nn.Module): + def __init__(self, group_size=4, num_new_features=1): + super().__init__() + self.group_size = group_size + self.num_new_features = num_new_features + + def forward(self, x): + b, c, h, w = x.shape + group_size = min(self.group_size, b) + y = x.reshape([group_size, -1, self.num_new_features, + c // self.num_new_features, h, w]) + y = y - y.mean(0, keepdim=True) + y = (y ** 2).mean(0, keepdim=True) + y = (y + 1e-8) ** 0.5 + y = y.mean([3, 4, 5], keepdim=True).squeeze(3) # don't keep the meaned-out channels + y = y.expand(group_size, -1, -1, h, w).clone().reshape(b, self.num_new_features, h, w) + z = torch.cat([x, y], dim=1) + return z + + +class DiscriminatorBlock(nn.Sequential): + def __init__(self, in_channels, out_channels, gain, use_wscale, activation_layer, blur_kernel): + super().__init__(OrderedDict([ + ('conv0', EqualizedConv2d(in_channels, in_channels, kernel_size=3, gain=gain, use_wscale=use_wscale)), + # out channels nf(res-1) + ('act0', activation_layer), + ('blur', BlurLayer(kernel=blur_kernel)), + ('conv1_down', EqualizedConv2d(in_channels, out_channels, kernel_size=3, + gain=gain, use_wscale=use_wscale, downscale=True)), + ('act1', activation_layer)])) + + + +class DiscriminatorTop(nn.Sequential): + def __init__(self, + mbstd_group_size, + mbstd_num_features, + in_channels, + intermediate_channels, + gain, use_wscale, + activation_layer, + resolution=4, + in_channels2=None, + output_features=1, + last_gain=1): + """ + :param mbstd_group_size: + :param mbstd_num_features: + :param in_channels: + :param intermediate_channels: + :param gain: + :param use_wscale: + :param activation_layer: + :param resolution: + :param in_channels2: + :param output_features: + :param last_gain: + """ + + layers = [] + if mbstd_group_size > 1: + layers.append(('stddev_layer', StddevLayer(mbstd_group_size, mbstd_num_features))) + + if in_channels2 is None: + in_channels2 = in_channels + + layers.append(('conv', EqualizedConv2d(in_channels + mbstd_num_features, in_channels2, kernel_size=3, + gain=gain, use_wscale=use_wscale))) + layers.append(('act0', activation_layer)) + layers.append(('view', View(-1))) + layers.append(('dense0', EqualizedLinear(in_channels2 * resolution * resolution, intermediate_channels, + gain=gain, use_wscale=use_wscale))) + layers.append(('act1', activation_layer)) + layers.append(('dense1', EqualizedLinear(intermediate_channels, output_features, + gain=last_gain, use_wscale=use_wscale))) + + super().__init__(OrderedDict(layers)) + + +class StyleGanDiscriminator(nn.Module): + def __init__(self, resolution, num_channels=3, fmap_base=8192, fmap_decay=1.0, fmap_max=512, + nonlinearity='lrelu', use_wscale=True, mbstd_group_size=4, mbstd_num_features=1, + blur_filter=None, structure='fixed', **kwargs): + """ + Discriminator used in the StyleGAN paper. + :param num_channels: Number of input color channels. Overridden based on dataset. + :param resolution: Input resolution. Overridden based on dataset. + # label_size=0, # Dimensionality of the labels, 0 if no labels. Overridden based on dataset. + :param fmap_base: Overall multiplier for the number of feature maps. + :param fmap_decay: log2 feature map reduction when doubling the resolution. + :param fmap_max: Maximum number of feature maps in any layer. + :param nonlinearity: Activation function: 'relu', 'lrelu' + :param use_wscale: Enable equalized learning rate? + :param mbstd_group_size: Group size for the mini_batch standard deviation layer, 0 = disable. + :param mbstd_num_features: Number of features for the mini_batch standard deviation layer. + :param blur_filter: Low-pass filter to apply when resampling activations. None = no filtering. + :param structure: 'fixed' = no progressive growing, 'linear' = human-readable + :param kwargs: Ignore unrecognized keyword args. + """ + super(StyleGanDiscriminator, self).__init__() + + def nf(stage): + return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max) + + self.mbstd_num_features = mbstd_num_features + self.mbstd_group_size = mbstd_group_size + self.structure = structure + # if blur_filter is None: + # blur_filter = [1, 2, 1] + + resolution_log2 = int(np.log2(resolution)) + assert resolution == 2 ** resolution_log2 and resolution >= 4 + self.depth = resolution_log2 - 1 + + act, gain = {'relu': (torch.relu, np.sqrt(2)), + 'lrelu': (nn.LeakyReLU(negative_slope=0.2), np.sqrt(2))}[nonlinearity] + + # create the remaining layers + blocks = [] + from_rgb = [] + for res in range(resolution_log2, 2, -1): + # name = '{s}x{s}'.format(s=2 ** res) + blocks.append(DiscriminatorBlock(nf(res - 1), nf(res - 2), + gain=gain, use_wscale=use_wscale, activation_layer=act, + blur_kernel=blur_filter)) + # create the fromRGB layers for various inputs: + from_rgb.append(EqualizedConv2d(num_channels, nf(res - 1), kernel_size=1, + gain=gain, use_wscale=use_wscale)) + self.blocks = nn.ModuleList(blocks) + + # Building the final block. + self.final_block = DiscriminatorTop(self.mbstd_group_size, self.mbstd_num_features, + in_channels=nf(2), intermediate_channels=nf(2), + gain=gain, use_wscale=use_wscale, activation_layer=act) + from_rgb.append(EqualizedConv2d(num_channels, nf(2), kernel_size=1, + gain=gain, use_wscale=use_wscale)) + self.from_rgb = nn.ModuleList(from_rgb) + + # register the temporary downSampler + self.temporaryDownsampler = nn.AvgPool2d(2) + + def forward(self, images_in, depth=0, alpha=1.): + """ + :param images_in: First input: Images [mini_batch, channel, height, width]. + :param labels_in: Second input: Labels [mini_batch, label_size]. + :param depth: current height of operation (Progressive GAN) + :param alpha: current value of alpha for fade-in + :return: + """ + + if self.structure == 'fixed': + x = self.from_rgb[0](images_in) + for i, block in enumerate(self.blocks): + x = block(x) + scores_out = self.final_block(x) + elif self.structure == 'linear': + assert depth < self.depth, "Requested output depth cannot be produced" + if depth > 0: + residual = self.from_rgb[self.depth - depth](self.temporaryDownsampler(images_in)) + straight = self.blocks[self.depth - depth - 1](self.from_rgb[self.depth - depth - 1](images_in)) + x = (alpha * straight) + ((1 - alpha) * residual) + + for block in self.blocks[(self.depth - depth):]: + x = block(x) + else: + x = self.from_rgb[-1](images_in) + + scores_out = self.final_block(x) + else: + raise KeyError("Unknown structure: ", self.structure) + + return scores_out \ No newline at end of file diff --git a/codes/models/archs/stylegan/__init__.py b/codes/models/archs/stylegan/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/codes/models/networks.py b/codes/models/networks.py index b168bf96..feefbd69 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -7,8 +7,7 @@ import torch import torchvision from munch import munchify -import models.archs.DiscriminatorResnet_arch as DiscriminatorResnet_arch -import models.archs.DiscriminatorResnet_arch_passthrough as DiscriminatorResnet_arch_passthrough +import models.archs.fixup_resnet.DiscriminatorResnet_arch as DiscriminatorResnet_arch import models.archs.RRDBNet_arch as RRDBNet_arch import models.archs.SPSR_arch as spsr import models.archs.SRResNet_arch as SRResNet_arch @@ -17,9 +16,11 @@ import models.archs.discriminator_vgg_arch as SRGAN_arch import models.archs.feature_arch as feature_arch import models.archs.panet.panet as panet import models.archs.rcan as rcan -import models.archs.ChainedEmbeddingGen as chained from models.archs import srg2_classic +from models.archs.biggan.biggan_discriminator import BigGanDiscriminator +from models.archs.stylegan.Discriminator_StyleGAN import StyleGanDiscriminator from models.archs.pyramid_arch import BasicResamplingFlowNet +from models.archs.rrdb_with_adain_latent import AdaRRDBNet, LinearLatentEstimator from models.archs.rrdb_with_latent import LatentEstimator, RRDBNetWithLatent, LatentEstimator2 from models.archs.teco_resgen import TecoGen @@ -90,15 +91,6 @@ def define_G(opt, net_key='network_G', scale=None): netG = spsr.Spsr7(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'], multiplexer_reductions=opt_net['multiplexer_reductions'] if 'multiplexer_reductions' in opt_net.keys() else 3, init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10, recurrent=recurrent) - elif which_model == 'chained_gen_structured': - rec = opt_net['recurrent'] if 'recurrent' in opt_net.keys() else False - recnf = opt_net['recurrent_nf'] if 'recurrent_nf' in opt_net.keys() else 3 - recstd = opt_net['recurrent_stride'] if 'recurrent_stride' in opt_net.keys() else 2 - in_nc = opt_net['in_nc'] if 'in_nc' in opt_net.keys() else 3 - netG = chained.ChainedEmbeddingGenWithStructure(depth=opt_net['depth'], recurrent=rec, recurrent_nf=recnf, recurrent_stride=recstd, in_nc=in_nc) - elif which_model == 'multifaceted_chained': - scale = opt_net['scale'] if 'scale' in opt_net.keys() else 2 - netG = chained.MultifacetedChainedEmbeddingGen(depth=opt_net['depth'], scale=scale) elif which_model == "flownet2": from models.flownet2.models import FlowNet2 ld = 'load_path' in opt_net.keys() @@ -125,12 +117,19 @@ def define_G(opt, net_key='network_G', scale=None): blocks_per_checkpoint=opt_net['blocks_per_checkpoint'], scale=opt_net['scale'], bottom_latent_only=opt_net['bottom_latent_only']) + elif which_model == "adarrdb": + netG = AdaRRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'], + mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], + blocks_per_checkpoint=opt_net['blocks_per_checkpoint'], + scale=opt_net['scale']) elif which_model == "latent_estimator": if opt_net['version'] == 2: netG = LatentEstimator2(in_nc=3, nf=opt_net['nf']) else: overwrite = [1,2] if opt_net['only_base_level'] else [] netG = LatentEstimator(in_nc=3, nf=opt_net['nf'], overwrite_levels=overwrite) + elif which_model == "linear_latent_estimator": + netG = LinearLatentEstimator(in_nc=3, nf=opt_net['nf']) else: raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model)) return netG @@ -159,19 +158,19 @@ def define_D_net(opt_net, img_sz=None, wrap=False): netD = GradDiscWrapper(netD) elif which_model == 'discriminator_vgg_128_gn_checkpointed': netD = SRGAN_arch.Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128, do_checkpointing=True) + elif which_model == 'stylegan_vgg': + netD = StyleGanDiscriminator(128) elif which_model == 'discriminator_resnet': netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz) elif which_model == 'discriminator_resnet_50': netD = DiscriminatorResnet_arch.fixup_resnet50(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz) - elif which_model == 'discriminator_resnet_passthrough': - netD = DiscriminatorResnet_arch_passthrough.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz, - number_skips=opt_net['number_skips'], use_bn=True, - disable_passthrough=opt_net['disable_passthrough']) elif which_model == 'resnext': netD = torchvision.models.resnext50_32x4d(norm_layer=functools.partial(torch.nn.GroupNorm, 8)) - state_dict = torch.hub.load_state_dict_from_url('https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', progress=True) - netD.load_state_dict(state_dict, strict=False) + #state_dict = torch.hub.load_state_dict_from_url('https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', progress=True) + #netD.load_state_dict(state_dict, strict=False) netD.fc = torch.nn.Linear(512 * 4, 1) + elif which_model == 'biggan_resnet': + netD = BigGanDiscriminator(D_activation=torch.nn.LeakyReLU(negative_slope=.2)) elif which_model == 'discriminator_pix': netD = SRGAN_arch.Discriminator_VGG_PixLoss(in_nc=opt_net['in_nc'], nf=opt_net['nf']) elif which_model == "discriminator_unet": diff --git a/codes/train.py b/codes/train.py index edae8fdb..227a273e 100644 --- a/codes/train.py +++ b/codes/train.py @@ -265,7 +265,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_latent_mi1_rrdb4x_6bl_lower_signal.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_adalatent_mi1_rrdb4x_6bl.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/train2.py b/codes/train2.py index 13151cf7..f2996e83 100644 --- a/codes/train2.py +++ b/codes/train2.py @@ -280,7 +280,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_latent_mi1_rrdb4x_6bl_lower_signal_2.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_adalatent_mi1_rrdb4x_6bl_resdisc.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') args = parser.parse_args() opt = option.parse(args.opt, is_train=True)