diff --git a/codes/models/archs/DiscriminatorResnet_arch.py b/codes/models/archs/DiscriminatorResnet_arch.py index e0438ed3..70f80929 100644 --- a/codes/models/archs/DiscriminatorResnet_arch.py +++ b/codes/models/archs/DiscriminatorResnet_arch.py @@ -94,21 +94,20 @@ class FixupBottleneck(nn.Module): class FixupResNet(nn.Module): - def __init__(self, block, layers, num_classes=1000): + def __init__(self, block, layers, num_filters=64, num_classes=1000): super(FixupResNet, self).__init__() self.num_layers = sum(layers) - self.inplanes = 64 - self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + self.inplanes = num_filters + self.conv1 = nn.Conv2d(3, num_filters, kernel_size=7, stride=2, padding=3, bias=False) self.bias1 = nn.Parameter(torch.zeros(1)) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) - self.layer1 = self._make_layer(block, 64, layers[0]) - self.layer2 = self._make_layer(block, 128, layers[1], stride=2) - self.layer3 = self._make_layer(block, 256, layers[2], stride=2) - self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.layer1 = self._make_layer(block, num_filters, layers[0], stride=2) + self.layer2 = self._make_layer(block, num_filters*2, layers[1], stride=2) + self.layer3 = self._make_layer(block, num_filters*4, layers[2], stride=2) + self.layer4 = self._make_layer(block, num_filters*8, layers[3], stride=2) self.bias2 = nn.Parameter(torch.zeros(1)) - self.fc1 = nn.Linear(512 * 2 * 2, 100) + self.fc1 = nn.Linear(num_filters * 8 * 2 * 2, 100) self.fc2 = nn.Linear(100, num_classes) for m in self.modules(): @@ -123,9 +122,10 @@ class FixupResNet(nn.Module): nn.init.constant_(m.conv3.weight, 0) if m.downsample is not None: nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:])))) + ''' elif isinstance(m, nn.Linear): nn.init.constant_(m.weight, 0) - nn.init.constant_(m.bias, 0) + nn.init.constant_(m.bias, 0)''' def _make_layer(self, block, planes, blocks, stride=1): downsample = None @@ -143,7 +143,6 @@ class FixupResNet(nn.Module): def forward(self, x): x = self.conv1(x) x = self.lrelu(x + self.bias1) - x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) diff --git a/codes/models/archs/FlatProcessorNetNew_arch.py b/codes/models/archs/FlatProcessorNetNew_arch.py new file mode 100644 index 00000000..bc164fcf --- /dev/null +++ b/codes/models/archs/FlatProcessorNetNew_arch.py @@ -0,0 +1,134 @@ +import torch +import torch.nn as nn +import numpy as np +import torch.nn.functional as F + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class FixupBasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(FixupBasicBlock, self).__init__() + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.bias1a = nn.Parameter(torch.zeros(1)) + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes, affine=True) + self.bias1b = nn.Parameter(torch.zeros(1)) + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + self.bias2a = nn.Parameter(torch.zeros(1)) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes, affine=True) + self.scale = nn.Parameter(torch.ones(1)) + self.bias2b = nn.Parameter(torch.zeros(1)) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x + self.bias1a) + out = self.lrelu(out + self.bias1b) + + out = self.conv2(out + self.bias2a) + out = out * self.scale + self.bias2b + + if self.downsample is not None: + identity = self.downsample(x + self.bias1a) + + out += identity + out = self.lrelu(out) + + return out + + +class FixupResNet(nn.Module): + + def __init__(self, block, num_filters, layers, num_classes=1000): + super(FixupResNet, self).__init__() + self.num_layers = sum(layers) + self.bias1 = nn.Parameter(torch.zeros(1)) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + self.pixel_shuffle = nn.PixelShuffle(2) + + # 4 input channels, including the noise. + self.conv1 = nn.Conv2d(4, num_filters, kernel_size=7, stride=2, padding=3, + bias=False) + + self.inplanes = num_filters + self.down_layer1 = self._make_layer(block, num_filters, layers[0]) + self.down_layer2 = self._make_layer(block, num_filters, layers[1], stride=2) + self.down_layer3 = self._make_layer(block, num_filters * 4, layers[2], stride=2) + self.down_layer4 = self._make_layer(block, num_filters * 16, layers[3], stride=2) + + self.inplanes = num_filters * 4 + self.up_layer1 = self._make_layer(block, num_filters * 4, layers[4], stride=1) + self.inplanes = num_filters + self.up_layer2 = self._make_layer(block, num_filters, layers[5], stride=1) + + self.defilter = nn.Conv2d(num_filters, 3, kernel_size=5, stride=1, padding=2, bias=False) + + for m in self.modules(): + if isinstance(m, FixupBasicBlock): + nn.init.normal_(m.conv1.weight, mean=0, std=np.sqrt(2 / (m.conv1.weight.shape[0] * np.prod(m.conv1.weight.shape[2:]))) * self.num_layers ** (-0.5)) + nn.init.constant_(m.conv2.weight, 0) + if m.downsample is not None: + nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:])))) + elif isinstance(m, nn.Linear): + nn.init.constant_(m.weight, 0) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = conv1x1(self.inplanes, planes * block.expansion, stride) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + skip = x + + # Noise has the same shape as the input with only one channel. + rand_feature = torch.randn((x.shape[0], 1) + x.shape[2:], device=x.device, dtype=x.dtype) + x = torch.cat([x, rand_feature], dim=1) + + x = self.conv1(x) + x = self.lrelu(x + self.bias1) + + x = self.down_layer1(x) + x = self.down_layer2(x) + x = self.down_layer3(x) + x = self.down_layer4(x) + + x = self.pixel_shuffle(x) + x = self.up_layer1(x) + x = self.pixel_shuffle(x) + x = self.up_layer2(x) + + x = self.defilter(x) + + base = F.interpolate(skip, scale_factor=.25, mode='bilinear', align_corners=False) + return x + base + + +def fixup_resnet34(num_filters, **kwargs): + """Constructs a Fixup-ResNet-34 model. + """ + model = FixupResNet(FixupBasicBlock, num_filters, [3, 4, 6, 3, 2, 2], **kwargs) + return model \ No newline at end of file diff --git a/codes/models/networks.py b/codes/models/networks.py index 546f4659..995670bd 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -3,6 +3,7 @@ import models.archs.SRResNet_arch as SRResNet_arch import models.archs.discriminator_vgg_arch as SRGAN_arch import models.archs.DiscriminatorResnet_arch as DiscriminatorResnet_arch import models.archs.DiscriminatorResnetBN_arch as DiscriminatorResnetBN_arch +import models.archs.FlatProcessorNetNew_arch as FlatProcessorNetNew_arch import models.archs.RRDBNet_arch as RRDBNet_arch import models.archs.EDVR_arch as EDVR_arch import models.archs.HighToLowResNet as HighToLowResNet @@ -30,9 +31,10 @@ def define_G(opt): netG = HighToLowResNet.HighToLowResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], downscale=opt_net['scale']) elif which_model == 'FlatProcessorNet': - netG = FlatProcessorNet_arch.FlatProcessorNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], + '''netG = FlatProcessorNet_arch.FlatProcessorNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], downscale=opt_net['scale'], reduce_anneal_blocks=opt_net['ra_blocks'], - assembler_blocks=opt_net['assembler_blocks']) + assembler_blocks=opt_net['assembler_blocks'])''' + netG = FlatProcessorNetNew_arch.fixup_resnet34(num_filters=opt_net['nf']) # video restoration elif which_model == 'EDVR': netG = EDVR_arch.EDVR(nf=opt_net['nf'], nframes=opt_net['nframes'], @@ -56,7 +58,7 @@ def define_D(opt): if which_model == 'discriminator_vgg_128': netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128) elif which_model == 'discriminator_resnet': - netD = DiscriminatorResnetBN_arch.resnet32(num_filters=opt_net['nf'], num_classes=1) + netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1) else: raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) return netD diff --git a/codes/options/train/train_GAN_blacked_corrupt.yml b/codes/options/train/train_GAN_blacked_corrupt.yml index f2f0edd6..220d3375 100644 --- a/codes/options/train/train_GAN_blacked_corrupt.yml +++ b/codes/options/train/train_GAN_blacked_corrupt.yml @@ -16,8 +16,8 @@ datasets: dataroot_LQ: E:\\4k6k\\datasets\\ultra_lowq\\for_training mismatched_Data_OK: true use_shuffle: true - n_workers: 0 # per GPU - batch_size: 16 + n_workers: 8 # per GPU + batch_size: 32 target_size: 64 use_flip: false use_rot: false @@ -34,11 +34,14 @@ network_G: which_model_G: FlatProcessorNet in_nc: 3 out_nc: 3 - nf: 48 - ra_blocks: 4 - assembler_blocks: 3 + nf: 32 + ra_blocks: 6 + assembler_blocks: 4 network_D: + #which_model_D: discriminator_vgg_128 + #in_nc: 3 + #nf: 64 which_model_D: discriminator_resnet in_nc: 3 nf: 64 @@ -56,7 +59,7 @@ train: weight_decay_G: 0 beta1_G: 0.9 beta2_G: 0.99 - lr_D: !!float 1e-4 + lr_D: !!float 2e-4 weight_decay_D: 0 beta1_D: 0.9 beta2_D: 0.99 @@ -71,11 +74,11 @@ train: pixel_weight: !!float 1e-2 feature_criterion: l1 feature_weight: 0 - gan_type: ragan # gan | ragan + gan_type: gan # gan | ragan gan_weight: !!float 1e-1 D_update_ratio: 2 - D_init_iters: 1200 + D_init_iters: 0 manual_seed: 10 val_freq: !!float 5e2