From a5188bb7ca2c081770019469890a8c9914105c7b Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 29 Apr 2020 15:17:43 -0600 Subject: [PATCH] Remover fixup code from arch_util Going into it's own arch. --- .idea/mmsr.iml | 2 + .../models/archs/DiscriminatorResnet_arch.py | 242 +++++++++++++----- codes/models/archs/arch_util.py | 96 +------ 3 files changed, 180 insertions(+), 160 deletions(-) diff --git a/.idea/mmsr.iml b/.idea/mmsr.iml index 643c0574..75a1d172 100644 --- a/.idea/mmsr.iml +++ b/.idea/mmsr.iml @@ -2,8 +2,10 @@ + + diff --git a/codes/models/archs/DiscriminatorResnet_arch.py b/codes/models/archs/DiscriminatorResnet_arch.py index e30a3ab9..b1feea64 100644 --- a/codes/models/archs/DiscriminatorResnet_arch.py +++ b/codes/models/archs/DiscriminatorResnet_arch.py @@ -1,85 +1,195 @@ import torch import torch.nn as nn -import torchvision -import models.archs.arch_util as arch_util -import functools -import torch.nn.functional as F -import torch.nn.utils.spectral_norm as SpectralNorm +import numpy as np -# Class that halfs the image size (x4 complexity reduction) and doubles the filter size. Substantial resnet -# processing is also performed. -class ResnetDownsampleLayer(nn.Module): - def __init__(self, starting_channels: int, number_filters: int, filter_multiplier: int, residual_blocks_input: int, residual_blocks_skip_image: int, total_residual_blocks: int): - super(ResnetDownsampleLayer, self).__init__() - self.skip_image_reducer = SpectralNorm(nn.Conv2d(starting_channels, number_filters, 3, stride=1, padding=1, bias=True)) - self.skip_image_res_trunk = arch_util.make_layer(functools.partial(arch_util.ResidualBlockSpectralNorm, nf=number_filters, total_residual_blocks=total_residual_blocks), residual_blocks_skip_image) +__all__ = ['FixupResNet', 'fixup_resnet18', 'fixup_resnet34', 'fixup_resnet50', 'fixup_resnet101', 'fixup_resnet152'] - self.input_reducer = SpectralNorm(nn.Conv2d(number_filters, number_filters*filter_multiplier, 3, stride=2, padding=1, bias=True)) - self.res_trunk = arch_util.make_layer(functools.partial(arch_util.ResidualBlockSpectralNorm, nf=number_filters*filter_multiplier, total_residual_blocks=total_residual_blocks), residual_blocks_input) +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class FixupBasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(FixupBasicBlock, self).__init__() + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.bias1a = nn.Parameter(torch.zeros(1)) + self.conv1 = conv3x3(inplanes, planes, stride) + self.bias1b = nn.Parameter(torch.zeros(1)) self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) - arch_util.initialize_weights([self.input_reducer, self.skip_image_reducer], 1) + self.bias2a = nn.Parameter(torch.zeros(1)) + self.conv2 = conv3x3(planes, planes) + self.scale = nn.Parameter(torch.ones(1)) + self.bias2b = nn.Parameter(torch.zeros(1)) + self.downsample = downsample + self.stride = stride - def forward(self, x, skip_image): - # Process the skip image first. - skip = self.lrelu(self.skip_image_reducer(skip_image)) - skip = self.skip_image_res_trunk(skip) + def forward(self, x): + identity = x + + out = self.conv1(x + self.bias1a) + out = self.lrelu(out + self.bias1b) + + out = self.conv2(out + self.bias2a) + out = out * self.scale + self.bias2b + + if self.downsample is not None: + identity = self.downsample(x + self.bias1a) + + out += identity + out = self.lrelu(out) - # Concat the processed skip image onto the input and perform processing. - out = (x + skip) / 2 - out = self.lrelu(self.input_reducer(out)) - out = self.res_trunk(out) return out -class DiscriminatorResnet(nn.Module): - # Discriminator that downsamples 5 times with resnet blocks at each layer. On each downsample, the filter size is - # increased by a factor of 2. Feeds the output of the convs into a dense for prediction at the logits. Scales the - # final dense based on the input image size. Intended for use with input images which are multiples of 32. - # - # This discriminator also includes provisions to pass an image at various downsample steps in directly. When this - # is done with a generator, it will allow much shorter gradient paths between the generator and discriminator. When - # no downsampled images are passed into the forward() pass, they will be automatically generated from the source - # image using interpolation. - # - # Uses spectral normalization rather than batch normalization. - def __init__(self, in_nc: int, nf: int, input_img_size: int, trunk_resblocks: int, skip_resblocks: int): - super(DiscriminatorResnet, self).__init__() - self.dimensionalize = nn.Conv2d(in_nc, nf, kernel_size=3, stride=1, padding=1, bias=True) +class FixupBottleneck(nn.Module): + expansion = 4 - # Trunk resblocks are the important things to get right, so use those. 5=number of downsample layers. - total_resblocks = trunk_resblocks * 5 - self.downsample1 = ResnetDownsampleLayer(in_nc, nf, 2, trunk_resblocks, skip_resblocks, total_resblocks) - self.downsample2 = ResnetDownsampleLayer(in_nc, nf*2, 2, trunk_resblocks, skip_resblocks, total_resblocks) - self.downsample3 = ResnetDownsampleLayer(in_nc, nf*4, 2, trunk_resblocks, skip_resblocks, total_resblocks) - # At the bottom layers, we cap the filter multiplier. We want this particular network to focus as much on the - # macro-details at higher image dimensionality as it does to the feature details. - self.downsample4 = ResnetDownsampleLayer(in_nc, nf*8, 1, trunk_resblocks, skip_resblocks, total_resblocks) - self.downsample5 = ResnetDownsampleLayer(in_nc, nf*8, 1, trunk_resblocks, skip_resblocks, total_resblocks) - self.downsamplers = [self.downsample1, self.downsample2, self.downsample3, self.downsample4, self.downsample5] + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(FixupBottleneck, self).__init__() + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.bias1a = nn.Parameter(torch.zeros(1)) + self.conv1 = conv1x1(inplanes, planes) + self.bias1b = nn.Parameter(torch.zeros(1)) + self.bias2a = nn.Parameter(torch.zeros(1)) + self.conv2 = conv3x3(planes, planes, stride) + self.bias2b = nn.Parameter(torch.zeros(1)) + self.bias3a = nn.Parameter(torch.zeros(1)) + self.conv3 = conv1x1(planes, planes * self.expansion) + self.scale = nn.Parameter(torch.ones(1)) + self.bias3b = nn.Parameter(torch.zeros(1)) + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + self.downsample = downsample + self.stride = stride - downsampled_image_size = input_img_size / 32 - self.linear1 = nn.Linear(int(nf * 8 * downsampled_image_size * downsampled_image_size), 100) - self.linear2 = nn.Linear(100, 1) + def forward(self, x): + identity = x - # activation function - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + out = self.conv1(x + self.bias1a) + out = self.lrelu(out + self.bias1b) - arch_util.initialize_weights([self.dimensionalize, self.linear1, self.linear2], 1) + out = self.conv2(out + self.bias2a) + out = self.lrelu(out + self.bias2b) - def forward(self, x, skip_images=None): - if skip_images is None: - # Sythesize them from x. - skip_images = [] - for i in range(len(self.downsamplers)): - m = 2 ** i - skip_images.append(F.interpolate(x, scale_factor=1 / m, mode='bilinear', align_corners=False)) + out = self.conv3(out + self.bias3a) + out = out * self.scale + self.bias3b - fea = self.dimensionalize(x) - for skip, d in zip(skip_images, self.downsamplers): - fea = d(fea, skip) + if self.downsample is not None: + identity = self.downsample(x + self.bias1a) + + out += identity + out = self.lrelu(out) - fea = fea.view(fea.size(0), -1) - fea = self.lrelu(self.linear1(fea)) - out = self.linear2(fea) return out + + +class FixupResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000): + super(FixupResNet, self).__init__() + self.num_layers = sum(layers) + self.inplanes = 64 + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bias1 = nn.Parameter(torch.zeros(1)) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.bias2 = nn.Parameter(torch.zeros(1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, FixupBasicBlock): + nn.init.normal_(m.conv1.weight, mean=0, std=np.sqrt(2 / (m.conv1.weight.shape[0] * np.prod(m.conv1.weight.shape[2:]))) * self.num_layers ** (-0.5)) + nn.init.constant_(m.conv2.weight, 0) + if m.downsample is not None: + nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:])))) + elif isinstance(m, FixupBottleneck): + nn.init.normal_(m.conv1.weight, mean=0, std=np.sqrt(2 / (m.conv1.weight.shape[0] * np.prod(m.conv1.weight.shape[2:]))) * self.num_layers ** (-0.25)) + nn.init.normal_(m.conv2.weight, mean=0, std=np.sqrt(2 / (m.conv2.weight.shape[0] * np.prod(m.conv2.weight.shape[2:]))) * self.num_layers ** (-0.25)) + nn.init.constant_(m.conv3.weight, 0) + if m.downsample is not None: + nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:])))) + elif isinstance(m, nn.Linear): + nn.init.constant_(m.weight, 0) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = conv1x1(self.inplanes, planes * block.expansion, stride) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.relu(x + self.bias1) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x + self.bias2) + + return x + + +def fixup_resnet18(**kwargs): + """Constructs a Fixup-ResNet-18 model.2 + """ + model = FixupResNet(FixupBasicBlock, [2, 2, 2, 2], **kwargs) + return model + + +def fixup_resnet34(**kwargs): + """Constructs a Fixup-ResNet-34 model. + """ + model = FixupResNet(FixupBasicBlock, [3, 4, 6, 3], **kwargs) + return model + + +def fixup_resnet50(**kwargs): + """Constructs a Fixup-ResNet-50 model. + """ + model = FixupResNet(FixupBottleneck, [3, 4, 6, 3], **kwargs) + return model + + +def fixup_resnet101(**kwargs): + """Constructs a Fixup-ResNet-101 model. + """ + model = FixupResNet(FixupBottleneck, [3, 4, 23, 3], **kwargs) + return model + + +def fixup_resnet152(**kwargs): + """Constructs a Fixup-ResNet-152 model. + """ + model = FixupResNet(FixupBottleneck, [3, 8, 36, 3], **kwargs) + return model + + +__all__ = ['FixupResNet', 'fixup_resnet18', 'fixup_resnet34', 'fixup_resnet50', 'fixup_resnet101', 'fixup_resnet152'] \ No newline at end of file diff --git a/codes/models/archs/arch_util.py b/codes/models/archs/arch_util.py index c33af559..bf634f6b 100644 --- a/codes/models/archs/arch_util.py +++ b/codes/models/archs/arch_util.py @@ -5,13 +5,8 @@ import torch.nn.functional as F import torch.nn.utils.spectral_norm as SpectralNorm from math import sqrt -def scale_conv_weights_fixup(conv, residual_block_count, m=2): - k = conv.kernel_size[0] - n = conv.out_channels - scaling_factor = residual_block_count ** (-1.0 / (2 * m - 2)) - sigma = sqrt(2 / (k * k * n)) * scaling_factor - conv.weight.data = conv.weight.data * sigma - return conv +def pixel_norm(x, epsilon=1e-8): + return x * torch.rsqrt(torch.mean(torch.pow(x, 2), dim=1, keepdims=True) + epsilon) def initialize_weights(net_l, scale=1): if not isinstance(net_l, list): @@ -39,89 +34,6 @@ def make_layer(block, n_layers): layers.append(block()) return nn.Sequential(*layers) -def conv3x3(in_planes, out_planes, stride=1): - """3x3 convolution with padding""" - return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, - padding=1, bias=False) - -def conv1x1(in_planes, out_planes, stride=1): - """1x1 convolution""" - return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) - -class FixupBasicBlock(nn.Module): - expansion = 1 - - def __init__(self, inplanes, planes, stride=1, downsample=None): - super(FixupBasicBlock, self).__init__() - # Both self.conv1 and self.downsample layers downsample the input when stride != 1 - self.bias1a = nn.Parameter(torch.zeros(1)) - self.conv1 = conv3x3(inplanes, planes, stride) - self.bias1b = nn.Parameter(torch.zeros(1)) - self.relu = nn.ReLU(inplace=True) - self.bias2a = nn.Parameter(torch.zeros(1)) - self.conv2 = conv3x3(planes, planes) - self.scale = nn.Parameter(torch.ones(1)) - self.bias2b = nn.Parameter(torch.zeros(1)) - self.downsample = downsample - self.stride = stride - - def forward(self, x): - identity = x - - out = self.conv1(x + self.bias1a) - out = self.relu(out + self.bias1b) - - out = self.conv2(out + self.bias2a) - out = out * self.scale + self.bias2b - - if self.downsample is not None: - identity = self.downsample(x + self.bias1a) - - out += identity - out = self.relu(out) - - return out - -class FixupBottleneck(nn.Module): - expansion = 4 - - def __init__(self, inplanes, planes, stride=1, downsample=None): - super(FixupBottleneck, self).__init__() - # Both self.conv2 and self.downsample layers downsample the input when stride != 1 - self.bias1a = nn.Parameter(torch.zeros(1)) - self.conv1 = conv1x1(inplanes, planes) - self.bias1b = nn.Parameter(torch.zeros(1)) - self.bias2a = nn.Parameter(torch.zeros(1)) - self.conv2 = conv3x3(planes, planes, stride) - self.bias2b = nn.Parameter(torch.zeros(1)) - self.bias3a = nn.Parameter(torch.zeros(1)) - self.conv3 = conv1x1(planes, planes * self.expansion) - self.scale = nn.Parameter(torch.ones(1)) - self.bias3b = nn.Parameter(torch.zeros(1)) - self.relu = nn.ReLU(inplace=True) - self.downsample = downsample - self.stride = stride - - def forward(self, x): - identity = x - - out = self.conv1(x + self.bias1a) - out = self.relu(out + self.bias1b) - - out = self.conv2(out + self.bias2a) - out = self.relu(out + self.bias2b) - - out = self.conv3(out + self.bias3a) - out = out * self.scale + self.bias3b - - if self.downsample is not None: - identity = self.downsample(x + self.bias1a) - - out += identity - out = self.relu(out) - - return out - class ResidualBlock(nn.Module): '''Residual block with BN ---Conv-BN-ReLU-Conv-+- @@ -157,11 +69,7 @@ class ResidualBlockSpectralNorm(nn.Module): self.conv1 = SpectralNorm(nn.Conv2d(nf, nf, 3, 1, 1, bias=True)) self.conv2 = SpectralNorm(nn.Conv2d(nf, nf, 3, 1, 1, bias=True)) - # Initialize first. initialize_weights([self.conv1, self.conv2], 1) - # Then perform fixup scaling - self.conv1 = scale_conv_weights_fixup(self.conv1, total_residual_blocks) - self.conv2 = scale_conv_weights_fixup(self.conv2, total_residual_blocks) def forward(self, x): identity = x