From a5188bb7ca2c081770019469890a8c9914105c7b Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Wed, 29 Apr 2020 15:17:43 -0600
Subject: [PATCH] Remover fixup code from arch_util

Going into it's own arch.
---
 .idea/mmsr.iml                                |   2 +
 .../models/archs/DiscriminatorResnet_arch.py  | 242 +++++++++++++-----
 codes/models/archs/arch_util.py               |  96 +------
 3 files changed, 180 insertions(+), 160 deletions(-)

diff --git a/.idea/mmsr.iml b/.idea/mmsr.iml
index 643c0574..75a1d172 100644
--- a/.idea/mmsr.iml
+++ b/.idea/mmsr.iml
@@ -2,8 +2,10 @@
 <module type="PYTHON_MODULE" version="4">
   <component name="NewModuleRootManager">
     <content url="file://$MODULE_DIR$">
+      <sourceFolder url="file://$MODULE_DIR$/codes" isTestSource="false" />
       <excludeFolder url="file://$MODULE_DIR$/datasets" />
       <excludeFolder url="file://$MODULE_DIR$/experiments" />
+      <excludeFolder url="file://$MODULE_DIR$/results" />
       <excludeFolder url="file://$MODULE_DIR$/tb_logger" />
     </content>
     <orderEntry type="jdk" jdkName="Python 3.7 (python37-torch)" jdkType="Python SDK" />
diff --git a/codes/models/archs/DiscriminatorResnet_arch.py b/codes/models/archs/DiscriminatorResnet_arch.py
index e30a3ab9..b1feea64 100644
--- a/codes/models/archs/DiscriminatorResnet_arch.py
+++ b/codes/models/archs/DiscriminatorResnet_arch.py
@@ -1,85 +1,195 @@
 import torch
 import torch.nn as nn
-import torchvision
-import models.archs.arch_util as arch_util
-import functools
-import torch.nn.functional as F
-import torch.nn.utils.spectral_norm as SpectralNorm
+import numpy as np
 
-# Class that halfs the image size (x4 complexity reduction) and doubles the filter size. Substantial resnet
-# processing is also performed.
-class ResnetDownsampleLayer(nn.Module):
-    def __init__(self, starting_channels: int, number_filters: int, filter_multiplier: int, residual_blocks_input: int, residual_blocks_skip_image: int, total_residual_blocks: int):
-        super(ResnetDownsampleLayer, self).__init__()
 
-        self.skip_image_reducer = SpectralNorm(nn.Conv2d(starting_channels, number_filters, 3, stride=1, padding=1, bias=True))
-        self.skip_image_res_trunk = arch_util.make_layer(functools.partial(arch_util.ResidualBlockSpectralNorm, nf=number_filters, total_residual_blocks=total_residual_blocks), residual_blocks_skip_image)
+__all__ = ['FixupResNet', 'fixup_resnet18', 'fixup_resnet34', 'fixup_resnet50', 'fixup_resnet101', 'fixup_resnet152']
 
-        self.input_reducer = SpectralNorm(nn.Conv2d(number_filters, number_filters*filter_multiplier, 3, stride=2, padding=1, bias=True))
-        self.res_trunk = arch_util.make_layer(functools.partial(arch_util.ResidualBlockSpectralNorm, nf=number_filters*filter_multiplier, total_residual_blocks=total_residual_blocks), residual_blocks_input)
 
+def conv3x3(in_planes, out_planes, stride=1):
+    """3x3 convolution with padding"""
+    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+                     padding=1, bias=False)
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+    """1x1 convolution"""
+    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+class FixupBasicBlock(nn.Module):
+    expansion = 1
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None):
+        super(FixupBasicBlock, self).__init__()
+        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
+        self.bias1a = nn.Parameter(torch.zeros(1))
+        self.conv1 = conv3x3(inplanes, planes, stride)
+        self.bias1b = nn.Parameter(torch.zeros(1))
         self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
-        arch_util.initialize_weights([self.input_reducer, self.skip_image_reducer], 1)
+        self.bias2a = nn.Parameter(torch.zeros(1))
+        self.conv2 = conv3x3(planes, planes)
+        self.scale = nn.Parameter(torch.ones(1))
+        self.bias2b = nn.Parameter(torch.zeros(1))
+        self.downsample = downsample
+        self.stride = stride
 
-    def forward(self, x, skip_image):
-        # Process the skip image first.
-        skip = self.lrelu(self.skip_image_reducer(skip_image))
-        skip = self.skip_image_res_trunk(skip)
+    def forward(self, x):
+        identity = x
+
+        out = self.conv1(x + self.bias1a)
+        out = self.lrelu(out + self.bias1b)
+
+        out = self.conv2(out + self.bias2a)
+        out = out * self.scale + self.bias2b
+
+        if self.downsample is not None:
+            identity = self.downsample(x + self.bias1a)
+
+        out += identity
+        out = self.lrelu(out)
 
-        # Concat the processed skip image onto the input and perform processing.
-        out = (x + skip) / 2
-        out = self.lrelu(self.input_reducer(out))
-        out = self.res_trunk(out)
         return out
 
-class DiscriminatorResnet(nn.Module):
-    # Discriminator that downsamples 5 times with resnet blocks at each layer. On each downsample, the filter size is
-    # increased by a factor of 2. Feeds the output of the convs into a dense for prediction at the logits. Scales the
-    # final dense based on the input image size. Intended for use with input images which are multiples of 32.
-    #
-    # This discriminator also includes provisions to pass an image at various downsample steps in directly. When this
-    # is done with a generator, it will allow much shorter gradient paths between the generator and discriminator. When
-    # no downsampled images are passed into the forward() pass, they will be automatically generated from the source
-    # image using interpolation.
-    #
-    # Uses spectral normalization rather than batch normalization.
-    def __init__(self, in_nc: int, nf: int, input_img_size: int, trunk_resblocks: int, skip_resblocks: int):
-        super(DiscriminatorResnet, self).__init__()
-        self.dimensionalize = nn.Conv2d(in_nc, nf, kernel_size=3, stride=1, padding=1, bias=True)
+class FixupBottleneck(nn.Module):
+    expansion = 4
 
-        # Trunk resblocks are the important things to get right, so use those. 5=number of downsample layers.
-        total_resblocks = trunk_resblocks * 5
-        self.downsample1 = ResnetDownsampleLayer(in_nc, nf, 2, trunk_resblocks, skip_resblocks, total_resblocks)
-        self.downsample2 = ResnetDownsampleLayer(in_nc, nf*2, 2, trunk_resblocks, skip_resblocks, total_resblocks)
-        self.downsample3 = ResnetDownsampleLayer(in_nc, nf*4, 2, trunk_resblocks, skip_resblocks, total_resblocks)
-        # At the bottom layers, we cap the filter multiplier. We want this particular network to focus as much on the
-        # macro-details at higher image dimensionality as it does to the feature details.
-        self.downsample4 = ResnetDownsampleLayer(in_nc, nf*8, 1, trunk_resblocks, skip_resblocks, total_resblocks)
-        self.downsample5 = ResnetDownsampleLayer(in_nc, nf*8, 1, trunk_resblocks, skip_resblocks, total_resblocks)
-        self.downsamplers = [self.downsample1, self.downsample2, self.downsample3, self.downsample4, self.downsample5]
+    def __init__(self, inplanes, planes, stride=1, downsample=None):
+        super(FixupBottleneck, self).__init__()
+        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
+        self.bias1a = nn.Parameter(torch.zeros(1))
+        self.conv1 = conv1x1(inplanes, planes)
+        self.bias1b = nn.Parameter(torch.zeros(1))
+        self.bias2a = nn.Parameter(torch.zeros(1))
+        self.conv2 = conv3x3(planes, planes, stride)
+        self.bias2b = nn.Parameter(torch.zeros(1))
+        self.bias3a = nn.Parameter(torch.zeros(1))
+        self.conv3 = conv1x1(planes, planes * self.expansion)
+        self.scale = nn.Parameter(torch.ones(1))
+        self.bias3b = nn.Parameter(torch.zeros(1))
+        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+        self.downsample = downsample
+        self.stride = stride
 
-        downsampled_image_size = input_img_size / 32
-        self.linear1 = nn.Linear(int(nf * 8 * downsampled_image_size * downsampled_image_size), 100)
-        self.linear2 = nn.Linear(100, 1)
+    def forward(self, x):
+        identity = x
 
-        # activation function
-        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+        out = self.conv1(x + self.bias1a)
+        out = self.lrelu(out + self.bias1b)
 
-        arch_util.initialize_weights([self.dimensionalize, self.linear1, self.linear2], 1)
+        out = self.conv2(out + self.bias2a)
+        out = self.lrelu(out + self.bias2b)
 
-    def forward(self, x, skip_images=None):
-        if skip_images is None:
-            # Sythesize them from x.
-            skip_images = []
-            for i in range(len(self.downsamplers)):
-                m = 2 ** i
-                skip_images.append(F.interpolate(x, scale_factor=1 / m, mode='bilinear', align_corners=False))
+        out = self.conv3(out + self.bias3a)
+        out = out * self.scale + self.bias3b
 
-        fea = self.dimensionalize(x)
-        for skip, d in zip(skip_images, self.downsamplers):
-            fea = d(fea, skip)
+        if self.downsample is not None:
+            identity = self.downsample(x + self.bias1a)
+
+        out += identity
+        out = self.lrelu(out)
 
-        fea = fea.view(fea.size(0), -1)
-        fea = self.lrelu(self.linear1(fea))
-        out = self.linear2(fea)
         return out
+
+
+class FixupResNet(nn.Module):
+
+    def __init__(self, block, layers, num_classes=1000):
+        super(FixupResNet, self).__init__()
+        self.num_layers = sum(layers)
+        self.inplanes = 64
+        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
+                               bias=False)
+        self.bias1 = nn.Parameter(torch.zeros(1))
+        self.relu = nn.ReLU(inplace=True)
+        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+        self.layer1 = self._make_layer(block, 64, layers[0])
+        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+        self.bias2 = nn.Parameter(torch.zeros(1))
+        self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+        for m in self.modules():
+            if isinstance(m, FixupBasicBlock):
+                nn.init.normal_(m.conv1.weight, mean=0, std=np.sqrt(2 / (m.conv1.weight.shape[0] * np.prod(m.conv1.weight.shape[2:]))) * self.num_layers ** (-0.5))
+                nn.init.constant_(m.conv2.weight, 0)
+                if m.downsample is not None:
+                    nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:]))))
+            elif isinstance(m, FixupBottleneck):
+                nn.init.normal_(m.conv1.weight, mean=0, std=np.sqrt(2 / (m.conv1.weight.shape[0] * np.prod(m.conv1.weight.shape[2:]))) * self.num_layers ** (-0.25))
+                nn.init.normal_(m.conv2.weight, mean=0, std=np.sqrt(2 / (m.conv2.weight.shape[0] * np.prod(m.conv2.weight.shape[2:]))) * self.num_layers ** (-0.25))
+                nn.init.constant_(m.conv3.weight, 0)
+                if m.downsample is not None:
+                    nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:]))))
+            elif isinstance(m, nn.Linear):
+                nn.init.constant_(m.weight, 0)
+                nn.init.constant_(m.bias, 0)
+
+    def _make_layer(self, block, planes, blocks, stride=1):
+        downsample = None
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = conv1x1(self.inplanes, planes * block.expansion, stride)
+
+        layers = []
+        layers.append(block(self.inplanes, planes, stride, downsample))
+        self.inplanes = planes * block.expansion
+        for _ in range(1, blocks):
+            layers.append(block(self.inplanes, planes))
+
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.relu(x + self.bias1)
+        x = self.maxpool(x)
+
+        x = self.layer1(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+        x = self.layer4(x)
+
+        x = self.avgpool(x)
+        x = x.view(x.size(0), -1)
+        x = self.fc(x + self.bias2)
+
+        return x
+
+
+def fixup_resnet18(**kwargs):
+    """Constructs a Fixup-ResNet-18 model.2
+    """
+    model = FixupResNet(FixupBasicBlock, [2, 2, 2, 2], **kwargs)
+    return model
+
+
+def fixup_resnet34(**kwargs):
+    """Constructs a Fixup-ResNet-34 model.
+    """
+    model = FixupResNet(FixupBasicBlock, [3, 4, 6, 3], **kwargs)
+    return model
+
+
+def fixup_resnet50(**kwargs):
+    """Constructs a Fixup-ResNet-50 model.
+    """
+    model = FixupResNet(FixupBottleneck, [3, 4, 6, 3], **kwargs)
+    return model
+
+
+def fixup_resnet101(**kwargs):
+    """Constructs a Fixup-ResNet-101 model.
+    """
+    model = FixupResNet(FixupBottleneck, [3, 4, 23, 3], **kwargs)
+    return model
+
+
+def fixup_resnet152(**kwargs):
+    """Constructs a Fixup-ResNet-152 model.
+    """
+    model = FixupResNet(FixupBottleneck, [3, 8, 36, 3], **kwargs)
+    return model
+
+
+__all__ = ['FixupResNet', 'fixup_resnet18', 'fixup_resnet34', 'fixup_resnet50', 'fixup_resnet101', 'fixup_resnet152']
\ No newline at end of file
diff --git a/codes/models/archs/arch_util.py b/codes/models/archs/arch_util.py
index c33af559..bf634f6b 100644
--- a/codes/models/archs/arch_util.py
+++ b/codes/models/archs/arch_util.py
@@ -5,13 +5,8 @@ import torch.nn.functional as F
 import torch.nn.utils.spectral_norm as SpectralNorm
 from math import sqrt
 
-def scale_conv_weights_fixup(conv, residual_block_count, m=2):
-    k = conv.kernel_size[0]
-    n = conv.out_channels
-    scaling_factor = residual_block_count ** (-1.0 / (2 * m - 2))
-    sigma = sqrt(2 / (k * k * n)) * scaling_factor
-    conv.weight.data = conv.weight.data * sigma
-    return conv
+def pixel_norm(x, epsilon=1e-8):
+    return x * torch.rsqrt(torch.mean(torch.pow(x, 2), dim=1, keepdims=True) + epsilon)
 
 def initialize_weights(net_l, scale=1):
     if not isinstance(net_l, list):
@@ -39,89 +34,6 @@ def make_layer(block, n_layers):
         layers.append(block())
     return nn.Sequential(*layers)
 
-def conv3x3(in_planes, out_planes, stride=1):
-    """3x3 convolution with padding"""
-    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
-                     padding=1, bias=False)
-
-def conv1x1(in_planes, out_planes, stride=1):
-    """1x1 convolution"""
-    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
-
-class FixupBasicBlock(nn.Module):
-    expansion = 1
-
-    def __init__(self, inplanes, planes, stride=1, downsample=None):
-        super(FixupBasicBlock, self).__init__()
-        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
-        self.bias1a = nn.Parameter(torch.zeros(1))
-        self.conv1 = conv3x3(inplanes, planes, stride)
-        self.bias1b = nn.Parameter(torch.zeros(1))
-        self.relu = nn.ReLU(inplace=True)
-        self.bias2a = nn.Parameter(torch.zeros(1))
-        self.conv2 = conv3x3(planes, planes)
-        self.scale = nn.Parameter(torch.ones(1))
-        self.bias2b = nn.Parameter(torch.zeros(1))
-        self.downsample = downsample
-        self.stride = stride
-
-    def forward(self, x):
-        identity = x
-
-        out = self.conv1(x + self.bias1a)
-        out = self.relu(out + self.bias1b)
-
-        out = self.conv2(out + self.bias2a)
-        out = out * self.scale + self.bias2b
-
-        if self.downsample is not None:
-            identity = self.downsample(x + self.bias1a)
-
-        out += identity
-        out = self.relu(out)
-
-        return out
-
-class FixupBottleneck(nn.Module):
-    expansion = 4
-
-    def __init__(self, inplanes, planes, stride=1, downsample=None):
-        super(FixupBottleneck, self).__init__()
-        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
-        self.bias1a = nn.Parameter(torch.zeros(1))
-        self.conv1 = conv1x1(inplanes, planes)
-        self.bias1b = nn.Parameter(torch.zeros(1))
-        self.bias2a = nn.Parameter(torch.zeros(1))
-        self.conv2 = conv3x3(planes, planes, stride)
-        self.bias2b = nn.Parameter(torch.zeros(1))
-        self.bias3a = nn.Parameter(torch.zeros(1))
-        self.conv3 = conv1x1(planes, planes * self.expansion)
-        self.scale = nn.Parameter(torch.ones(1))
-        self.bias3b = nn.Parameter(torch.zeros(1))
-        self.relu = nn.ReLU(inplace=True)
-        self.downsample = downsample
-        self.stride = stride
-
-    def forward(self, x):
-        identity = x
-
-        out = self.conv1(x + self.bias1a)
-        out = self.relu(out + self.bias1b)
-
-        out = self.conv2(out + self.bias2a)
-        out = self.relu(out + self.bias2b)
-
-        out = self.conv3(out + self.bias3a)
-        out = out * self.scale + self.bias3b
-
-        if self.downsample is not None:
-            identity = self.downsample(x + self.bias1a)
-
-        out += identity
-        out = self.relu(out)
-
-        return out
-
 class ResidualBlock(nn.Module):
     '''Residual block with BN
     ---Conv-BN-ReLU-Conv-+-
@@ -157,11 +69,7 @@ class ResidualBlockSpectralNorm(nn.Module):
         self.conv1 = SpectralNorm(nn.Conv2d(nf, nf, 3, 1, 1, bias=True))
         self.conv2 = SpectralNorm(nn.Conv2d(nf, nf, 3, 1, 1, bias=True))
 
-        # Initialize first.
         initialize_weights([self.conv1, self.conv2], 1)
-        # Then perform fixup scaling
-        self.conv1 = scale_conv_weights_fixup(self.conv1, total_residual_blocks)
-        self.conv2 = scale_conv_weights_fixup(self.conv2, total_residual_blocks)
 
     def forward(self, x):
         identity = x