Apply fixes to resgen

This commit is contained in:
James Betker 2020-05-24 07:43:23 -06:00
parent 446322754a
commit 3c2e5a0250
7 changed files with 194 additions and 131 deletions

View File

@ -13,18 +13,18 @@ import data.util as data_util # noqa: E402
def main():
mode = 'single' # single (one input folder) | pair (extract corresponding GT and LR pairs)
split_img = False
split_img = True
opt = {}
opt['n_thread'] = 20
opt['compression_level'] = 3 # 3 is the default value in cv2
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
# compression time. If read raw images during training, use 0 for faster IO speed.
if mode == 'single':
opt['input_folder'] = 'Z:\\4k6k\\datasets\\adrianna\\adrianna_vids\\images'
opt['save_folder'] = 'Z:\\4k6k\\datasets\\adrianna\\adrianna_vids\\tiled'
opt['crop_sz'] = 64 # the size of each sub-image
opt['step'] = 48 # step of the sliding crop window
opt['thres_sz'] = 20 # size threshold
opt['input_folder'] = 'F:\\4k6k\\datasets\\vrp\\images_sized'
opt['save_folder'] = 'F:\\4k6k\\datasets\\vrp\\images_tiled'
opt['crop_sz'] = 320 # the size of each sub-image
opt['step'] = 280 # step of the sliding crop window
opt['thres_sz'] = 200 # size threshold
extract_single(opt, split_img)
elif mode == 'pair':
GT_folder = '../../datasets/div2k/DIV2K_train_HR'
@ -120,8 +120,8 @@ def worker(path, opt, split_mode=False, left_img=True):
raise ValueError('Wrong image shape - {}'.format(n_channels))
# Uncomment to filter any image that doesnt meet a threshold size.
#if w < 3000:
# return
if w < 3000:
return
left = 0
right = w
@ -152,10 +152,10 @@ def worker(path, opt, split_mode=False, left_img=True):
crop_img = img[x:x + crop_sz, y:y + crop_sz, :]
crop_img = np.ascontiguousarray(crop_img)
# If this fails, change it and the imwrite below to the write extension.
assert img_name.contains(".png")
assert ".jpg" in img_name
cv2.imwrite(
osp.join(opt['save_folder'],
img_name.replace('.png', '_l{:05d}_s{:03d}.png'.format(left, index))), crop_img,
img_name.replace('.jpg', '_l{:05d}_s{:03d}.png'.format(left, index))), crop_img,
[cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']])
return 'Processing {:s} ...'.format(img_name)

View File

@ -13,10 +13,15 @@ def conv3x3(in_planes, out_planes, stride=1):
padding=1, bias=False)
def conv5x5(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
"""5x5 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=5, stride=stride,
padding=2, bias=False)
def conv7x7(in_planes, out_planes, stride=1):
"""7x7 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=7, stride=stride,
padding=3, bias=False)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
@ -62,7 +67,7 @@ class FixupResNet(nn.Module):
def __init__(self, block, layers, upscale_applications=2, num_filters=64, inject_noise=False):
super(FixupResNet, self).__init__()
self.inject_noise = inject_noise
self.num_layers = sum(layers) + layers[-1] # The last layer is applied twice to achieve 4x upsampling.
self.num_layers = sum(layers) + layers[-1] * (upscale_applications - 1) # The last layer is applied repeatedly to achieve high level SR.
self.inplanes = num_filters
self.upscale_applications = upscale_applications
# Part 1 - Process raw input image. Most denoising should appear here and this should be the most complicated

View File

@ -1,17 +1,12 @@
# Source: https://github.com/ajbrock/BigGAN-PyTorch/blob/master/BigGANdeep.py
import numpy as np
import math
import functools
import torch
import torch.nn as nn
from torch.nn import init
import torch.optim as optim
import torch.nn.functional as F
from torch.nn import Parameter as P
import models.archs.biggan_layers as layers
from models.archs.biggan_sync_batchnorm import SynchronizedBatchNorm2d as SyncBatchNorm2d
# BigGAN-deep: uses a different resblock and pattern
@ -45,11 +40,11 @@ class GBlock(nn.Module):
# upsample layers
self.upsample = upsample
def forward(self, x, y):
def forward(self, x):
# Project down to channel ratio
h = self.conv1(self.activation(self.bn1(x, y)))
h = self.conv1(self.activation(self.bn1(x)))
# Apply next BN-ReLU
h = self.activation(self.bn2(h, y))
h = self.activation(self.bn2(h))
# Drop channels in x if necessary
if self.in_channels != self.out_channels:
x = x[:, :self.out_channels]
@ -59,61 +54,38 @@ class GBlock(nn.Module):
x = self.upsample(x)
# 3x3 convs
h = self.conv2(h)
h = self.conv3(self.activation(self.bn3(h, y)))
h = self.conv3(self.activation(self.bn3(h)))
# Final 1x1 conv
h = self.conv4(self.activation(self.bn4(h, y)))
h = self.conv4(self.activation(self.bn4(h)))
return h + x
def G_arch(ch=64, attention='64', ksize='333333', dilation='111111'):
def G_arch(ch=64, attention='64', base_width=64):
arch = {}
arch[256] = {'in_channels': [ch * item for item in [16, 16, 8, 8, 4, 2]],
'out_channels': [ch * item for item in [16, 8, 8, 4, 2, 1]],
'upsample': [True] * 6,
'resolution': [8, 16, 32, 64, 128, 256],
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
for i in range(3, 9)}}
arch[128] = {'in_channels': [ch * item for item in [16, 16, 8, 4, 2]],
'out_channels': [ch * item for item in [16, 8, 4, 2, 1]],
'upsample': [True] * 5,
'resolution': [8, 16, 32, 64, 128],
arch[128] = {'in_channels': [ch * item for item in [2, 2, 1, 1]],
'out_channels': [ch * item for item in [2, 1, 1, 1]],
'upsample': [False, True, False, False],
'resolution': [base_width, base_width, base_width*2, base_width*2],
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
for i in range(3, 8)}}
arch[64] = {'in_channels': [ch * item for item in [16, 16, 8, 4]],
'out_channels': [ch * item for item in [16, 8, 4, 2]],
'upsample': [True] * 4,
'resolution': [8, 16, 32, 64],
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
for i in range(3, 7)}}
arch[32] = {'in_channels': [ch * item for item in [4, 4, 4]],
'out_channels': [ch * item for item in [4, 4, 4]],
'upsample': [True] * 3,
'resolution': [8, 16, 32],
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
for i in range(3, 6)}}
return arch
class Generator(nn.Module):
def __init__(self, G_ch=64, G_depth=2, dim_z=128, bottom_width=4, resolution=128,
G_kernel_size=3, G_attn='64', n_classes=1000,
num_G_SVs=1, num_G_SV_itrs=1,
G_shared=True, shared_dim=0, hier=False,
def __init__(self, G_ch=64, G_depth=2, bottom_width=4, resolution=128,
G_kernel_size=3, G_attn='64',
num_G_SVs=1, num_G_SV_itrs=1, hier=False,
cross_replica=False, mybn=False,
G_activation=nn.ReLU(inplace=False),
G_lr=5e-5, G_B1=0.0, G_B2=0.999, adam_eps=1e-8,
BN_eps=1e-5, SN_eps=1e-12, G_mixed_precision=False, G_fp16=False,
G_init='ortho', skip_init=False, no_optim=False,
G_param='SN', norm_style='bn',
**kwargs):
BN_eps=1e-5, SN_eps=1e-12,
G_init='ortho', skip_init=False,
G_param='SN', norm_style='bn'):
super(Generator, self).__init__()
# Channel width mulitplier
# Channel width multiplier
self.ch = G_ch
# Number of resblocks per stage
self.G_depth = G_depth
# Dimensionality of the latent space
self.dim_z = dim_z
# The initial spatial dimensions
self.bottom_width = bottom_width
# Resolution of the output
@ -122,12 +94,6 @@ class Generator(nn.Module):
self.kernel_size = G_kernel_size
# Attention?
self.attention = G_attn
# number of classes, for use in categorical conditional generation
self.n_classes = n_classes
# Use shared embeddings?
self.G_shared = G_shared
# Dimensionality of the shared embedding? Unused if not using G_shared
self.shared_dim = shared_dim if shared_dim > 0 else dim_z
# Hierarchical latent space?
self.hier = hier
# Cross replica batchnorm?
@ -146,8 +112,6 @@ class Generator(nn.Module):
self.BN_eps = BN_eps
# Epsilon for Spectral Norm?
self.SN_eps = SN_eps
# fp16?
self.fp16 = G_fp16
# Architecture dict
self.arch = G_arch(self.ch, self.attention)[resolution]
@ -157,34 +121,23 @@ class Generator(nn.Module):
kernel_size=3, padding=1,
num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
eps=self.SN_eps)
self.which_linear = functools.partial(layers.SNLinear,
num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
eps=self.SN_eps)
else:
self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
self.which_linear = nn.Linear
# We use a non-spectral-normed embedding here regardless;
# For some reason applying SN to G's embedding seems to randomly cripple G
self.which_embedding = nn.Embedding
bn_linear = (functools.partial(self.which_linear, bias=False) if self.G_shared
else self.which_embedding)
self.which_bn = functools.partial(layers.ccbn,
which_linear=bn_linear,
self.which_bn = functools.partial(layers.bn,
cross_replica=self.cross_replica,
mybn=self.mybn,
input_size=(self.shared_dim + self.dim_z if self.G_shared
else self.n_classes),
norm_style=self.norm_style,
eps=self.BN_eps)
# Prepare model
# If not using shared embeddings, self.shared is just a passthrough
self.shared = (self.which_embedding(n_classes, self.shared_dim) if G_shared
else layers.identity())
# First linear layer
self.linear = self.which_linear(self.dim_z + self.shared_dim,
self.arch['in_channels'][0] * (self.bottom_width ** 2))
# First conv layer to project into feature-space
self.initial_conv = nn.Sequential(self.which_conv(3, self.arch['in_channels'][0]),
layers.bn(self.arch['in_channels'][0],
cross_replica=self.cross_replica,
mybn=self.mybn),
self.activation
)
# self.blocks is a doubly-nested list of modules, the outer loop intended
# to be over blocks at a given resolution (resblocks and/or self-attention)
@ -222,26 +175,6 @@ class Generator(nn.Module):
if not skip_init:
self.init_weights()
# Set up optimizer
# If this is an EMA copy, no need for an optim, so just return now
if no_optim:
return
self.lr, self.B1, self.B2, self.adam_eps = G_lr, G_B1, G_B2, adam_eps
if G_mixed_precision:
print('Using fp16 adam in G...')
import utils
self.optim = utils.Adam16(params=self.parameters(), lr=self.lr,
betas=(self.B1, self.B2), weight_decay=0,
eps=self.adam_eps)
else:
self.optim = optim.Adam(params=self.parameters(), lr=self.lr,
betas=(self.B1, self.B2), weight_decay=0,
eps=self.adam_eps)
# LR scheduling, left here for forward compatibility
# self.lr_sched = {'itr' : 0}# if self.progressive else {}
# self.j = 0
# Initialize
def init_weights(self):
self.param_count = 0
@ -260,25 +193,17 @@ class Generator(nn.Module):
self.param_count += sum([p.data.nelement() for p in module.parameters()])
print('Param count for G''s initialized parameters: %d' % self.param_count)
# Note on this forward function: we pass in a y vector which has
# already been passed through G.shared to enable easy class-wise
# interpolation later. If we passed in the one-hot and then ran it through
# G.shared in this forward function, it would be harder to handle.
# NOTE: The z vs y dichotomy here is for compatibility with not-y
def forward(self, z, y):
# If hierarchical, concatenate zs and ys
if self.hier:
z = torch.cat([y, z], 1)
y = z
# First linear layer
h = self.linear(z)
# Reshape
h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width)
def forward(self, z):
# First conv layer to convert into correct filter-space.
h = self.initial_conv(z)
# Loop over blocks
for index, blocklist in enumerate(self.blocks):
# Second inner loop in case block has multiple layers
for block in blocklist:
h = block(h, y)
h = block(h)
# Apply batchnorm-relu-conv-tanh at output
return torch.tanh(self.output_layer(h))
return (torch.tanh(self.output_layer(h)), )
def biggan_medium(num_filters):
return Generator(num_filters)

View File

@ -1,16 +1,11 @@
''' Layers
This file contains various layers for the BigGAN models.
'''
import numpy as np
import torch
import torch.nn as nn
from torch.nn import init
import torch.optim as optim
import torch.nn.functional as F
from torch.nn import Parameter as P
from sync_batchnorm import SynchronizedBatchNorm2d as SyncBN2d
# Projection of x onto y
def proj(x, y):
@ -336,7 +331,7 @@ class ccbn(nn.Module):
# Normal, non-class-conditional BN
class bn(nn.Module):
def __init__(self, output_size, eps=1e-5, momentum=0.1,
cross_replica=False, mybn=False):
cross_replica=False, mybn=False, norm_style=None):
super(bn, self).__init__()
self.output_size = output_size
# Prepare gain and bias layers

View File

@ -16,8 +16,6 @@ import torch.nn.functional as F
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
from .comm import SyncMaster
__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
@ -348,4 +346,144 @@ class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
if input.dim() != 5:
raise ValueError('expected 5D input (got {}D input)'
.format(input.dim()))
super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
# From ccomm.py
# -*- coding: utf-8 -*-
# File : comm.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
import queue
import collections
import threading
__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
class FutureResult(object):
"""A thread-safe future implementation. Used only as one-to-one pipe."""
def __init__(self):
self._result = None
self._lock = threading.Lock()
self._cond = threading.Condition(self._lock)
def put(self, result):
with self._lock:
assert self._result is None, 'Previous result has\'t been fetched.'
self._result = result
self._cond.notify()
def get(self):
with self._lock:
if self._result is None:
self._cond.wait()
res = self._result
self._result = None
return res
_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
class SlavePipe(_SlavePipeBase):
"""Pipe for master-slave communication."""
def run_slave(self, msg):
self.queue.put((self.identifier, msg))
ret = self.result.get()
self.queue.put(True)
return ret
class SyncMaster(object):
"""An abstract `SyncMaster` object.
- During the replication, as the data parallel will trigger an callback of each module, all slave devices should
call `register(id)` and obtain an `SlavePipe` to communicate with the master.
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
and passed to a registered callback.
- After receiving the messages, the master device should gather the information and determine to message passed
back to each slave devices.
"""
def __init__(self, master_callback):
"""
Args:
master_callback: a callback to be invoked after having collected messages from slave devices.
"""
self._master_callback = master_callback
self._queue = queue.Queue()
self._registry = collections.OrderedDict()
self._activated = False
def __getstate__(self):
return {'master_callback': self._master_callback}
def __setstate__(self, state):
self.__init__(state['master_callback'])
def register_slave(self, identifier):
"""
Register an slave device.
Args:
identifier: an identifier, usually is the device id.
Returns: a `SlavePipe` object which can be used to communicate with the master device.
"""
if self._activated:
assert self._queue.empty(), 'Queue is not clean before next initialization.'
self._activated = False
self._registry.clear()
future = FutureResult()
self._registry[identifier] = _MasterRegistry(future)
return SlavePipe(identifier, self._queue, future)
def run_master(self, master_msg):
"""
Main entry for the master device in each forward pass.
The messages were first collected from each devices (including the master device), and then
an callback will be invoked to compute the message to be sent back to each devices
(including the master device).
Args:
master_msg: the message that the master want to send to itself. This will be placed as the first
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
Returns: the message to be sent back to the master device.
"""
self._activated = True
intermediates = [(0, master_msg)]
for i in range(self.nr_slaves):
intermediates.append(self._queue.get())
results = self._master_callback(intermediates)
assert results[0][0] == 0, 'The first result should belongs to the master.'
for i, res in results:
if i == 0:
continue
self._registry[i].result.put(res)
for i in range(self.nr_slaves):
assert self._queue.get() is True
return results[0][1]
@property
def nr_slaves(self):
return len(self._registry)

View File

@ -35,7 +35,7 @@ def define_G(opt, net_key='network_G'):
upscale_applications=opt_net['upscale_applications'], num_filters=opt_net['nf'],
inject_noise=opt_net['inject_noise'])
elif which_model == "BigGan":
netG = biggan_arch.biggan_medium(filters=opt_net['nf'])
netG = biggan_arch.biggan_medium(num_filters=opt_net['nf'])
# image corruption
elif which_model == 'HighToLowResNet':

View File

@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs):
def main():
#### options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/finetune_vix_resgenv2.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vix_biggan.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)