Apply fixes to resgen
This commit is contained in:
parent
446322754a
commit
3c2e5a0250
|
@ -13,18 +13,18 @@ import data.util as data_util # noqa: E402
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
mode = 'single' # single (one input folder) | pair (extract corresponding GT and LR pairs)
|
mode = 'single' # single (one input folder) | pair (extract corresponding GT and LR pairs)
|
||||||
split_img = False
|
split_img = True
|
||||||
opt = {}
|
opt = {}
|
||||||
opt['n_thread'] = 20
|
opt['n_thread'] = 20
|
||||||
opt['compression_level'] = 3 # 3 is the default value in cv2
|
opt['compression_level'] = 3 # 3 is the default value in cv2
|
||||||
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
|
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
|
||||||
# compression time. If read raw images during training, use 0 for faster IO speed.
|
# compression time. If read raw images during training, use 0 for faster IO speed.
|
||||||
if mode == 'single':
|
if mode == 'single':
|
||||||
opt['input_folder'] = 'Z:\\4k6k\\datasets\\adrianna\\adrianna_vids\\images'
|
opt['input_folder'] = 'F:\\4k6k\\datasets\\vrp\\images_sized'
|
||||||
opt['save_folder'] = 'Z:\\4k6k\\datasets\\adrianna\\adrianna_vids\\tiled'
|
opt['save_folder'] = 'F:\\4k6k\\datasets\\vrp\\images_tiled'
|
||||||
opt['crop_sz'] = 64 # the size of each sub-image
|
opt['crop_sz'] = 320 # the size of each sub-image
|
||||||
opt['step'] = 48 # step of the sliding crop window
|
opt['step'] = 280 # step of the sliding crop window
|
||||||
opt['thres_sz'] = 20 # size threshold
|
opt['thres_sz'] = 200 # size threshold
|
||||||
extract_single(opt, split_img)
|
extract_single(opt, split_img)
|
||||||
elif mode == 'pair':
|
elif mode == 'pair':
|
||||||
GT_folder = '../../datasets/div2k/DIV2K_train_HR'
|
GT_folder = '../../datasets/div2k/DIV2K_train_HR'
|
||||||
|
@ -120,8 +120,8 @@ def worker(path, opt, split_mode=False, left_img=True):
|
||||||
raise ValueError('Wrong image shape - {}'.format(n_channels))
|
raise ValueError('Wrong image shape - {}'.format(n_channels))
|
||||||
|
|
||||||
# Uncomment to filter any image that doesnt meet a threshold size.
|
# Uncomment to filter any image that doesnt meet a threshold size.
|
||||||
#if w < 3000:
|
if w < 3000:
|
||||||
# return
|
return
|
||||||
|
|
||||||
left = 0
|
left = 0
|
||||||
right = w
|
right = w
|
||||||
|
@ -152,10 +152,10 @@ def worker(path, opt, split_mode=False, left_img=True):
|
||||||
crop_img = img[x:x + crop_sz, y:y + crop_sz, :]
|
crop_img = img[x:x + crop_sz, y:y + crop_sz, :]
|
||||||
crop_img = np.ascontiguousarray(crop_img)
|
crop_img = np.ascontiguousarray(crop_img)
|
||||||
# If this fails, change it and the imwrite below to the write extension.
|
# If this fails, change it and the imwrite below to the write extension.
|
||||||
assert img_name.contains(".png")
|
assert ".jpg" in img_name
|
||||||
cv2.imwrite(
|
cv2.imwrite(
|
||||||
osp.join(opt['save_folder'],
|
osp.join(opt['save_folder'],
|
||||||
img_name.replace('.png', '_l{:05d}_s{:03d}.png'.format(left, index))), crop_img,
|
img_name.replace('.jpg', '_l{:05d}_s{:03d}.png'.format(left, index))), crop_img,
|
||||||
[cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']])
|
[cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']])
|
||||||
return 'Processing {:s} ...'.format(img_name)
|
return 'Processing {:s} ...'.format(img_name)
|
||||||
|
|
||||||
|
|
|
@ -13,10 +13,15 @@ def conv3x3(in_planes, out_planes, stride=1):
|
||||||
padding=1, bias=False)
|
padding=1, bias=False)
|
||||||
|
|
||||||
def conv5x5(in_planes, out_planes, stride=1):
|
def conv5x5(in_planes, out_planes, stride=1):
|
||||||
"""3x3 convolution with padding"""
|
"""5x5 convolution with padding"""
|
||||||
return nn.Conv2d(in_planes, out_planes, kernel_size=5, stride=stride,
|
return nn.Conv2d(in_planes, out_planes, kernel_size=5, stride=stride,
|
||||||
padding=2, bias=False)
|
padding=2, bias=False)
|
||||||
|
|
||||||
|
def conv7x7(in_planes, out_planes, stride=1):
|
||||||
|
"""7x7 convolution with padding"""
|
||||||
|
return nn.Conv2d(in_planes, out_planes, kernel_size=7, stride=stride,
|
||||||
|
padding=3, bias=False)
|
||||||
|
|
||||||
def conv1x1(in_planes, out_planes, stride=1):
|
def conv1x1(in_planes, out_planes, stride=1):
|
||||||
"""1x1 convolution"""
|
"""1x1 convolution"""
|
||||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||||
|
@ -62,7 +67,7 @@ class FixupResNet(nn.Module):
|
||||||
def __init__(self, block, layers, upscale_applications=2, num_filters=64, inject_noise=False):
|
def __init__(self, block, layers, upscale_applications=2, num_filters=64, inject_noise=False):
|
||||||
super(FixupResNet, self).__init__()
|
super(FixupResNet, self).__init__()
|
||||||
self.inject_noise = inject_noise
|
self.inject_noise = inject_noise
|
||||||
self.num_layers = sum(layers) + layers[-1] # The last layer is applied twice to achieve 4x upsampling.
|
self.num_layers = sum(layers) + layers[-1] * (upscale_applications - 1) # The last layer is applied repeatedly to achieve high level SR.
|
||||||
self.inplanes = num_filters
|
self.inplanes = num_filters
|
||||||
self.upscale_applications = upscale_applications
|
self.upscale_applications = upscale_applications
|
||||||
# Part 1 - Process raw input image. Most denoising should appear here and this should be the most complicated
|
# Part 1 - Process raw input image. Most denoising should appear here and this should be the most complicated
|
||||||
|
|
|
@ -1,17 +1,12 @@
|
||||||
# Source: https://github.com/ajbrock/BigGAN-PyTorch/blob/master/BigGANdeep.py
|
# Source: https://github.com/ajbrock/BigGAN-PyTorch/blob/master/BigGANdeep.py
|
||||||
import numpy as np
|
|
||||||
import math
|
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn import init
|
from torch.nn import init
|
||||||
import torch.optim as optim
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.nn import Parameter as P
|
|
||||||
|
|
||||||
import models.archs.biggan_layers as layers
|
import models.archs.biggan_layers as layers
|
||||||
from models.archs.biggan_sync_batchnorm import SynchronizedBatchNorm2d as SyncBatchNorm2d
|
|
||||||
|
|
||||||
# BigGAN-deep: uses a different resblock and pattern
|
# BigGAN-deep: uses a different resblock and pattern
|
||||||
|
|
||||||
|
@ -45,11 +40,11 @@ class GBlock(nn.Module):
|
||||||
# upsample layers
|
# upsample layers
|
||||||
self.upsample = upsample
|
self.upsample = upsample
|
||||||
|
|
||||||
def forward(self, x, y):
|
def forward(self, x):
|
||||||
# Project down to channel ratio
|
# Project down to channel ratio
|
||||||
h = self.conv1(self.activation(self.bn1(x, y)))
|
h = self.conv1(self.activation(self.bn1(x)))
|
||||||
# Apply next BN-ReLU
|
# Apply next BN-ReLU
|
||||||
h = self.activation(self.bn2(h, y))
|
h = self.activation(self.bn2(h))
|
||||||
# Drop channels in x if necessary
|
# Drop channels in x if necessary
|
||||||
if self.in_channels != self.out_channels:
|
if self.in_channels != self.out_channels:
|
||||||
x = x[:, :self.out_channels]
|
x = x[:, :self.out_channels]
|
||||||
|
@ -59,61 +54,38 @@ class GBlock(nn.Module):
|
||||||
x = self.upsample(x)
|
x = self.upsample(x)
|
||||||
# 3x3 convs
|
# 3x3 convs
|
||||||
h = self.conv2(h)
|
h = self.conv2(h)
|
||||||
h = self.conv3(self.activation(self.bn3(h, y)))
|
h = self.conv3(self.activation(self.bn3(h)))
|
||||||
# Final 1x1 conv
|
# Final 1x1 conv
|
||||||
h = self.conv4(self.activation(self.bn4(h, y)))
|
h = self.conv4(self.activation(self.bn4(h)))
|
||||||
return h + x
|
return h + x
|
||||||
|
|
||||||
|
|
||||||
def G_arch(ch=64, attention='64', ksize='333333', dilation='111111'):
|
def G_arch(ch=64, attention='64', base_width=64):
|
||||||
arch = {}
|
arch = {}
|
||||||
arch[256] = {'in_channels': [ch * item for item in [16, 16, 8, 8, 4, 2]],
|
arch[128] = {'in_channels': [ch * item for item in [2, 2, 1, 1]],
|
||||||
'out_channels': [ch * item for item in [16, 8, 8, 4, 2, 1]],
|
'out_channels': [ch * item for item in [2, 1, 1, 1]],
|
||||||
'upsample': [True] * 6,
|
'upsample': [False, True, False, False],
|
||||||
'resolution': [8, 16, 32, 64, 128, 256],
|
'resolution': [base_width, base_width, base_width*2, base_width*2],
|
||||||
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
|
|
||||||
for i in range(3, 9)}}
|
|
||||||
arch[128] = {'in_channels': [ch * item for item in [16, 16, 8, 4, 2]],
|
|
||||||
'out_channels': [ch * item for item in [16, 8, 4, 2, 1]],
|
|
||||||
'upsample': [True] * 5,
|
|
||||||
'resolution': [8, 16, 32, 64, 128],
|
|
||||||
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
|
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
|
||||||
for i in range(3, 8)}}
|
for i in range(3, 8)}}
|
||||||
arch[64] = {'in_channels': [ch * item for item in [16, 16, 8, 4]],
|
|
||||||
'out_channels': [ch * item for item in [16, 8, 4, 2]],
|
|
||||||
'upsample': [True] * 4,
|
|
||||||
'resolution': [8, 16, 32, 64],
|
|
||||||
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
|
|
||||||
for i in range(3, 7)}}
|
|
||||||
arch[32] = {'in_channels': [ch * item for item in [4, 4, 4]],
|
|
||||||
'out_channels': [ch * item for item in [4, 4, 4]],
|
|
||||||
'upsample': [True] * 3,
|
|
||||||
'resolution': [8, 16, 32],
|
|
||||||
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
|
|
||||||
for i in range(3, 6)}}
|
|
||||||
|
|
||||||
return arch
|
return arch
|
||||||
|
|
||||||
|
|
||||||
class Generator(nn.Module):
|
class Generator(nn.Module):
|
||||||
def __init__(self, G_ch=64, G_depth=2, dim_z=128, bottom_width=4, resolution=128,
|
def __init__(self, G_ch=64, G_depth=2, bottom_width=4, resolution=128,
|
||||||
G_kernel_size=3, G_attn='64', n_classes=1000,
|
G_kernel_size=3, G_attn='64',
|
||||||
num_G_SVs=1, num_G_SV_itrs=1,
|
num_G_SVs=1, num_G_SV_itrs=1, hier=False,
|
||||||
G_shared=True, shared_dim=0, hier=False,
|
|
||||||
cross_replica=False, mybn=False,
|
cross_replica=False, mybn=False,
|
||||||
G_activation=nn.ReLU(inplace=False),
|
G_activation=nn.ReLU(inplace=False),
|
||||||
G_lr=5e-5, G_B1=0.0, G_B2=0.999, adam_eps=1e-8,
|
BN_eps=1e-5, SN_eps=1e-12,
|
||||||
BN_eps=1e-5, SN_eps=1e-12, G_mixed_precision=False, G_fp16=False,
|
G_init='ortho', skip_init=False,
|
||||||
G_init='ortho', skip_init=False, no_optim=False,
|
G_param='SN', norm_style='bn'):
|
||||||
G_param='SN', norm_style='bn',
|
|
||||||
**kwargs):
|
|
||||||
super(Generator, self).__init__()
|
super(Generator, self).__init__()
|
||||||
# Channel width mulitplier
|
# Channel width multiplier
|
||||||
self.ch = G_ch
|
self.ch = G_ch
|
||||||
# Number of resblocks per stage
|
# Number of resblocks per stage
|
||||||
self.G_depth = G_depth
|
self.G_depth = G_depth
|
||||||
# Dimensionality of the latent space
|
|
||||||
self.dim_z = dim_z
|
|
||||||
# The initial spatial dimensions
|
# The initial spatial dimensions
|
||||||
self.bottom_width = bottom_width
|
self.bottom_width = bottom_width
|
||||||
# Resolution of the output
|
# Resolution of the output
|
||||||
|
@ -122,12 +94,6 @@ class Generator(nn.Module):
|
||||||
self.kernel_size = G_kernel_size
|
self.kernel_size = G_kernel_size
|
||||||
# Attention?
|
# Attention?
|
||||||
self.attention = G_attn
|
self.attention = G_attn
|
||||||
# number of classes, for use in categorical conditional generation
|
|
||||||
self.n_classes = n_classes
|
|
||||||
# Use shared embeddings?
|
|
||||||
self.G_shared = G_shared
|
|
||||||
# Dimensionality of the shared embedding? Unused if not using G_shared
|
|
||||||
self.shared_dim = shared_dim if shared_dim > 0 else dim_z
|
|
||||||
# Hierarchical latent space?
|
# Hierarchical latent space?
|
||||||
self.hier = hier
|
self.hier = hier
|
||||||
# Cross replica batchnorm?
|
# Cross replica batchnorm?
|
||||||
|
@ -146,8 +112,6 @@ class Generator(nn.Module):
|
||||||
self.BN_eps = BN_eps
|
self.BN_eps = BN_eps
|
||||||
# Epsilon for Spectral Norm?
|
# Epsilon for Spectral Norm?
|
||||||
self.SN_eps = SN_eps
|
self.SN_eps = SN_eps
|
||||||
# fp16?
|
|
||||||
self.fp16 = G_fp16
|
|
||||||
# Architecture dict
|
# Architecture dict
|
||||||
self.arch = G_arch(self.ch, self.attention)[resolution]
|
self.arch = G_arch(self.ch, self.attention)[resolution]
|
||||||
|
|
||||||
|
@ -157,34 +121,23 @@ class Generator(nn.Module):
|
||||||
kernel_size=3, padding=1,
|
kernel_size=3, padding=1,
|
||||||
num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
|
num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
|
||||||
eps=self.SN_eps)
|
eps=self.SN_eps)
|
||||||
self.which_linear = functools.partial(layers.SNLinear,
|
|
||||||
num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
|
|
||||||
eps=self.SN_eps)
|
|
||||||
else:
|
else:
|
||||||
self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
|
self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
|
||||||
self.which_linear = nn.Linear
|
|
||||||
|
|
||||||
# We use a non-spectral-normed embedding here regardless;
|
self.which_bn = functools.partial(layers.bn,
|
||||||
# For some reason applying SN to G's embedding seems to randomly cripple G
|
|
||||||
self.which_embedding = nn.Embedding
|
|
||||||
bn_linear = (functools.partial(self.which_linear, bias=False) if self.G_shared
|
|
||||||
else self.which_embedding)
|
|
||||||
self.which_bn = functools.partial(layers.ccbn,
|
|
||||||
which_linear=bn_linear,
|
|
||||||
cross_replica=self.cross_replica,
|
cross_replica=self.cross_replica,
|
||||||
mybn=self.mybn,
|
mybn=self.mybn,
|
||||||
input_size=(self.shared_dim + self.dim_z if self.G_shared
|
|
||||||
else self.n_classes),
|
|
||||||
norm_style=self.norm_style,
|
norm_style=self.norm_style,
|
||||||
eps=self.BN_eps)
|
eps=self.BN_eps)
|
||||||
|
|
||||||
# Prepare model
|
# Prepare model
|
||||||
# If not using shared embeddings, self.shared is just a passthrough
|
# First conv layer to project into feature-space
|
||||||
self.shared = (self.which_embedding(n_classes, self.shared_dim) if G_shared
|
self.initial_conv = nn.Sequential(self.which_conv(3, self.arch['in_channels'][0]),
|
||||||
else layers.identity())
|
layers.bn(self.arch['in_channels'][0],
|
||||||
# First linear layer
|
cross_replica=self.cross_replica,
|
||||||
self.linear = self.which_linear(self.dim_z + self.shared_dim,
|
mybn=self.mybn),
|
||||||
self.arch['in_channels'][0] * (self.bottom_width ** 2))
|
self.activation
|
||||||
|
)
|
||||||
|
|
||||||
# self.blocks is a doubly-nested list of modules, the outer loop intended
|
# self.blocks is a doubly-nested list of modules, the outer loop intended
|
||||||
# to be over blocks at a given resolution (resblocks and/or self-attention)
|
# to be over blocks at a given resolution (resblocks and/or self-attention)
|
||||||
|
@ -222,26 +175,6 @@ class Generator(nn.Module):
|
||||||
if not skip_init:
|
if not skip_init:
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
# Set up optimizer
|
|
||||||
# If this is an EMA copy, no need for an optim, so just return now
|
|
||||||
if no_optim:
|
|
||||||
return
|
|
||||||
self.lr, self.B1, self.B2, self.adam_eps = G_lr, G_B1, G_B2, adam_eps
|
|
||||||
if G_mixed_precision:
|
|
||||||
print('Using fp16 adam in G...')
|
|
||||||
import utils
|
|
||||||
self.optim = utils.Adam16(params=self.parameters(), lr=self.lr,
|
|
||||||
betas=(self.B1, self.B2), weight_decay=0,
|
|
||||||
eps=self.adam_eps)
|
|
||||||
else:
|
|
||||||
self.optim = optim.Adam(params=self.parameters(), lr=self.lr,
|
|
||||||
betas=(self.B1, self.B2), weight_decay=0,
|
|
||||||
eps=self.adam_eps)
|
|
||||||
|
|
||||||
# LR scheduling, left here for forward compatibility
|
|
||||||
# self.lr_sched = {'itr' : 0}# if self.progressive else {}
|
|
||||||
# self.j = 0
|
|
||||||
|
|
||||||
# Initialize
|
# Initialize
|
||||||
def init_weights(self):
|
def init_weights(self):
|
||||||
self.param_count = 0
|
self.param_count = 0
|
||||||
|
@ -260,25 +193,17 @@ class Generator(nn.Module):
|
||||||
self.param_count += sum([p.data.nelement() for p in module.parameters()])
|
self.param_count += sum([p.data.nelement() for p in module.parameters()])
|
||||||
print('Param count for G''s initialized parameters: %d' % self.param_count)
|
print('Param count for G''s initialized parameters: %d' % self.param_count)
|
||||||
|
|
||||||
# Note on this forward function: we pass in a y vector which has
|
def forward(self, z):
|
||||||
# already been passed through G.shared to enable easy class-wise
|
# First conv layer to convert into correct filter-space.
|
||||||
# interpolation later. If we passed in the one-hot and then ran it through
|
h = self.initial_conv(z)
|
||||||
# G.shared in this forward function, it would be harder to handle.
|
|
||||||
# NOTE: The z vs y dichotomy here is for compatibility with not-y
|
|
||||||
def forward(self, z, y):
|
|
||||||
# If hierarchical, concatenate zs and ys
|
|
||||||
if self.hier:
|
|
||||||
z = torch.cat([y, z], 1)
|
|
||||||
y = z
|
|
||||||
# First linear layer
|
|
||||||
h = self.linear(z)
|
|
||||||
# Reshape
|
|
||||||
h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width)
|
|
||||||
# Loop over blocks
|
# Loop over blocks
|
||||||
for index, blocklist in enumerate(self.blocks):
|
for index, blocklist in enumerate(self.blocks):
|
||||||
# Second inner loop in case block has multiple layers
|
# Second inner loop in case block has multiple layers
|
||||||
for block in blocklist:
|
for block in blocklist:
|
||||||
h = block(h, y)
|
h = block(h)
|
||||||
|
|
||||||
# Apply batchnorm-relu-conv-tanh at output
|
# Apply batchnorm-relu-conv-tanh at output
|
||||||
return torch.tanh(self.output_layer(h))
|
return (torch.tanh(self.output_layer(h)), )
|
||||||
|
|
||||||
|
def biggan_medium(num_filters):
|
||||||
|
return Generator(num_filters)
|
|
@ -1,16 +1,11 @@
|
||||||
''' Layers
|
''' Layers
|
||||||
This file contains various layers for the BigGAN models.
|
This file contains various layers for the BigGAN models.
|
||||||
'''
|
'''
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn import init
|
|
||||||
import torch.optim as optim
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.nn import Parameter as P
|
from torch.nn import Parameter as P
|
||||||
|
|
||||||
from sync_batchnorm import SynchronizedBatchNorm2d as SyncBN2d
|
|
||||||
|
|
||||||
|
|
||||||
# Projection of x onto y
|
# Projection of x onto y
|
||||||
def proj(x, y):
|
def proj(x, y):
|
||||||
|
@ -336,7 +331,7 @@ class ccbn(nn.Module):
|
||||||
# Normal, non-class-conditional BN
|
# Normal, non-class-conditional BN
|
||||||
class bn(nn.Module):
|
class bn(nn.Module):
|
||||||
def __init__(self, output_size, eps=1e-5, momentum=0.1,
|
def __init__(self, output_size, eps=1e-5, momentum=0.1,
|
||||||
cross_replica=False, mybn=False):
|
cross_replica=False, mybn=False, norm_style=None):
|
||||||
super(bn, self).__init__()
|
super(bn, self).__init__()
|
||||||
self.output_size = output_size
|
self.output_size = output_size
|
||||||
# Prepare gain and bias layers
|
# Prepare gain and bias layers
|
||||||
|
|
|
@ -16,8 +16,6 @@ import torch.nn.functional as F
|
||||||
from torch.nn.modules.batchnorm import _BatchNorm
|
from torch.nn.modules.batchnorm import _BatchNorm
|
||||||
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
|
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
|
||||||
|
|
||||||
from .comm import SyncMaster
|
|
||||||
|
|
||||||
__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
|
__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
|
||||||
|
|
||||||
|
|
||||||
|
@ -348,4 +346,144 @@ class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
|
||||||
if input.dim() != 5:
|
if input.dim() != 5:
|
||||||
raise ValueError('expected 5D input (got {}D input)'
|
raise ValueError('expected 5D input (got {}D input)'
|
||||||
.format(input.dim()))
|
.format(input.dim()))
|
||||||
super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
|
super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
|
||||||
|
|
||||||
|
|
||||||
|
# From ccomm.py
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# File : comm.py
|
||||||
|
# Author : Jiayuan Mao
|
||||||
|
# Email : maojiayuan@gmail.com
|
||||||
|
# Date : 27/01/2018
|
||||||
|
#
|
||||||
|
# This file is part of Synchronized-BatchNorm-PyTorch.
|
||||||
|
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
||||||
|
# Distributed under MIT License.
|
||||||
|
|
||||||
|
import queue
|
||||||
|
import collections
|
||||||
|
import threading
|
||||||
|
|
||||||
|
__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
|
||||||
|
|
||||||
|
|
||||||
|
class FutureResult(object):
|
||||||
|
"""A thread-safe future implementation. Used only as one-to-one pipe."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._result = None
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._cond = threading.Condition(self._lock)
|
||||||
|
|
||||||
|
def put(self, result):
|
||||||
|
with self._lock:
|
||||||
|
assert self._result is None, 'Previous result has\'t been fetched.'
|
||||||
|
self._result = result
|
||||||
|
self._cond.notify()
|
||||||
|
|
||||||
|
def get(self):
|
||||||
|
with self._lock:
|
||||||
|
if self._result is None:
|
||||||
|
self._cond.wait()
|
||||||
|
|
||||||
|
res = self._result
|
||||||
|
self._result = None
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
|
||||||
|
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
|
||||||
|
|
||||||
|
|
||||||
|
class SlavePipe(_SlavePipeBase):
|
||||||
|
"""Pipe for master-slave communication."""
|
||||||
|
|
||||||
|
def run_slave(self, msg):
|
||||||
|
self.queue.put((self.identifier, msg))
|
||||||
|
ret = self.result.get()
|
||||||
|
self.queue.put(True)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
class SyncMaster(object):
|
||||||
|
"""An abstract `SyncMaster` object.
|
||||||
|
|
||||||
|
- During the replication, as the data parallel will trigger an callback of each module, all slave devices should
|
||||||
|
call `register(id)` and obtain an `SlavePipe` to communicate with the master.
|
||||||
|
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
|
||||||
|
and passed to a registered callback.
|
||||||
|
- After receiving the messages, the master device should gather the information and determine to message passed
|
||||||
|
back to each slave devices.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, master_callback):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
master_callback: a callback to be invoked after having collected messages from slave devices.
|
||||||
|
"""
|
||||||
|
self._master_callback = master_callback
|
||||||
|
self._queue = queue.Queue()
|
||||||
|
self._registry = collections.OrderedDict()
|
||||||
|
self._activated = False
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
return {'master_callback': self._master_callback}
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
self.__init__(state['master_callback'])
|
||||||
|
|
||||||
|
def register_slave(self, identifier):
|
||||||
|
"""
|
||||||
|
Register an slave device.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
identifier: an identifier, usually is the device id.
|
||||||
|
|
||||||
|
Returns: a `SlavePipe` object which can be used to communicate with the master device.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self._activated:
|
||||||
|
assert self._queue.empty(), 'Queue is not clean before next initialization.'
|
||||||
|
self._activated = False
|
||||||
|
self._registry.clear()
|
||||||
|
future = FutureResult()
|
||||||
|
self._registry[identifier] = _MasterRegistry(future)
|
||||||
|
return SlavePipe(identifier, self._queue, future)
|
||||||
|
|
||||||
|
def run_master(self, master_msg):
|
||||||
|
"""
|
||||||
|
Main entry for the master device in each forward pass.
|
||||||
|
The messages were first collected from each devices (including the master device), and then
|
||||||
|
an callback will be invoked to compute the message to be sent back to each devices
|
||||||
|
(including the master device).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
master_msg: the message that the master want to send to itself. This will be placed as the first
|
||||||
|
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
|
||||||
|
|
||||||
|
Returns: the message to be sent back to the master device.
|
||||||
|
|
||||||
|
"""
|
||||||
|
self._activated = True
|
||||||
|
|
||||||
|
intermediates = [(0, master_msg)]
|
||||||
|
for i in range(self.nr_slaves):
|
||||||
|
intermediates.append(self._queue.get())
|
||||||
|
|
||||||
|
results = self._master_callback(intermediates)
|
||||||
|
assert results[0][0] == 0, 'The first result should belongs to the master.'
|
||||||
|
|
||||||
|
for i, res in results:
|
||||||
|
if i == 0:
|
||||||
|
continue
|
||||||
|
self._registry[i].result.put(res)
|
||||||
|
|
||||||
|
for i in range(self.nr_slaves):
|
||||||
|
assert self._queue.get() is True
|
||||||
|
|
||||||
|
return results[0][1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def nr_slaves(self):
|
||||||
|
return len(self._registry)
|
|
@ -35,7 +35,7 @@ def define_G(opt, net_key='network_G'):
|
||||||
upscale_applications=opt_net['upscale_applications'], num_filters=opt_net['nf'],
|
upscale_applications=opt_net['upscale_applications'], num_filters=opt_net['nf'],
|
||||||
inject_noise=opt_net['inject_noise'])
|
inject_noise=opt_net['inject_noise'])
|
||||||
elif which_model == "BigGan":
|
elif which_model == "BigGan":
|
||||||
netG = biggan_arch.biggan_medium(filters=opt_net['nf'])
|
netG = biggan_arch.biggan_medium(num_filters=opt_net['nf'])
|
||||||
|
|
||||||
# image corruption
|
# image corruption
|
||||||
elif which_model == 'HighToLowResNet':
|
elif which_model == 'HighToLowResNet':
|
||||||
|
|
|
@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs):
|
||||||
def main():
|
def main():
|
||||||
#### options
|
#### options
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/finetune_vix_resgenv2.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vix_biggan.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||||
help='job launcher')
|
help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user