from collections import OrderedDict import torch from torch import nn import torch.nn.functional as F import numpy as np class BlurLayer(nn.Module): def __init__(self, kernel=None, normalize=True, flip=False, stride=1): super(BlurLayer, self).__init__() if kernel is None: kernel = [1, 2, 1] kernel = torch.tensor(kernel, dtype=torch.float32) kernel = kernel[:, None] * kernel[None, :] kernel = kernel[None, None] if normalize: kernel = kernel / kernel.sum() if flip: kernel = kernel[:, :, ::-1, ::-1] self.register_buffer('kernel', kernel) self.stride = stride def forward(self, x): # expand kernel channels kernel = self.kernel.expand(x.size(1), -1, -1, -1) x = F.conv2d( x, kernel, stride=self.stride, padding=int((self.kernel.size(2) - 1) / 2), groups=x.size(1) ) return x class Upscale2d(nn.Module): @staticmethod def upscale2d(x, factor=2, gain=1): assert x.dim() == 4 if gain != 1: x = x * gain if factor != 1: shape = x.shape x = x.view(shape[0], shape[1], shape[2], 1, shape[3], 1).expand(-1, -1, -1, factor, -1, factor) x = x.contiguous().view(shape[0], shape[1], factor * shape[2], factor * shape[3]) return x def __init__(self, factor=2, gain=1): super().__init__() assert isinstance(factor, int) and factor >= 1 self.gain = gain self.factor = factor def forward(self, x): return self.upscale2d(x, factor=self.factor, gain=self.gain) class Downscale2d(nn.Module): def __init__(self, factor=2, gain=1): super().__init__() assert isinstance(factor, int) and factor >= 1 self.factor = factor self.gain = gain if factor == 2: f = [np.sqrt(gain) / factor] * factor self.blur = BlurLayer(kernel=f, normalize=False, stride=factor) else: self.blur = None def forward(self, x): assert x.dim() == 4 # 2x2, float32 => downscale using _blur2d(). if self.blur is not None and x.dtype == torch.float32: return self.blur(x) # Apply gain. if self.gain != 1: x = x * self.gain # No-op => early exit. if self.factor == 1: return x # Large factor => downscale using tf.nn.avg_pool(). # NOTE: Requires tf_config['graph_options.place_pruned_graph']=True to work. return F.avg_pool2d(x, self.factor) class EqualizedConv2d(nn.Module): """Conv layer with equalized learning rate and custom learning rate multiplier.""" def __init__(self, input_channels, output_channels, kernel_size, stride=1, gain=2 ** 0.5, use_wscale=False, lrmul=1, bias=True, intermediate=None, upscale=False, downscale=False): super().__init__() if upscale: self.upscale = Upscale2d() else: self.upscale = None if downscale: self.downscale = Downscale2d() else: self.downscale = None he_std = gain * (input_channels * kernel_size ** 2) ** (-0.5) # He init self.kernel_size = kernel_size if use_wscale: init_std = 1.0 / lrmul self.w_mul = he_std * lrmul else: init_std = he_std / lrmul self.w_mul = lrmul self.weight = torch.nn.Parameter( torch.randn(output_channels, input_channels, kernel_size, kernel_size) * init_std) if bias: self.bias = torch.nn.Parameter(torch.zeros(output_channels)) self.b_mul = lrmul else: self.bias = None self.intermediate = intermediate def forward(self, x): bias = self.bias if bias is not None: bias = bias * self.b_mul have_convolution = False if self.upscale is not None and min(x.shape[2:]) * 2 >= 128: # this is the fused upscale + conv from StyleGAN, sadly this seems incompatible with the non-fused way # this really needs to be cleaned up and go into the conv... w = self.weight * self.w_mul w = w.permute(1, 0, 2, 3) # probably applying a conv on w would be more efficient. also this quadruples the weight (average)?! w = F.pad(w, [1, 1, 1, 1]) w = w[:, :, 1:, 1:] + w[:, :, :-1, 1:] + w[:, :, 1:, :-1] + w[:, :, :-1, :-1] x = F.conv_transpose2d(x, w, stride=2, padding=(w.size(-1) - 1) // 2) have_convolution = True elif self.upscale is not None: x = self.upscale(x) downscale = self.downscale intermediate = self.intermediate if downscale is not None and min(x.shape[2:]) >= 128: w = self.weight * self.w_mul w = F.pad(w, [1, 1, 1, 1]) # in contrast to upscale, this is a mean... w = (w[:, :, 1:, 1:] + w[:, :, :-1, 1:] + w[:, :, 1:, :-1] + w[:, :, :-1, :-1]) * 0.25 # avg_pool? x = F.conv2d(x, w, stride=2, padding=(w.size(-1) - 1) // 2) have_convolution = True downscale = None elif downscale is not None: assert intermediate is None intermediate = downscale if not have_convolution and intermediate is None: return F.conv2d(x, self.weight * self.w_mul, bias, padding=self.kernel_size // 2) elif not have_convolution: x = F.conv2d(x, self.weight * self.w_mul, None, padding=self.kernel_size // 2) if intermediate is not None: x = intermediate(x) if bias is not None: x = x + bias.view(1, -1, 1, 1) return x class EqualizedLinear(nn.Module): """Linear layer with equalized learning rate and custom learning rate multiplier.""" def __init__(self, input_size, output_size, gain=2 ** 0.5, use_wscale=False, lrmul=1, bias=True): super().__init__() he_std = gain * input_size ** (-0.5) # He init # Equalized learning rate and custom learning rate multiplier. if use_wscale: init_std = 1.0 / lrmul self.w_mul = he_std * lrmul else: init_std = he_std / lrmul self.w_mul = lrmul self.weight = torch.nn.Parameter(torch.randn(output_size, input_size) * init_std) if bias: self.bias = torch.nn.Parameter(torch.zeros(output_size)) self.b_mul = lrmul else: self.bias = None def forward(self, x): bias = self.bias if bias is not None: bias = bias * self.b_mul return F.linear(x, self.weight * self.w_mul, bias) class View(nn.Module): def __init__(self, *shape): super().__init__() self.shape = shape def forward(self, x): return x.view(x.size(0), *self.shape) class StddevLayer(nn.Module): def __init__(self, group_size=4, num_new_features=1): super().__init__() self.group_size = group_size self.num_new_features = num_new_features def forward(self, x): b, c, h, w = x.shape group_size = min(self.group_size, b) y = x.reshape([group_size, -1, self.num_new_features, c // self.num_new_features, h, w]) y = y - y.mean(0, keepdim=True) y = (y ** 2).mean(0, keepdim=True) y = (y + 1e-8) ** 0.5 y = y.mean([3, 4, 5], keepdim=True).squeeze(3) # don't keep the meaned-out channels y = y.expand(group_size, -1, -1, h, w).clone().reshape(b, self.num_new_features, h, w) z = torch.cat([x, y], dim=1) return z class DiscriminatorBlock(nn.Sequential): def __init__(self, in_channels, out_channels, gain, use_wscale, activation_layer, blur_kernel): super().__init__(OrderedDict([ ('conv0', EqualizedConv2d(in_channels, in_channels, kernel_size=3, gain=gain, use_wscale=use_wscale)), # out channels nf(res-1) ('act0', activation_layer), ('blur', BlurLayer(kernel=blur_kernel)), ('conv1_down', EqualizedConv2d(in_channels, out_channels, kernel_size=3, gain=gain, use_wscale=use_wscale, downscale=True)), ('act1', activation_layer)])) class DiscriminatorTop(nn.Sequential): def __init__(self, mbstd_group_size, mbstd_num_features, in_channels, intermediate_channels, gain, use_wscale, activation_layer, resolution=4, in_channels2=None, output_features=1, last_gain=1): """ :param mbstd_group_size: :param mbstd_num_features: :param in_channels: :param intermediate_channels: :param gain: :param use_wscale: :param activation_layer: :param resolution: :param in_channels2: :param output_features: :param last_gain: """ layers = [] if mbstd_group_size > 1: layers.append(('stddev_layer', StddevLayer(mbstd_group_size, mbstd_num_features))) if in_channels2 is None: in_channels2 = in_channels layers.append(('conv', EqualizedConv2d(in_channels + mbstd_num_features, in_channels2, kernel_size=3, gain=gain, use_wscale=use_wscale))) layers.append(('act0', activation_layer)) layers.append(('view', View(-1))) layers.append(('dense0', EqualizedLinear(in_channels2 * resolution * resolution, intermediate_channels, gain=gain, use_wscale=use_wscale))) layers.append(('act1', activation_layer)) layers.append(('dense1', EqualizedLinear(intermediate_channels, output_features, gain=last_gain, use_wscale=use_wscale))) super().__init__(OrderedDict(layers)) class StyleGanDiscriminator(nn.Module): def __init__(self, resolution, num_channels=3, fmap_base=8192, fmap_decay=1.0, fmap_max=512, nonlinearity='lrelu', use_wscale=True, mbstd_group_size=4, mbstd_num_features=1, blur_filter=None, structure='fixed', **kwargs): """ Discriminator used in the StyleGAN paper. :param num_channels: Number of input color channels. Overridden based on dataset. :param resolution: Input resolution. Overridden based on dataset. # label_size=0, # Dimensionality of the labels, 0 if no labels. Overridden based on dataset. :param fmap_base: Overall multiplier for the number of feature maps. :param fmap_decay: log2 feature map reduction when doubling the resolution. :param fmap_max: Maximum number of feature maps in any layer. :param nonlinearity: Activation function: 'relu', 'lrelu' :param use_wscale: Enable equalized learning rate? :param mbstd_group_size: Group size for the mini_batch standard deviation layer, 0 = disable. :param mbstd_num_features: Number of features for the mini_batch standard deviation layer. :param blur_filter: Low-pass filter to apply when resampling activations. None = no filtering. :param structure: 'fixed' = no progressive growing, 'linear' = human-readable :param kwargs: Ignore unrecognized keyword args. """ super(StyleGanDiscriminator, self).__init__() def nf(stage): return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max) self.mbstd_num_features = mbstd_num_features self.mbstd_group_size = mbstd_group_size self.structure = structure # if blur_filter is None: # blur_filter = [1, 2, 1] resolution_log2 = int(np.log2(resolution)) assert resolution == 2 ** resolution_log2 and resolution >= 4 self.depth = resolution_log2 - 1 act, gain = {'relu': (torch.relu, np.sqrt(2)), 'lrelu': (nn.LeakyReLU(negative_slope=0.2), np.sqrt(2))}[nonlinearity] # create the remaining layers blocks = [] from_rgb = [] for res in range(resolution_log2, 2, -1): # name = '{s}x{s}'.format(s=2 ** res) blocks.append(DiscriminatorBlock(nf(res - 1), nf(res - 2), gain=gain, use_wscale=use_wscale, activation_layer=act, blur_kernel=blur_filter)) # create the fromRGB layers for various inputs: from_rgb.append(EqualizedConv2d(num_channels, nf(res - 1), kernel_size=1, gain=gain, use_wscale=use_wscale)) self.blocks = nn.ModuleList(blocks) # Building the final block. self.final_block = DiscriminatorTop(self.mbstd_group_size, self.mbstd_num_features, in_channels=nf(2), intermediate_channels=nf(2), gain=gain, use_wscale=use_wscale, activation_layer=act) from_rgb.append(EqualizedConv2d(num_channels, nf(2), kernel_size=1, gain=gain, use_wscale=use_wscale)) self.from_rgb = nn.ModuleList(from_rgb) # register the temporary downSampler self.temporaryDownsampler = nn.AvgPool2d(2) def forward(self, images_in, depth=0, alpha=1.): """ :param images_in: First input: Images [mini_batch, channel, height, width]. :param labels_in: Second input: Labels [mini_batch, label_size]. :param depth: current height of operation (Progressive GAN) :param alpha: current value of alpha for fade-in :return: """ if self.structure == 'fixed': x = self.from_rgb[0](images_in) for i, block in enumerate(self.blocks): x = block(x) scores_out = self.final_block(x) elif self.structure == 'linear': assert depth < self.depth, "Requested output depth cannot be produced" if depth > 0: residual = self.from_rgb[self.depth - depth](self.temporaryDownsampler(images_in)) straight = self.blocks[self.depth - depth - 1](self.from_rgb[self.depth - depth - 1](images_in)) x = (alpha * straight) + ((1 - alpha) * residual) for block in self.blocks[(self.depth - depth):]: x = block(x) else: x = self.from_rgb[-1](images_in) scores_out = self.final_block(x) else: raise KeyError("Unknown structure: ", self.structure) return scores_out