From 3c2e5a0250f7cc71ab354dd1a945166c94c91076 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 24 May 2020 07:43:23 -0600 Subject: [PATCH] Apply fixes to resgen --- codes/data_scripts/extract_subimages.py | 20 +-- codes/models/archs/ResGen_arch.py | 9 +- codes/models/archs/biggan_gen_arch.py | 141 +++++-------------- codes/models/archs/biggan_layers.py | 7 +- codes/models/archs/biggan_sync_batchnorm.py | 144 +++++++++++++++++++- codes/models/networks.py | 2 +- codes/train.py | 2 +- 7 files changed, 194 insertions(+), 131 deletions(-) diff --git a/codes/data_scripts/extract_subimages.py b/codes/data_scripts/extract_subimages.py index dd0cf41f..a3dbd38f 100644 --- a/codes/data_scripts/extract_subimages.py +++ b/codes/data_scripts/extract_subimages.py @@ -13,18 +13,18 @@ import data.util as data_util # noqa: E402 def main(): mode = 'single' # single (one input folder) | pair (extract corresponding GT and LR pairs) - split_img = False + split_img = True opt = {} opt['n_thread'] = 20 opt['compression_level'] = 3 # 3 is the default value in cv2 # CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer # compression time. If read raw images during training, use 0 for faster IO speed. if mode == 'single': - opt['input_folder'] = 'Z:\\4k6k\\datasets\\adrianna\\adrianna_vids\\images' - opt['save_folder'] = 'Z:\\4k6k\\datasets\\adrianna\\adrianna_vids\\tiled' - opt['crop_sz'] = 64 # the size of each sub-image - opt['step'] = 48 # step of the sliding crop window - opt['thres_sz'] = 20 # size threshold + opt['input_folder'] = 'F:\\4k6k\\datasets\\vrp\\images_sized' + opt['save_folder'] = 'F:\\4k6k\\datasets\\vrp\\images_tiled' + opt['crop_sz'] = 320 # the size of each sub-image + opt['step'] = 280 # step of the sliding crop window + opt['thres_sz'] = 200 # size threshold extract_single(opt, split_img) elif mode == 'pair': GT_folder = '../../datasets/div2k/DIV2K_train_HR' @@ -120,8 +120,8 @@ def worker(path, opt, split_mode=False, left_img=True): raise ValueError('Wrong image shape - {}'.format(n_channels)) # Uncomment to filter any image that doesnt meet a threshold size. - #if w < 3000: - # return + if w < 3000: + return left = 0 right = w @@ -152,10 +152,10 @@ def worker(path, opt, split_mode=False, left_img=True): crop_img = img[x:x + crop_sz, y:y + crop_sz, :] crop_img = np.ascontiguousarray(crop_img) # If this fails, change it and the imwrite below to the write extension. - assert img_name.contains(".png") + assert ".jpg" in img_name cv2.imwrite( osp.join(opt['save_folder'], - img_name.replace('.png', '_l{:05d}_s{:03d}.png'.format(left, index))), crop_img, + img_name.replace('.jpg', '_l{:05d}_s{:03d}.png'.format(left, index))), crop_img, [cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']]) return 'Processing {:s} ...'.format(img_name) diff --git a/codes/models/archs/ResGen_arch.py b/codes/models/archs/ResGen_arch.py index ddaa3f36..0eefcc5e 100644 --- a/codes/models/archs/ResGen_arch.py +++ b/codes/models/archs/ResGen_arch.py @@ -13,10 +13,15 @@ def conv3x3(in_planes, out_planes, stride=1): padding=1, bias=False) def conv5x5(in_planes, out_planes, stride=1): - """3x3 convolution with padding""" + """5x5 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=5, stride=stride, padding=2, bias=False) +def conv7x7(in_planes, out_planes, stride=1): + """7x7 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=7, stride=stride, + padding=3, bias=False) + def conv1x1(in_planes, out_planes, stride=1): """1x1 convolution""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) @@ -62,7 +67,7 @@ class FixupResNet(nn.Module): def __init__(self, block, layers, upscale_applications=2, num_filters=64, inject_noise=False): super(FixupResNet, self).__init__() self.inject_noise = inject_noise - self.num_layers = sum(layers) + layers[-1] # The last layer is applied twice to achieve 4x upsampling. + self.num_layers = sum(layers) + layers[-1] * (upscale_applications - 1) # The last layer is applied repeatedly to achieve high level SR. self.inplanes = num_filters self.upscale_applications = upscale_applications # Part 1 - Process raw input image. Most denoising should appear here and this should be the most complicated diff --git a/codes/models/archs/biggan_gen_arch.py b/codes/models/archs/biggan_gen_arch.py index 6264967a..9ba4f704 100644 --- a/codes/models/archs/biggan_gen_arch.py +++ b/codes/models/archs/biggan_gen_arch.py @@ -1,17 +1,12 @@ # Source: https://github.com/ajbrock/BigGAN-PyTorch/blob/master/BigGANdeep.py -import numpy as np -import math import functools import torch import torch.nn as nn from torch.nn import init -import torch.optim as optim import torch.nn.functional as F -from torch.nn import Parameter as P import models.archs.biggan_layers as layers -from models.archs.biggan_sync_batchnorm import SynchronizedBatchNorm2d as SyncBatchNorm2d # BigGAN-deep: uses a different resblock and pattern @@ -45,11 +40,11 @@ class GBlock(nn.Module): # upsample layers self.upsample = upsample - def forward(self, x, y): + def forward(self, x): # Project down to channel ratio - h = self.conv1(self.activation(self.bn1(x, y))) + h = self.conv1(self.activation(self.bn1(x))) # Apply next BN-ReLU - h = self.activation(self.bn2(h, y)) + h = self.activation(self.bn2(h)) # Drop channels in x if necessary if self.in_channels != self.out_channels: x = x[:, :self.out_channels] @@ -59,61 +54,38 @@ class GBlock(nn.Module): x = self.upsample(x) # 3x3 convs h = self.conv2(h) - h = self.conv3(self.activation(self.bn3(h, y))) + h = self.conv3(self.activation(self.bn3(h))) # Final 1x1 conv - h = self.conv4(self.activation(self.bn4(h, y))) + h = self.conv4(self.activation(self.bn4(h))) return h + x -def G_arch(ch=64, attention='64', ksize='333333', dilation='111111'): +def G_arch(ch=64, attention='64', base_width=64): arch = {} - arch[256] = {'in_channels': [ch * item for item in [16, 16, 8, 8, 4, 2]], - 'out_channels': [ch * item for item in [16, 8, 8, 4, 2, 1]], - 'upsample': [True] * 6, - 'resolution': [8, 16, 32, 64, 128, 256], - 'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')]) - for i in range(3, 9)}} - arch[128] = {'in_channels': [ch * item for item in [16, 16, 8, 4, 2]], - 'out_channels': [ch * item for item in [16, 8, 4, 2, 1]], - 'upsample': [True] * 5, - 'resolution': [8, 16, 32, 64, 128], + arch[128] = {'in_channels': [ch * item for item in [2, 2, 1, 1]], + 'out_channels': [ch * item for item in [2, 1, 1, 1]], + 'upsample': [False, True, False, False], + 'resolution': [base_width, base_width, base_width*2, base_width*2], 'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')]) for i in range(3, 8)}} - arch[64] = {'in_channels': [ch * item for item in [16, 16, 8, 4]], - 'out_channels': [ch * item for item in [16, 8, 4, 2]], - 'upsample': [True] * 4, - 'resolution': [8, 16, 32, 64], - 'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')]) - for i in range(3, 7)}} - arch[32] = {'in_channels': [ch * item for item in [4, 4, 4]], - 'out_channels': [ch * item for item in [4, 4, 4]], - 'upsample': [True] * 3, - 'resolution': [8, 16, 32], - 'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')]) - for i in range(3, 6)}} return arch class Generator(nn.Module): - def __init__(self, G_ch=64, G_depth=2, dim_z=128, bottom_width=4, resolution=128, - G_kernel_size=3, G_attn='64', n_classes=1000, - num_G_SVs=1, num_G_SV_itrs=1, - G_shared=True, shared_dim=0, hier=False, + def __init__(self, G_ch=64, G_depth=2, bottom_width=4, resolution=128, + G_kernel_size=3, G_attn='64', + num_G_SVs=1, num_G_SV_itrs=1, hier=False, cross_replica=False, mybn=False, G_activation=nn.ReLU(inplace=False), - G_lr=5e-5, G_B1=0.0, G_B2=0.999, adam_eps=1e-8, - BN_eps=1e-5, SN_eps=1e-12, G_mixed_precision=False, G_fp16=False, - G_init='ortho', skip_init=False, no_optim=False, - G_param='SN', norm_style='bn', - **kwargs): + BN_eps=1e-5, SN_eps=1e-12, + G_init='ortho', skip_init=False, + G_param='SN', norm_style='bn'): super(Generator, self).__init__() - # Channel width mulitplier + # Channel width multiplier self.ch = G_ch # Number of resblocks per stage self.G_depth = G_depth - # Dimensionality of the latent space - self.dim_z = dim_z # The initial spatial dimensions self.bottom_width = bottom_width # Resolution of the output @@ -122,12 +94,6 @@ class Generator(nn.Module): self.kernel_size = G_kernel_size # Attention? self.attention = G_attn - # number of classes, for use in categorical conditional generation - self.n_classes = n_classes - # Use shared embeddings? - self.G_shared = G_shared - # Dimensionality of the shared embedding? Unused if not using G_shared - self.shared_dim = shared_dim if shared_dim > 0 else dim_z # Hierarchical latent space? self.hier = hier # Cross replica batchnorm? @@ -146,8 +112,6 @@ class Generator(nn.Module): self.BN_eps = BN_eps # Epsilon for Spectral Norm? self.SN_eps = SN_eps - # fp16? - self.fp16 = G_fp16 # Architecture dict self.arch = G_arch(self.ch, self.attention)[resolution] @@ -157,34 +121,23 @@ class Generator(nn.Module): kernel_size=3, padding=1, num_svs=num_G_SVs, num_itrs=num_G_SV_itrs, eps=self.SN_eps) - self.which_linear = functools.partial(layers.SNLinear, - num_svs=num_G_SVs, num_itrs=num_G_SV_itrs, - eps=self.SN_eps) else: self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1) - self.which_linear = nn.Linear - # We use a non-spectral-normed embedding here regardless; - # For some reason applying SN to G's embedding seems to randomly cripple G - self.which_embedding = nn.Embedding - bn_linear = (functools.partial(self.which_linear, bias=False) if self.G_shared - else self.which_embedding) - self.which_bn = functools.partial(layers.ccbn, - which_linear=bn_linear, + self.which_bn = functools.partial(layers.bn, cross_replica=self.cross_replica, mybn=self.mybn, - input_size=(self.shared_dim + self.dim_z if self.G_shared - else self.n_classes), norm_style=self.norm_style, eps=self.BN_eps) # Prepare model - # If not using shared embeddings, self.shared is just a passthrough - self.shared = (self.which_embedding(n_classes, self.shared_dim) if G_shared - else layers.identity()) - # First linear layer - self.linear = self.which_linear(self.dim_z + self.shared_dim, - self.arch['in_channels'][0] * (self.bottom_width ** 2)) + # First conv layer to project into feature-space + self.initial_conv = nn.Sequential(self.which_conv(3, self.arch['in_channels'][0]), + layers.bn(self.arch['in_channels'][0], + cross_replica=self.cross_replica, + mybn=self.mybn), + self.activation + ) # self.blocks is a doubly-nested list of modules, the outer loop intended # to be over blocks at a given resolution (resblocks and/or self-attention) @@ -222,26 +175,6 @@ class Generator(nn.Module): if not skip_init: self.init_weights() - # Set up optimizer - # If this is an EMA copy, no need for an optim, so just return now - if no_optim: - return - self.lr, self.B1, self.B2, self.adam_eps = G_lr, G_B1, G_B2, adam_eps - if G_mixed_precision: - print('Using fp16 adam in G...') - import utils - self.optim = utils.Adam16(params=self.parameters(), lr=self.lr, - betas=(self.B1, self.B2), weight_decay=0, - eps=self.adam_eps) - else: - self.optim = optim.Adam(params=self.parameters(), lr=self.lr, - betas=(self.B1, self.B2), weight_decay=0, - eps=self.adam_eps) - - # LR scheduling, left here for forward compatibility - # self.lr_sched = {'itr' : 0}# if self.progressive else {} - # self.j = 0 - # Initialize def init_weights(self): self.param_count = 0 @@ -260,25 +193,17 @@ class Generator(nn.Module): self.param_count += sum([p.data.nelement() for p in module.parameters()]) print('Param count for G''s initialized parameters: %d' % self.param_count) - # Note on this forward function: we pass in a y vector which has - # already been passed through G.shared to enable easy class-wise - # interpolation later. If we passed in the one-hot and then ran it through - # G.shared in this forward function, it would be harder to handle. - # NOTE: The z vs y dichotomy here is for compatibility with not-y - def forward(self, z, y): - # If hierarchical, concatenate zs and ys - if self.hier: - z = torch.cat([y, z], 1) - y = z - # First linear layer - h = self.linear(z) - # Reshape - h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width) + def forward(self, z): + # First conv layer to convert into correct filter-space. + h = self.initial_conv(z) # Loop over blocks for index, blocklist in enumerate(self.blocks): # Second inner loop in case block has multiple layers for block in blocklist: - h = block(h, y) + h = block(h) # Apply batchnorm-relu-conv-tanh at output - return torch.tanh(self.output_layer(h)) \ No newline at end of file + return (torch.tanh(self.output_layer(h)), ) + +def biggan_medium(num_filters): + return Generator(num_filters) \ No newline at end of file diff --git a/codes/models/archs/biggan_layers.py b/codes/models/archs/biggan_layers.py index 8e1179e5..58e24fc4 100644 --- a/codes/models/archs/biggan_layers.py +++ b/codes/models/archs/biggan_layers.py @@ -1,16 +1,11 @@ ''' Layers This file contains various layers for the BigGAN models. ''' -import numpy as np import torch import torch.nn as nn -from torch.nn import init -import torch.optim as optim import torch.nn.functional as F from torch.nn import Parameter as P -from sync_batchnorm import SynchronizedBatchNorm2d as SyncBN2d - # Projection of x onto y def proj(x, y): @@ -336,7 +331,7 @@ class ccbn(nn.Module): # Normal, non-class-conditional BN class bn(nn.Module): def __init__(self, output_size, eps=1e-5, momentum=0.1, - cross_replica=False, mybn=False): + cross_replica=False, mybn=False, norm_style=None): super(bn, self).__init__() self.output_size = output_size # Prepare gain and bias layers diff --git a/codes/models/archs/biggan_sync_batchnorm.py b/codes/models/archs/biggan_sync_batchnorm.py index a55a75a9..42e9c852 100644 --- a/codes/models/archs/biggan_sync_batchnorm.py +++ b/codes/models/archs/biggan_sync_batchnorm.py @@ -16,8 +16,6 @@ import torch.nn.functional as F from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast -from .comm import SyncMaster - __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] @@ -348,4 +346,144 @@ class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): if input.dim() != 5: raise ValueError('expected 5D input (got {}D input)' .format(input.dim())) - super(SynchronizedBatchNorm3d, self)._check_input_dim(input) \ No newline at end of file + super(SynchronizedBatchNorm3d, self)._check_input_dim(input) + + +# From ccomm.py +# -*- coding: utf-8 -*- +# File : comm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import queue +import collections +import threading + +__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] + + +class FutureResult(object): + """A thread-safe future implementation. Used only as one-to-one pipe.""" + + def __init__(self): + self._result = None + self._lock = threading.Lock() + self._cond = threading.Condition(self._lock) + + def put(self, result): + with self._lock: + assert self._result is None, 'Previous result has\'t been fetched.' + self._result = result + self._cond.notify() + + def get(self): + with self._lock: + if self._result is None: + self._cond.wait() + + res = self._result + self._result = None + return res + + +_MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) +_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) + + +class SlavePipe(_SlavePipeBase): + """Pipe for master-slave communication.""" + + def run_slave(self, msg): + self.queue.put((self.identifier, msg)) + ret = self.result.get() + self.queue.put(True) + return ret + + +class SyncMaster(object): + """An abstract `SyncMaster` object. + + - During the replication, as the data parallel will trigger an callback of each module, all slave devices should + call `register(id)` and obtain an `SlavePipe` to communicate with the master. + - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, + and passed to a registered callback. + - After receiving the messages, the master device should gather the information and determine to message passed + back to each slave devices. + """ + + def __init__(self, master_callback): + """ + + Args: + master_callback: a callback to be invoked after having collected messages from slave devices. + """ + self._master_callback = master_callback + self._queue = queue.Queue() + self._registry = collections.OrderedDict() + self._activated = False + + def __getstate__(self): + return {'master_callback': self._master_callback} + + def __setstate__(self, state): + self.__init__(state['master_callback']) + + def register_slave(self, identifier): + """ + Register an slave device. + + Args: + identifier: an identifier, usually is the device id. + + Returns: a `SlavePipe` object which can be used to communicate with the master device. + + """ + if self._activated: + assert self._queue.empty(), 'Queue is not clean before next initialization.' + self._activated = False + self._registry.clear() + future = FutureResult() + self._registry[identifier] = _MasterRegistry(future) + return SlavePipe(identifier, self._queue, future) + + def run_master(self, master_msg): + """ + Main entry for the master device in each forward pass. + The messages were first collected from each devices (including the master device), and then + an callback will be invoked to compute the message to be sent back to each devices + (including the master device). + + Args: + master_msg: the message that the master want to send to itself. This will be placed as the first + message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. + + Returns: the message to be sent back to the master device. + + """ + self._activated = True + + intermediates = [(0, master_msg)] + for i in range(self.nr_slaves): + intermediates.append(self._queue.get()) + + results = self._master_callback(intermediates) + assert results[0][0] == 0, 'The first result should belongs to the master.' + + for i, res in results: + if i == 0: + continue + self._registry[i].result.put(res) + + for i in range(self.nr_slaves): + assert self._queue.get() is True + + return results[0][1] + + @property + def nr_slaves(self): + return len(self._registry) \ No newline at end of file diff --git a/codes/models/networks.py b/codes/models/networks.py index c66ac347..b5d83a42 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -35,7 +35,7 @@ def define_G(opt, net_key='network_G'): upscale_applications=opt_net['upscale_applications'], num_filters=opt_net['nf'], inject_noise=opt_net['inject_noise']) elif which_model == "BigGan": - netG = biggan_arch.biggan_medium(filters=opt_net['nf']) + netG = biggan_arch.biggan_medium(num_filters=opt_net['nf']) # image corruption elif which_model == 'HighToLowResNet': diff --git a/codes/train.py b/codes/train.py index 6a09b391..f372293b 100644 --- a/codes/train.py +++ b/codes/train.py @@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/finetune_vix_resgenv2.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vix_biggan.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0)