375 lines
15 KiB
Python
375 lines
15 KiB
Python
|
from collections import OrderedDict
|
||
|
|
||
|
import torch
|
||
|
from torch import nn
|
||
|
import torch.nn.functional as F
|
||
|
import numpy as np
|
||
|
|
||
|
|
||
|
class BlurLayer(nn.Module):
|
||
|
def __init__(self, kernel=None, normalize=True, flip=False, stride=1):
|
||
|
super(BlurLayer, self).__init__()
|
||
|
if kernel is None:
|
||
|
kernel = [1, 2, 1]
|
||
|
kernel = torch.tensor(kernel, dtype=torch.float32)
|
||
|
kernel = kernel[:, None] * kernel[None, :]
|
||
|
kernel = kernel[None, None]
|
||
|
if normalize:
|
||
|
kernel = kernel / kernel.sum()
|
||
|
if flip:
|
||
|
kernel = kernel[:, :, ::-1, ::-1]
|
||
|
self.register_buffer('kernel', kernel)
|
||
|
self.stride = stride
|
||
|
|
||
|
def forward(self, x):
|
||
|
# expand kernel channels
|
||
|
kernel = self.kernel.expand(x.size(1), -1, -1, -1)
|
||
|
x = F.conv2d(
|
||
|
x,
|
||
|
kernel,
|
||
|
stride=self.stride,
|
||
|
padding=int((self.kernel.size(2) - 1) / 2),
|
||
|
groups=x.size(1)
|
||
|
)
|
||
|
return x
|
||
|
|
||
|
|
||
|
class Upscale2d(nn.Module):
|
||
|
@staticmethod
|
||
|
def upscale2d(x, factor=2, gain=1):
|
||
|
assert x.dim() == 4
|
||
|
if gain != 1:
|
||
|
x = x * gain
|
||
|
if factor != 1:
|
||
|
shape = x.shape
|
||
|
x = x.view(shape[0], shape[1], shape[2], 1, shape[3], 1).expand(-1, -1, -1, factor, -1, factor)
|
||
|
x = x.contiguous().view(shape[0], shape[1], factor * shape[2], factor * shape[3])
|
||
|
return x
|
||
|
|
||
|
def __init__(self, factor=2, gain=1):
|
||
|
super().__init__()
|
||
|
assert isinstance(factor, int) and factor >= 1
|
||
|
self.gain = gain
|
||
|
self.factor = factor
|
||
|
|
||
|
def forward(self, x):
|
||
|
return self.upscale2d(x, factor=self.factor, gain=self.gain)
|
||
|
|
||
|
|
||
|
class Downscale2d(nn.Module):
|
||
|
def __init__(self, factor=2, gain=1):
|
||
|
super().__init__()
|
||
|
assert isinstance(factor, int) and factor >= 1
|
||
|
self.factor = factor
|
||
|
self.gain = gain
|
||
|
if factor == 2:
|
||
|
f = [np.sqrt(gain) / factor] * factor
|
||
|
self.blur = BlurLayer(kernel=f, normalize=False, stride=factor)
|
||
|
else:
|
||
|
self.blur = None
|
||
|
|
||
|
def forward(self, x):
|
||
|
assert x.dim() == 4
|
||
|
# 2x2, float32 => downscale using _blur2d().
|
||
|
if self.blur is not None and x.dtype == torch.float32:
|
||
|
return self.blur(x)
|
||
|
|
||
|
# Apply gain.
|
||
|
if self.gain != 1:
|
||
|
x = x * self.gain
|
||
|
|
||
|
# No-op => early exit.
|
||
|
if self.factor == 1:
|
||
|
return x
|
||
|
|
||
|
# Large factor => downscale using tf.nn.avg_pool().
|
||
|
# NOTE: Requires tf_config['graph_options.place_pruned_graph']=True to work.
|
||
|
return F.avg_pool2d(x, self.factor)
|
||
|
|
||
|
|
||
|
class EqualizedConv2d(nn.Module):
|
||
|
"""Conv layer with equalized learning rate and custom learning rate multiplier."""
|
||
|
|
||
|
def __init__(self, input_channels, output_channels, kernel_size, stride=1, gain=2 ** 0.5, use_wscale=False,
|
||
|
lrmul=1, bias=True, intermediate=None, upscale=False, downscale=False):
|
||
|
super().__init__()
|
||
|
if upscale:
|
||
|
self.upscale = Upscale2d()
|
||
|
else:
|
||
|
self.upscale = None
|
||
|
if downscale:
|
||
|
self.downscale = Downscale2d()
|
||
|
else:
|
||
|
self.downscale = None
|
||
|
he_std = gain * (input_channels * kernel_size ** 2) ** (-0.5) # He init
|
||
|
self.kernel_size = kernel_size
|
||
|
if use_wscale:
|
||
|
init_std = 1.0 / lrmul
|
||
|
self.w_mul = he_std * lrmul
|
||
|
else:
|
||
|
init_std = he_std / lrmul
|
||
|
self.w_mul = lrmul
|
||
|
self.weight = torch.nn.Parameter(
|
||
|
torch.randn(output_channels, input_channels, kernel_size, kernel_size) * init_std)
|
||
|
if bias:
|
||
|
self.bias = torch.nn.Parameter(torch.zeros(output_channels))
|
||
|
self.b_mul = lrmul
|
||
|
else:
|
||
|
self.bias = None
|
||
|
self.intermediate = intermediate
|
||
|
|
||
|
def forward(self, x):
|
||
|
bias = self.bias
|
||
|
if bias is not None:
|
||
|
bias = bias * self.b_mul
|
||
|
|
||
|
have_convolution = False
|
||
|
if self.upscale is not None and min(x.shape[2:]) * 2 >= 128:
|
||
|
# this is the fused upscale + conv from StyleGAN, sadly this seems incompatible with the non-fused way
|
||
|
# this really needs to be cleaned up and go into the conv...
|
||
|
w = self.weight * self.w_mul
|
||
|
w = w.permute(1, 0, 2, 3)
|
||
|
# probably applying a conv on w would be more efficient. also this quadruples the weight (average)?!
|
||
|
w = F.pad(w, [1, 1, 1, 1])
|
||
|
w = w[:, :, 1:, 1:] + w[:, :, :-1, 1:] + w[:, :, 1:, :-1] + w[:, :, :-1, :-1]
|
||
|
x = F.conv_transpose2d(x, w, stride=2, padding=(w.size(-1) - 1) // 2)
|
||
|
have_convolution = True
|
||
|
elif self.upscale is not None:
|
||
|
x = self.upscale(x)
|
||
|
|
||
|
downscale = self.downscale
|
||
|
intermediate = self.intermediate
|
||
|
if downscale is not None and min(x.shape[2:]) >= 128:
|
||
|
w = self.weight * self.w_mul
|
||
|
w = F.pad(w, [1, 1, 1, 1])
|
||
|
# in contrast to upscale, this is a mean...
|
||
|
w = (w[:, :, 1:, 1:] + w[:, :, :-1, 1:] + w[:, :, 1:, :-1] + w[:, :, :-1, :-1]) * 0.25 # avg_pool?
|
||
|
x = F.conv2d(x, w, stride=2, padding=(w.size(-1) - 1) // 2)
|
||
|
have_convolution = True
|
||
|
downscale = None
|
||
|
elif downscale is not None:
|
||
|
assert intermediate is None
|
||
|
intermediate = downscale
|
||
|
|
||
|
if not have_convolution and intermediate is None:
|
||
|
return F.conv2d(x, self.weight * self.w_mul, bias, padding=self.kernel_size // 2)
|
||
|
elif not have_convolution:
|
||
|
x = F.conv2d(x, self.weight * self.w_mul, None, padding=self.kernel_size // 2)
|
||
|
|
||
|
if intermediate is not None:
|
||
|
x = intermediate(x)
|
||
|
|
||
|
if bias is not None:
|
||
|
x = x + bias.view(1, -1, 1, 1)
|
||
|
return x
|
||
|
|
||
|
|
||
|
class EqualizedLinear(nn.Module):
|
||
|
"""Linear layer with equalized learning rate and custom learning rate multiplier."""
|
||
|
|
||
|
def __init__(self, input_size, output_size, gain=2 ** 0.5, use_wscale=False, lrmul=1, bias=True):
|
||
|
super().__init__()
|
||
|
he_std = gain * input_size ** (-0.5) # He init
|
||
|
# Equalized learning rate and custom learning rate multiplier.
|
||
|
if use_wscale:
|
||
|
init_std = 1.0 / lrmul
|
||
|
self.w_mul = he_std * lrmul
|
||
|
else:
|
||
|
init_std = he_std / lrmul
|
||
|
self.w_mul = lrmul
|
||
|
self.weight = torch.nn.Parameter(torch.randn(output_size, input_size) * init_std)
|
||
|
if bias:
|
||
|
self.bias = torch.nn.Parameter(torch.zeros(output_size))
|
||
|
self.b_mul = lrmul
|
||
|
else:
|
||
|
self.bias = None
|
||
|
|
||
|
def forward(self, x):
|
||
|
bias = self.bias
|
||
|
if bias is not None:
|
||
|
bias = bias * self.b_mul
|
||
|
return F.linear(x, self.weight * self.w_mul, bias)
|
||
|
|
||
|
|
||
|
class View(nn.Module):
|
||
|
def __init__(self, *shape):
|
||
|
super().__init__()
|
||
|
self.shape = shape
|
||
|
|
||
|
|
||
|
def forward(self, x):
|
||
|
return x.view(x.size(0), *self.shape)
|
||
|
|
||
|
|
||
|
class StddevLayer(nn.Module):
|
||
|
def __init__(self, group_size=4, num_new_features=1):
|
||
|
super().__init__()
|
||
|
self.group_size = group_size
|
||
|
self.num_new_features = num_new_features
|
||
|
|
||
|
def forward(self, x):
|
||
|
b, c, h, w = x.shape
|
||
|
group_size = min(self.group_size, b)
|
||
|
y = x.reshape([group_size, -1, self.num_new_features,
|
||
|
c // self.num_new_features, h, w])
|
||
|
y = y - y.mean(0, keepdim=True)
|
||
|
y = (y ** 2).mean(0, keepdim=True)
|
||
|
y = (y + 1e-8) ** 0.5
|
||
|
y = y.mean([3, 4, 5], keepdim=True).squeeze(3) # don't keep the meaned-out channels
|
||
|
y = y.expand(group_size, -1, -1, h, w).clone().reshape(b, self.num_new_features, h, w)
|
||
|
z = torch.cat([x, y], dim=1)
|
||
|
return z
|
||
|
|
||
|
|
||
|
class DiscriminatorBlock(nn.Sequential):
|
||
|
def __init__(self, in_channels, out_channels, gain, use_wscale, activation_layer, blur_kernel):
|
||
|
super().__init__(OrderedDict([
|
||
|
('conv0', EqualizedConv2d(in_channels, in_channels, kernel_size=3, gain=gain, use_wscale=use_wscale)),
|
||
|
# out channels nf(res-1)
|
||
|
('act0', activation_layer),
|
||
|
('blur', BlurLayer(kernel=blur_kernel)),
|
||
|
('conv1_down', EqualizedConv2d(in_channels, out_channels, kernel_size=3,
|
||
|
gain=gain, use_wscale=use_wscale, downscale=True)),
|
||
|
('act1', activation_layer)]))
|
||
|
|
||
|
|
||
|
|
||
|
class DiscriminatorTop(nn.Sequential):
|
||
|
def __init__(self,
|
||
|
mbstd_group_size,
|
||
|
mbstd_num_features,
|
||
|
in_channels,
|
||
|
intermediate_channels,
|
||
|
gain, use_wscale,
|
||
|
activation_layer,
|
||
|
resolution=4,
|
||
|
in_channels2=None,
|
||
|
output_features=1,
|
||
|
last_gain=1):
|
||
|
"""
|
||
|
:param mbstd_group_size:
|
||
|
:param mbstd_num_features:
|
||
|
:param in_channels:
|
||
|
:param intermediate_channels:
|
||
|
:param gain:
|
||
|
:param use_wscale:
|
||
|
:param activation_layer:
|
||
|
:param resolution:
|
||
|
:param in_channels2:
|
||
|
:param output_features:
|
||
|
:param last_gain:
|
||
|
"""
|
||
|
|
||
|
layers = []
|
||
|
if mbstd_group_size > 1:
|
||
|
layers.append(('stddev_layer', StddevLayer(mbstd_group_size, mbstd_num_features)))
|
||
|
|
||
|
if in_channels2 is None:
|
||
|
in_channels2 = in_channels
|
||
|
|
||
|
layers.append(('conv', EqualizedConv2d(in_channels + mbstd_num_features, in_channels2, kernel_size=3,
|
||
|
gain=gain, use_wscale=use_wscale)))
|
||
|
layers.append(('act0', activation_layer))
|
||
|
layers.append(('view', View(-1)))
|
||
|
layers.append(('dense0', EqualizedLinear(in_channels2 * resolution * resolution, intermediate_channels,
|
||
|
gain=gain, use_wscale=use_wscale)))
|
||
|
layers.append(('act1', activation_layer))
|
||
|
layers.append(('dense1', EqualizedLinear(intermediate_channels, output_features,
|
||
|
gain=last_gain, use_wscale=use_wscale)))
|
||
|
|
||
|
super().__init__(OrderedDict(layers))
|
||
|
|
||
|
|
||
|
class StyleGanDiscriminator(nn.Module):
|
||
|
def __init__(self, resolution, num_channels=3, fmap_base=8192, fmap_decay=1.0, fmap_max=512,
|
||
|
nonlinearity='lrelu', use_wscale=True, mbstd_group_size=4, mbstd_num_features=1,
|
||
|
blur_filter=None, structure='fixed', **kwargs):
|
||
|
"""
|
||
|
Discriminator used in the StyleGAN paper.
|
||
|
:param num_channels: Number of input color channels. Overridden based on dataset.
|
||
|
:param resolution: Input resolution. Overridden based on dataset.
|
||
|
# label_size=0, # Dimensionality of the labels, 0 if no labels. Overridden based on dataset.
|
||
|
:param fmap_base: Overall multiplier for the number of feature maps.
|
||
|
:param fmap_decay: log2 feature map reduction when doubling the resolution.
|
||
|
:param fmap_max: Maximum number of feature maps in any layer.
|
||
|
:param nonlinearity: Activation function: 'relu', 'lrelu'
|
||
|
:param use_wscale: Enable equalized learning rate?
|
||
|
:param mbstd_group_size: Group size for the mini_batch standard deviation layer, 0 = disable.
|
||
|
:param mbstd_num_features: Number of features for the mini_batch standard deviation layer.
|
||
|
:param blur_filter: Low-pass filter to apply when resampling activations. None = no filtering.
|
||
|
:param structure: 'fixed' = no progressive growing, 'linear' = human-readable
|
||
|
:param kwargs: Ignore unrecognized keyword args.
|
||
|
"""
|
||
|
super(StyleGanDiscriminator, self).__init__()
|
||
|
|
||
|
def nf(stage):
|
||
|
return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max)
|
||
|
|
||
|
self.mbstd_num_features = mbstd_num_features
|
||
|
self.mbstd_group_size = mbstd_group_size
|
||
|
self.structure = structure
|
||
|
# if blur_filter is None:
|
||
|
# blur_filter = [1, 2, 1]
|
||
|
|
||
|
resolution_log2 = int(np.log2(resolution))
|
||
|
assert resolution == 2 ** resolution_log2 and resolution >= 4
|
||
|
self.depth = resolution_log2 - 1
|
||
|
|
||
|
act, gain = {'relu': (torch.relu, np.sqrt(2)),
|
||
|
'lrelu': (nn.LeakyReLU(negative_slope=0.2), np.sqrt(2))}[nonlinearity]
|
||
|
|
||
|
# create the remaining layers
|
||
|
blocks = []
|
||
|
from_rgb = []
|
||
|
for res in range(resolution_log2, 2, -1):
|
||
|
# name = '{s}x{s}'.format(s=2 ** res)
|
||
|
blocks.append(DiscriminatorBlock(nf(res - 1), nf(res - 2),
|
||
|
gain=gain, use_wscale=use_wscale, activation_layer=act,
|
||
|
blur_kernel=blur_filter))
|
||
|
# create the fromRGB layers for various inputs:
|
||
|
from_rgb.append(EqualizedConv2d(num_channels, nf(res - 1), kernel_size=1,
|
||
|
gain=gain, use_wscale=use_wscale))
|
||
|
self.blocks = nn.ModuleList(blocks)
|
||
|
|
||
|
# Building the final block.
|
||
|
self.final_block = DiscriminatorTop(self.mbstd_group_size, self.mbstd_num_features,
|
||
|
in_channels=nf(2), intermediate_channels=nf(2),
|
||
|
gain=gain, use_wscale=use_wscale, activation_layer=act)
|
||
|
from_rgb.append(EqualizedConv2d(num_channels, nf(2), kernel_size=1,
|
||
|
gain=gain, use_wscale=use_wscale))
|
||
|
self.from_rgb = nn.ModuleList(from_rgb)
|
||
|
|
||
|
# register the temporary downSampler
|
||
|
self.temporaryDownsampler = nn.AvgPool2d(2)
|
||
|
|
||
|
def forward(self, images_in, depth=0, alpha=1.):
|
||
|
"""
|
||
|
:param images_in: First input: Images [mini_batch, channel, height, width].
|
||
|
:param labels_in: Second input: Labels [mini_batch, label_size].
|
||
|
:param depth: current height of operation (Progressive GAN)
|
||
|
:param alpha: current value of alpha for fade-in
|
||
|
:return:
|
||
|
"""
|
||
|
|
||
|
if self.structure == 'fixed':
|
||
|
x = self.from_rgb[0](images_in)
|
||
|
for i, block in enumerate(self.blocks):
|
||
|
x = block(x)
|
||
|
scores_out = self.final_block(x)
|
||
|
elif self.structure == 'linear':
|
||
|
assert depth < self.depth, "Requested output depth cannot be produced"
|
||
|
if depth > 0:
|
||
|
residual = self.from_rgb[self.depth - depth](self.temporaryDownsampler(images_in))
|
||
|
straight = self.blocks[self.depth - depth - 1](self.from_rgb[self.depth - depth - 1](images_in))
|
||
|
x = (alpha * straight) + ((1 - alpha) * residual)
|
||
|
|
||
|
for block in self.blocks[(self.depth - depth):]:
|
||
|
x = block(x)
|
||
|
else:
|
||
|
x = self.from_rgb[-1](images_in)
|
||
|
|
||
|
scores_out = self.final_block(x)
|
||
|
else:
|
||
|
raise KeyError("Unknown structure: ", self.structure)
|
||
|
|
||
|
return scores_out
|