forked from mrq/DL-Art-School
Lots of new discriminator nets
This commit is contained in:
parent
4e5ba61ae7
commit
6a2fd5f7d0
0
codes/models/__init__.py
Normal file
0
codes/models/__init__.py
Normal file
|
@ -1,197 +0,0 @@
|
|||
import os
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from torch import nn
|
||||
|
||||
from models.archs.SPSR_arch import ImageGradientNoPadding
|
||||
from models.archs.arch_util import ConvGnLelu, ExpansionBlock2, ConvGnSilu, ConjoinBlock, MultiConvBlock, \
|
||||
FinalUpsampleBlock2x, ReferenceJoinBlock
|
||||
from models.archs.spinenet_arch import SpineNet
|
||||
from utils.util import checkpoint
|
||||
|
||||
|
||||
class BasicEmbeddingPyramid(nn.Module):
|
||||
def __init__(self, use_norms=True):
|
||||
super(BasicEmbeddingPyramid, self).__init__()
|
||||
self.initial_process = ConvGnLelu(64, 64, kernel_size=1, bias=True, activation=True, norm=False)
|
||||
self.reducers = nn.ModuleList([ConvGnLelu(64, 128, stride=2, kernel_size=1, bias=False, activation=True, norm=False),
|
||||
ConvGnLelu(128, 128, kernel_size=3, bias=False, activation=True, norm=use_norms),
|
||||
ConvGnLelu(128, 256, stride=2, kernel_size=1, bias=False, activation=True, norm=False),
|
||||
ConvGnLelu(256, 256, kernel_size=3, bias=False, activation=True, norm=use_norms)])
|
||||
self.expanders = nn.ModuleList([ExpansionBlock2(256, 128, block=ConvGnLelu),
|
||||
ExpansionBlock2(128, 64, block=ConvGnLelu)])
|
||||
self.embedding_processor1 = ConvGnSilu(256, 128, kernel_size=1, bias=True, activation=True, norm=False)
|
||||
self.embedding_joiner1 = ConjoinBlock(128, block=ConvGnLelu, norm=use_norms)
|
||||
self.embedding_processor2 = ConvGnSilu(256, 256, kernel_size=1, bias=True, activation=True, norm=False)
|
||||
self.embedding_joiner2 = ConjoinBlock(256, block=ConvGnLelu, norm=use_norms)
|
||||
|
||||
self.final_process = nn.Sequential(ConvGnLelu(128, 96, kernel_size=1, bias=False, activation=False, norm=False,
|
||||
weight_init_factor=.1),
|
||||
ConvGnLelu(96, 64, kernel_size=1, bias=False, activation=False, norm=False,
|
||||
weight_init_factor=.1),
|
||||
ConvGnLelu(64, 64, kernel_size=1, bias=False, activation=False, norm=False,
|
||||
weight_init_factor=.1),
|
||||
ConvGnLelu(64, 64, kernel_size=1, bias=False, activation=False, norm=False,
|
||||
weight_init_factor=.1))
|
||||
|
||||
def forward(self, x, *embeddings):
|
||||
p = self.initial_process(x)
|
||||
identities = []
|
||||
for i in range(2):
|
||||
identities.append(p)
|
||||
p = self.reducers[i*2](p)
|
||||
p = self.reducers[i*2+1](p)
|
||||
if i == 0:
|
||||
p = self.embedding_joiner1(p, self.embedding_processor1(embeddings[0]))
|
||||
elif i == 1:
|
||||
p = self.embedding_joiner2(p, self.embedding_processor2(embeddings[1]))
|
||||
for i in range(2):
|
||||
p = self.expanders[i](p, identities[-(i+1)])
|
||||
x = self.final_process(torch.cat([x, p], dim=1))
|
||||
return x, p
|
||||
|
||||
|
||||
|
||||
|
||||
class ChainedEmbeddingGenWithStructure(nn.Module):
|
||||
def __init__(self, in_nc=3, depth=10, recurrent=False, recurrent_nf=3, recurrent_stride=2):
|
||||
super(ChainedEmbeddingGenWithStructure, self).__init__()
|
||||
self.recurrent = recurrent
|
||||
self.initial_conv = ConvGnLelu(in_nc, 64, kernel_size=7, bias=True, norm=False, activation=False)
|
||||
if recurrent:
|
||||
self.recurrent_nf = recurrent_nf
|
||||
self.recurrent_stride = recurrent_stride
|
||||
self.recurrent_process = ConvGnLelu(recurrent_nf, 64, kernel_size=3, stride=recurrent_stride, norm=False, bias=True, activation=False)
|
||||
self.recurrent_join = ReferenceJoinBlock(64, residual_weight_init_factor=.01, final_norm=False, kernel_size=1, depth=3, join=False)
|
||||
self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False)
|
||||
self.blocks = nn.ModuleList([BasicEmbeddingPyramid() for i in range(depth)])
|
||||
self.structure_joins = nn.ModuleList([ConjoinBlock(64) for i in range(3)])
|
||||
self.structure_blocks = nn.ModuleList([ConvGnLelu(64, 64, kernel_size=3, bias=False, norm=False, activation=False, weight_init_factor=.1) for i in range(3)])
|
||||
self.structure_upsample = FinalUpsampleBlock2x(64)
|
||||
self.grad_extract = ImageGradientNoPadding()
|
||||
self.upsample = FinalUpsampleBlock2x(64)
|
||||
self.ref_join_std = 0
|
||||
|
||||
def forward(self, x, recurrent=None):
|
||||
fea = self.initial_conv(x)
|
||||
if self.recurrent:
|
||||
if recurrent is None:
|
||||
if self.recurrent_nf == 3:
|
||||
recurrent = torch.zeros_like(x)
|
||||
if self.recurrent_stride != 1:
|
||||
recurrent = torch.nn.functional.interpolate(recurrent, scale_factor=self.recurrent_stride, mode='nearest')
|
||||
else:
|
||||
recurrent = torch.zeros_like(fea)
|
||||
rec = self.recurrent_process(recurrent)
|
||||
fea, recstd = self.recurrent_join(fea, rec)
|
||||
self.ref_join_std = recstd.item()
|
||||
if self.spine is not None:
|
||||
emb = checkpoint(self.spine, fea)
|
||||
else:
|
||||
b,f,h,w = fea.shape
|
||||
emb = (torch.zeros((b,f,h//2,w//2), device=fea.device),
|
||||
torch.zeros((b,f,h//4,w//4), device=fea.device))
|
||||
grad = fea
|
||||
for i, block in enumerate(self.blocks):
|
||||
fea = fea + checkpoint(block, fea, *emb)[0]
|
||||
if i < 3:
|
||||
structure_br = checkpoint(self.structure_joins[i], grad, fea)
|
||||
grad = grad + checkpoint(self.structure_blocks[i], structure_br)
|
||||
out = checkpoint(self.upsample, fea)
|
||||
return out, self.grad_extract(checkpoint(self.structure_upsample, grad)), self.grad_extract(out), fea
|
||||
|
||||
def get_debug_values(self, step, net_name):
|
||||
return { 'ref_join_std': self.ref_join_std }
|
||||
|
||||
|
||||
# This is a structural block that learns to mute regions of a residual transformation given a signal.
|
||||
class OptionalPassthroughBlock(nn.Module):
|
||||
def __init__(self, nf, initial_bias=10):
|
||||
super(OptionalPassthroughBlock, self).__init__()
|
||||
self.switch_process = nn.Sequential(ConvGnLelu(nf, nf // 2, 1, activation=False, norm=False, bias=False),
|
||||
ConvGnLelu(nf // 2, nf // 4, 1, activation=False, norm=False, bias=False),
|
||||
ConvGnLelu(nf // 4, 1, 1, activation=False, norm=False, bias=False))
|
||||
self.bias = nn.Parameter(torch.tensor(initial_bias, dtype=torch.float), requires_grad=True)
|
||||
self.activation = nn.Sigmoid()
|
||||
|
||||
def forward(self, x, switch_signal):
|
||||
switch = self.switch_process(switch_signal)
|
||||
bypass_map = self.activation(self.bias + switch)
|
||||
return x * bypass_map, bypass_map
|
||||
|
||||
|
||||
class MultifacetedChainedEmbeddingGen(nn.Module):
|
||||
def __init__(self, depth=10, scale=2):
|
||||
super(MultifacetedChainedEmbeddingGen, self).__init__()
|
||||
assert scale == 2 or scale == 4
|
||||
|
||||
self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False)
|
||||
|
||||
if scale == 2:
|
||||
self.teco_recurrent_process = ConvGnLelu(3, 64, kernel_size=3, stride=2, norm=False, bias=True, activation=False)
|
||||
else:
|
||||
self.teco_recurrent_process = ConvGnLelu(3, 64, kernel_size=7, stride=4, norm=False, bias=True, activation=False)
|
||||
self.teco_recurrent_join = ReferenceJoinBlock(64, residual_weight_init_factor=.01, final_norm=False, kernel_size=1, depth=3, join=False)
|
||||
|
||||
self.prog_recurrent_process = ConvGnLelu(64, 64, kernel_size=3, stride=1, norm=False, bias=True, activation=False)
|
||||
self.prog_recurrent_join = ReferenceJoinBlock(64, residual_weight_init_factor=.01, final_norm=False, kernel_size=1, depth=3, join=False)
|
||||
|
||||
self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False)
|
||||
self.blocks = nn.ModuleList([BasicEmbeddingPyramid() for i in range(depth)])
|
||||
self.bypasses = nn.ModuleList([OptionalPassthroughBlock(64, initial_bias=0) for i in range(depth)])
|
||||
self.structure_joins = nn.ModuleList([ConjoinBlock(64) for i in range(3)])
|
||||
self.structure_blocks = nn.ModuleList([ConvGnLelu(64, 64, kernel_size=3, bias=False, norm=False, activation=False, weight_init_factor=.1) for i in range(3)])
|
||||
self.structure_upsample = FinalUpsampleBlock2x(64, scale=scale)
|
||||
self.grad_extract = ImageGradientNoPadding()
|
||||
self.upsample = FinalUpsampleBlock2x(64, scale=scale)
|
||||
|
||||
self.teco_ref_std = 0
|
||||
self.prog_ref_std = 0
|
||||
self.block_residual_means = [0 for _ in range(depth)]
|
||||
self.block_residual_stds = [0 for _ in range(depth)]
|
||||
self.bypass_maps = []
|
||||
|
||||
def forward(self, x, teco_recurrent=None, prog_recurrent=None):
|
||||
fea = self.initial_conv(x)
|
||||
|
||||
# Integrate recurrence inputs.
|
||||
if teco_recurrent is not None:
|
||||
teco_rec = self.teco_recurrent_process(teco_recurrent)
|
||||
fea, std = self.teco_recurrent_join(fea, teco_rec)
|
||||
self.teco_ref_std = std.item()
|
||||
elif prog_recurrent is not None:
|
||||
prog_rec = self.prog_recurrent_process(prog_recurrent)
|
||||
prog_rec, std = self.prog_recurrent_join(fea, prog_rec)
|
||||
self.prog_ref_std = std.item()
|
||||
|
||||
emb = checkpoint(self.spine, fea)
|
||||
grad = fea
|
||||
self.bypass_maps = []
|
||||
for i, block in enumerate(self.blocks):
|
||||
residual, context = checkpoint(block, fea, *emb)
|
||||
residual, bypass_map = checkpoint(self.bypasses[i], residual, context)
|
||||
fea = fea + residual
|
||||
self.bypass_maps.append(bypass_map.detach())
|
||||
self.block_residual_means[i] = residual.mean().item()
|
||||
self.block_residual_stds[i] = residual.std().item()
|
||||
if i < 3:
|
||||
structure_br = checkpoint(self.structure_joins[i], grad, fea)
|
||||
grad = grad + checkpoint(self.structure_blocks[i], structure_br)
|
||||
out = checkpoint(self.upsample, fea)
|
||||
return out, self.grad_extract(checkpoint(self.structure_upsample, grad)), self.grad_extract(out), fea
|
||||
|
||||
def visual_dbg(self, step, path):
|
||||
for i, bm in enumerate(self.bypass_maps):
|
||||
torchvision.utils.save_image(bm.cpu().float(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1)))
|
||||
|
||||
def get_debug_values(self, step, net_name):
|
||||
biases = [b.bias.item() for b in self.bypasses]
|
||||
blk_stds, blk_means = {}, {}
|
||||
for i, (s, m) in enumerate(zip(self.block_residual_stds, self.block_residual_means)):
|
||||
blk_stds['block_%i' % (i+1,)] = s
|
||||
blk_means['block_%i' % (i+1,)] = m
|
||||
return {'teco_std': self.teco_ref_std,
|
||||
'prog_std': self.prog_ref_std,
|
||||
'bypass_biases': sum(biases) / len(biases),
|
||||
'blocks_std': blk_stds, 'blocks_mean': blk_means}
|
|
@ -1,225 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
|
||||
__all__ = ['FixupResNet', 'fixup_resnet18', 'fixup_resnet34', 'fixup_resnet50', 'fixup_resnet101', 'fixup_resnet152']
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=1, bias=False)
|
||||
|
||||
def conv5x5(in_planes, out_planes, stride=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=5, stride=stride,
|
||||
padding=2, bias=False)
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
|
||||
class FixupBasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, use_bn=False, conv_create=conv3x3):
|
||||
super(FixupBasicBlock, self).__init__()
|
||||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||
self.bias1a = nn.Parameter(torch.zeros(1))
|
||||
self.conv1 = conv_create(inplanes, planes, stride)
|
||||
self.bias1b = nn.Parameter(torch.zeros(1))
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||
self.bias2a = nn.Parameter(torch.zeros(1))
|
||||
self.conv2 = conv_create(planes, planes)
|
||||
self.scale = nn.Parameter(torch.ones(1))
|
||||
self.bias2b = nn.Parameter(torch.zeros(1))
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
self.use_bn = use_bn
|
||||
if use_bn:
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x + self.bias1a)
|
||||
if self.use_bn:
|
||||
out = self.bn1(out)
|
||||
out = self.lrelu(out + self.bias1b)
|
||||
|
||||
out = self.conv2(out + self.bias2a)
|
||||
if self.use_bn:
|
||||
out = self.bn2(out)
|
||||
out = out * self.scale + self.bias2b
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x + self.bias1a)
|
||||
|
||||
out += identity
|
||||
out = self.lrelu(out)
|
||||
|
||||
return out
|
||||
|
||||
class FixupBottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(FixupBottleneck, self).__init__()
|
||||
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
||||
self.bias1a = nn.Parameter(torch.zeros(1))
|
||||
self.conv1 = conv1x1(inplanes, planes)
|
||||
self.bias1b = nn.Parameter(torch.zeros(1))
|
||||
self.bias2a = nn.Parameter(torch.zeros(1))
|
||||
self.conv2 = conv3x3(planes, planes, stride)
|
||||
self.bias2b = nn.Parameter(torch.zeros(1))
|
||||
self.bias3a = nn.Parameter(torch.zeros(1))
|
||||
self.conv3 = conv1x1(planes, planes * self.expansion)
|
||||
self.scale = nn.Parameter(torch.ones(1))
|
||||
self.bias3b = nn.Parameter(torch.zeros(1))
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x + self.bias1a)
|
||||
out = self.lrelu(out + self.bias1b)
|
||||
|
||||
out = self.conv2(out + self.bias2a)
|
||||
out = self.lrelu(out + self.bias2b)
|
||||
|
||||
out = self.conv3(out + self.bias3a)
|
||||
out = out * self.scale + self.bias3b
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x + self.bias1a)
|
||||
|
||||
out += identity
|
||||
out = self.lrelu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class FixupResNet(nn.Module):
|
||||
|
||||
def __init__(self, block, layers, num_filters=64, num_classes=1000, input_img_size=64, number_skips=2, use_bn=False,
|
||||
disable_passthrough=False):
|
||||
super(FixupResNet, self).__init__()
|
||||
self.num_layers = sum(layers)
|
||||
self.inplanes = 3
|
||||
self.number_skips = number_skips
|
||||
self.disable_passthrough = disable_passthrough
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
self.layer0 = self._make_layer(block, num_filters*2, layers[0], stride=2, use_bn=use_bn, conv_type=conv5x5)
|
||||
if number_skips > 0:
|
||||
self.inplanes = self.inplanes + 3 # Accomodate a skip connection from the generator.
|
||||
self.layer1 = self._make_layer(block, num_filters*4, layers[1], stride=2, use_bn=use_bn, conv_type=conv5x5)
|
||||
if number_skips > 1:
|
||||
self.inplanes = self.inplanes + 3 # Accomodate a second skip connection from the generator.
|
||||
self.layer2 = self._make_layer(block, num_filters*8, layers[2], stride=2, use_bn=use_bn)
|
||||
# SRGAN already has a feature loss tied to a separate VGG discriminator. We really don't care about features.
|
||||
# Therefore, level off the filter count from this block forwards.
|
||||
self.layer3 = self._make_layer(block, num_filters*8, layers[3], stride=2, use_bn=use_bn)
|
||||
self.layer4 = self._make_layer(block, num_filters*8, layers[4], stride=2, use_bn=use_bn)
|
||||
self.bias2 = nn.Parameter(torch.zeros(1))
|
||||
reduced_img_sz = int(input_img_size / 32)
|
||||
self.fc1 = nn.Linear(num_filters * 8 * reduced_img_sz * reduced_img_sz, 100)
|
||||
self.fc2 = nn.Linear(100, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, FixupBasicBlock):
|
||||
nn.init.normal_(m.conv1.weight, mean=0, std=np.sqrt(2 / (m.conv1.weight.shape[0] * np.prod(m.conv1.weight.shape[2:]))) * self.num_layers ** (-0.5))
|
||||
nn.init.constant_(m.conv2.weight, 0)
|
||||
if m.downsample is not None:
|
||||
nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:]))))
|
||||
elif isinstance(m, FixupBottleneck):
|
||||
nn.init.normal_(m.conv1.weight, mean=0, std=np.sqrt(2 / (m.conv1.weight.shape[0] * np.prod(m.conv1.weight.shape[2:]))) * self.num_layers ** (-0.25))
|
||||
nn.init.normal_(m.conv2.weight, mean=0, std=np.sqrt(2 / (m.conv2.weight.shape[0] * np.prod(m.conv2.weight.shape[2:]))) * self.num_layers ** (-0.25))
|
||||
nn.init.constant_(m.conv3.weight, 0)
|
||||
if m.downsample is not None:
|
||||
nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:]))))
|
||||
|
||||
def _make_layer(self, block, outplanes, blocks, stride=1, use_bn=False, conv_type=conv3x3):
|
||||
layers = []
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(self.inplanes, self.inplanes))
|
||||
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != outplanes * block.expansion:
|
||||
downsample = conv1x1(self.inplanes, outplanes * block.expansion, stride)
|
||||
layers.append(block(self.inplanes, outplanes, stride, downsample, use_bn=use_bn, conv_create=conv_type))
|
||||
self.inplanes = outplanes * block.expansion
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
if len(x) == 3:
|
||||
# This class can take a medium skip (half-res) and low skip (quarter-res) provided as a tuple in the input.
|
||||
x, med_skip, lo_skip = x
|
||||
else:
|
||||
# Or just a tuple with only the high res input (this assumes number_skips was set right).
|
||||
x = x[0]
|
||||
|
||||
if self.disable_passthrough:
|
||||
if self.number_skips > 0:
|
||||
med_skip = torch.zeros_like(med_skip)
|
||||
if self.number_skips > 1:
|
||||
lo_skip = torch.zeros_like(lo_skip)
|
||||
x = self.layer0(x)
|
||||
if self.number_skips > 0:
|
||||
x = torch.cat([x, med_skip], dim=1)
|
||||
x = self.layer1(x)
|
||||
if self.number_skips > 1:
|
||||
x = torch.cat([x, lo_skip], dim=1)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.lrelu(self.fc1(x))
|
||||
x = self.fc2(x + self.bias2)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def fixup_resnet18(**kwargs):
|
||||
"""Constructs a Fixup-ResNet-18 model.2
|
||||
"""
|
||||
model = FixupResNet(FixupBasicBlock, [2, 2, 2, 2, 2], **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def fixup_resnet34(**kwargs):
|
||||
"""Constructs a Fixup-ResNet-34 model.
|
||||
"""
|
||||
model = FixupResNet(FixupBasicBlock, [5, 5, 3, 3, 3], **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def fixup_resnet50(**kwargs):
|
||||
"""Constructs a Fixup-ResNet-50 model.
|
||||
"""
|
||||
model = FixupResNet(FixupBottleneck, [3, 4, 6, 3, 2], **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def fixup_resnet101(**kwargs):
|
||||
"""Constructs a Fixup-ResNet-101 model.
|
||||
"""
|
||||
model = FixupResNet(FixupBottleneck, [3, 4, 23, 3, 2], **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def fixup_resnet152(**kwargs):
|
||||
"""Constructs a Fixup-ResNet-152 model.
|
||||
"""
|
||||
model = FixupResNet(FixupBottleneck, [3, 8, 36, 3, 2], **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
__all__ = ['FixupResNet', 'fixup_resnet18', 'fixup_resnet34', 'fixup_resnet50', 'fixup_resnet101', 'fixup_resnet152']
|
139
codes/models/archs/biggan/biggan_discriminator.py
Normal file
139
codes/models/archs/biggan/biggan_discriminator.py
Normal file
|
@ -0,0 +1,139 @@
|
|||
import functools
|
||||
|
||||
import torch
|
||||
from torch.nn import init
|
||||
|
||||
import models.archs.biggan.biggan_layers as layers
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
# Discriminator architecture, same paradigm as G's above
|
||||
def D_arch(ch=64, attention='64',ksize='333333', dilation='111111'):
|
||||
arch = {}
|
||||
arch[256] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8, 8, 16]],
|
||||
'out_channels' : [item * ch for item in [1, 2, 4, 8, 8, 16, 16]],
|
||||
'downsample' : [True] * 6 + [False],
|
||||
'resolution' : [128, 64, 32, 16, 8, 4, 4 ],
|
||||
'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')]
|
||||
for i in range(2,8)}}
|
||||
arch[128] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8, 16]],
|
||||
'out_channels' : [item * ch for item in [1, 2, 4, 8, 16, 16]],
|
||||
'downsample' : [True] * 5 + [False],
|
||||
'resolution' : [64, 32, 16, 8, 4, 4],
|
||||
'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')]
|
||||
for i in range(2,8)}}
|
||||
arch[64] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8]],
|
||||
'out_channels' : [item * ch for item in [1, 2, 4, 8, 16]],
|
||||
'downsample' : [True] * 4 + [False],
|
||||
'resolution' : [32, 16, 8, 4, 4],
|
||||
'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')]
|
||||
for i in range(2,7)}}
|
||||
arch[32] = {'in_channels' : [3] + [item * ch for item in [4, 4, 4]],
|
||||
'out_channels' : [item * ch for item in [4, 4, 4, 4]],
|
||||
'downsample' : [True, True, False, False],
|
||||
'resolution' : [16, 16, 16, 16],
|
||||
'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')]
|
||||
for i in range(2,6)}}
|
||||
return arch
|
||||
|
||||
|
||||
class BigGanDiscriminator(nn.Module):
|
||||
|
||||
def __init__(self, D_ch=64, D_wide=True, resolution=128,
|
||||
D_kernel_size=3, D_attn='64', num_D_SVs=1, num_D_SV_itrs=1, D_activation=nn.ReLU(inplace=False),
|
||||
SN_eps=1e-12, output_dim=1, D_fp16=False,
|
||||
D_init='ortho', skip_init=False, D_param='SN'):
|
||||
super(BigGanDiscriminator, self).__init__()
|
||||
# Width multiplier
|
||||
self.ch = D_ch
|
||||
# Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN?
|
||||
self.D_wide = D_wide
|
||||
# Resolution
|
||||
self.resolution = resolution
|
||||
# Kernel size
|
||||
self.kernel_size = D_kernel_size
|
||||
# Attention?
|
||||
self.attention = D_attn
|
||||
# Activation
|
||||
self.activation = D_activation
|
||||
# Initialization style
|
||||
self.init = D_init
|
||||
# Parameterization style
|
||||
self.D_param = D_param
|
||||
# Epsilon for Spectral Norm?
|
||||
self.SN_eps = SN_eps
|
||||
# Fp16?
|
||||
self.fp16 = D_fp16
|
||||
# Architecture
|
||||
self.arch = D_arch(self.ch, self.attention)[resolution]
|
||||
|
||||
# Which convs, batchnorms, and linear layers to use
|
||||
# No option to turn off SN in D right now
|
||||
if self.D_param == 'SN':
|
||||
self.which_conv = functools.partial(layers.SNConv2d,
|
||||
kernel_size=3, padding=1,
|
||||
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
|
||||
eps=self.SN_eps)
|
||||
self.which_linear = functools.partial(layers.SNLinear,
|
||||
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
|
||||
eps=self.SN_eps)
|
||||
self.which_embedding = functools.partial(layers.SNEmbedding,
|
||||
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
|
||||
eps=self.SN_eps)
|
||||
# Prepare model
|
||||
# self.blocks is a doubly-nested list of modules, the outer loop intended
|
||||
# to be over blocks at a given resolution (resblocks and/or self-attention)
|
||||
self.blocks = []
|
||||
for index in range(len(self.arch['out_channels'])):
|
||||
self.blocks += [[layers.DBlock(in_channels=self.arch['in_channels'][index],
|
||||
out_channels=self.arch['out_channels'][index],
|
||||
which_conv=self.which_conv,
|
||||
wide=self.D_wide,
|
||||
activation=self.activation,
|
||||
preactivation=(index > 0),
|
||||
downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] else None))]]
|
||||
# If attention on this block, attach it to the end
|
||||
if self.arch['attention'][self.arch['resolution'][index]]:
|
||||
print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index])
|
||||
self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index],
|
||||
self.which_conv)]
|
||||
# Turn self.blocks into a ModuleList so that it's all properly registered.
|
||||
self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
|
||||
# Linear output layer. The output dimension is typically 1, but may be
|
||||
# larger if we're e.g. turning this into a VAE with an inference output
|
||||
self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim)
|
||||
|
||||
# Initialize weights
|
||||
if not skip_init:
|
||||
self.init_weights()
|
||||
|
||||
# Initialize
|
||||
def init_weights(self):
|
||||
self.param_count = 0
|
||||
for module in self.modules():
|
||||
if (isinstance(module, nn.Conv2d)
|
||||
or isinstance(module, nn.Linear)
|
||||
or isinstance(module, nn.Embedding)):
|
||||
if self.init == 'ortho':
|
||||
init.orthogonal_(module.weight)
|
||||
elif self.init == 'N02':
|
||||
init.normal_(module.weight, 0, 0.02)
|
||||
elif self.init in ['glorot', 'xavier']:
|
||||
init.xavier_uniform_(module.weight)
|
||||
else:
|
||||
print('Init style not recognized...')
|
||||
self.param_count += sum([p.data.nelement() for p in module.parameters()])
|
||||
print('Param count for D''s initialized parameters: %d' % self.param_count)
|
||||
|
||||
def forward(self, x, y=None):
|
||||
# Stick x into h for cleaner for loops without flow control
|
||||
h = x
|
||||
# Loop over blocks
|
||||
for index, blocklist in enumerate(self.blocks):
|
||||
for block in blocklist:
|
||||
h = block(h)
|
||||
# Apply global sum pooling as in SN-GAN
|
||||
h = torch.sum(self.activation(h), [2, 3])
|
||||
# Get initial class-unconditional output
|
||||
out = self.linear(h)
|
||||
return out
|
457
codes/models/archs/biggan/biggan_layers.py
Normal file
457
codes/models/archs/biggan/biggan_layers.py
Normal file
|
@ -0,0 +1,457 @@
|
|||
''' Layers
|
||||
This file contains various layers for the BigGAN models.
|
||||
'''
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import init
|
||||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Parameter as P
|
||||
|
||||
|
||||
# Projection of x onto y
|
||||
def proj(x, y):
|
||||
return torch.mm(y, x.t()) * y / torch.mm(y, y.t())
|
||||
|
||||
|
||||
# Orthogonalize x wrt list of vectors ys
|
||||
def gram_schmidt(x, ys):
|
||||
for y in ys:
|
||||
x = x - proj(x, y)
|
||||
return x
|
||||
|
||||
|
||||
# Apply num_itrs steps of the power method to estimate top N singular values.
|
||||
def power_iteration(W, u_, update=True, eps=1e-12):
|
||||
# Lists holding singular vectors and values
|
||||
us, vs, svs = [], [], []
|
||||
for i, u in enumerate(u_):
|
||||
# Run one step of the power iteration
|
||||
with torch.no_grad():
|
||||
v = torch.matmul(u, W)
|
||||
# Run Gram-Schmidt to subtract components of all other singular vectors
|
||||
v = F.normalize(gram_schmidt(v, vs), eps=eps)
|
||||
# Add to the list
|
||||
vs += [v]
|
||||
# Update the other singular vector
|
||||
u = torch.matmul(v, W.t())
|
||||
# Run Gram-Schmidt to subtract components of all other singular vectors
|
||||
u = F.normalize(gram_schmidt(u, us), eps=eps)
|
||||
# Add to the list
|
||||
us += [u]
|
||||
if update:
|
||||
u_[i][:] = u
|
||||
# Compute this singular value and add it to the list
|
||||
svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))]
|
||||
# svs += [torch.sum(F.linear(u, W.transpose(0, 1)) * v)]
|
||||
return svs, us, vs
|
||||
|
||||
|
||||
# Convenience passthrough function
|
||||
class identity(nn.Module):
|
||||
def forward(self, input):
|
||||
return input
|
||||
|
||||
|
||||
# Spectral normalization base class
|
||||
class SN(object):
|
||||
def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12):
|
||||
# Number of power iterations per step
|
||||
self.num_itrs = num_itrs
|
||||
# Number of singular values
|
||||
self.num_svs = num_svs
|
||||
# Transposed?
|
||||
self.transpose = transpose
|
||||
# Epsilon value for avoiding divide-by-0
|
||||
self.eps = eps
|
||||
# Register a singular vector for each sv
|
||||
for i in range(self.num_svs):
|
||||
self.register_buffer('u%d' % i, torch.randn(1, num_outputs))
|
||||
self.register_buffer('sv%d' % i, torch.ones(1))
|
||||
|
||||
# Singular vectors (u side)
|
||||
@property
|
||||
def u(self):
|
||||
return [getattr(self, 'u%d' % i) for i in range(self.num_svs)]
|
||||
|
||||
# Singular values;
|
||||
# note that these buffers are just for logging and are not used in training.
|
||||
@property
|
||||
def sv(self):
|
||||
return [getattr(self, 'sv%d' % i) for i in range(self.num_svs)]
|
||||
|
||||
# Compute the spectrally-normalized weight
|
||||
def W_(self):
|
||||
W_mat = self.weight.view(self.weight.size(0), -1)
|
||||
if self.transpose:
|
||||
W_mat = W_mat.t()
|
||||
# Apply num_itrs power iterations
|
||||
for _ in range(self.num_itrs):
|
||||
svs, us, vs = power_iteration(W_mat, self.u, update=self.training, eps=self.eps)
|
||||
# Update the svs
|
||||
if self.training:
|
||||
with torch.no_grad(): # Make sure to do this in a no_grad() context or you'll get memory leaks!
|
||||
for i, sv in enumerate(svs):
|
||||
self.sv[i][:] = sv
|
||||
return self.weight / svs[0]
|
||||
|
||||
|
||||
# 2D Conv layer with spectral norm
|
||||
class SNConv2d(nn.Conv2d, SN):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
||||
padding=0, dilation=1, groups=1, bias=True,
|
||||
num_svs=1, num_itrs=1, eps=1e-12):
|
||||
nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride,
|
||||
padding, dilation, groups, bias)
|
||||
SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps)
|
||||
|
||||
def forward(self, x):
|
||||
return F.conv2d(x, self.W_(), self.bias, self.stride,
|
||||
self.padding, self.dilation, self.groups)
|
||||
|
||||
|
||||
# Linear layer with spectral norm
|
||||
class SNLinear(nn.Linear, SN):
|
||||
def __init__(self, in_features, out_features, bias=True,
|
||||
num_svs=1, num_itrs=1, eps=1e-12):
|
||||
nn.Linear.__init__(self, in_features, out_features, bias)
|
||||
SN.__init__(self, num_svs, num_itrs, out_features, eps=eps)
|
||||
|
||||
def forward(self, x):
|
||||
return F.linear(x, self.W_(), self.bias)
|
||||
|
||||
|
||||
# Embedding layer with spectral norm
|
||||
# We use num_embeddings as the dim instead of embedding_dim here
|
||||
# for convenience sake
|
||||
class SNEmbedding(nn.Embedding, SN):
|
||||
def __init__(self, num_embeddings, embedding_dim, padding_idx=None,
|
||||
max_norm=None, norm_type=2, scale_grad_by_freq=False,
|
||||
sparse=False, _weight=None,
|
||||
num_svs=1, num_itrs=1, eps=1e-12):
|
||||
nn.Embedding.__init__(self, num_embeddings, embedding_dim, padding_idx,
|
||||
max_norm, norm_type, scale_grad_by_freq,
|
||||
sparse, _weight)
|
||||
SN.__init__(self, num_svs, num_itrs, num_embeddings, eps=eps)
|
||||
|
||||
def forward(self, x):
|
||||
return F.embedding(x, self.W_())
|
||||
|
||||
|
||||
# A non-local block as used in SA-GAN
|
||||
# Note that the implementation as described in the paper is largely incorrect;
|
||||
# refer to the released code for the actual implementation.
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, ch, which_conv=SNConv2d, name='attention'):
|
||||
super(Attention, self).__init__()
|
||||
# Channel multiplier
|
||||
self.ch = ch
|
||||
self.which_conv = which_conv
|
||||
self.theta = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False)
|
||||
self.phi = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False)
|
||||
self.g = self.which_conv(self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False)
|
||||
self.o = self.which_conv(self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False)
|
||||
# Learnable gain parameter
|
||||
self.gamma = P(torch.tensor(0.), requires_grad=True)
|
||||
|
||||
def forward(self, x, y=None):
|
||||
# Apply convs
|
||||
theta = self.theta(x)
|
||||
phi = F.max_pool2d(self.phi(x), [2, 2])
|
||||
g = F.max_pool2d(self.g(x), [2, 2])
|
||||
# Perform reshapes
|
||||
theta = theta.view(-1, self.ch // 8, x.shape[2] * x.shape[3])
|
||||
phi = phi.view(-1, self.ch // 8, x.shape[2] * x.shape[3] // 4)
|
||||
g = g.view(-1, self.ch // 2, x.shape[2] * x.shape[3] // 4)
|
||||
# Matmul and softmax to get attention maps
|
||||
beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1)
|
||||
# Attention map times g path
|
||||
o = self.o(torch.bmm(g, beta.transpose(1, 2)).view(-1, self.ch // 2, x.shape[2], x.shape[3]))
|
||||
return self.gamma * o + x
|
||||
|
||||
|
||||
# Fused batchnorm op
|
||||
def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5):
|
||||
# Apply scale and shift--if gain and bias are provided, fuse them here
|
||||
# Prepare scale
|
||||
scale = torch.rsqrt(var + eps)
|
||||
# If a gain is provided, use it
|
||||
if gain is not None:
|
||||
scale = scale * gain
|
||||
# Prepare shift
|
||||
shift = mean * scale
|
||||
# If bias is provided, use it
|
||||
if bias is not None:
|
||||
shift = shift - bias
|
||||
return x * scale - shift
|
||||
# return ((x - mean) / ((var + eps) ** 0.5)) * gain + bias # The unfused way.
|
||||
|
||||
|
||||
# Manual BN
|
||||
# Calculate means and variances using mean-of-squares minus mean-squared
|
||||
def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5):
|
||||
# Cast x to float32 if necessary
|
||||
float_x = x.float()
|
||||
# Calculate expected value of x (m) and expected value of x**2 (m2)
|
||||
# Mean of x
|
||||
m = torch.mean(float_x, [0, 2, 3], keepdim=True)
|
||||
# Mean of x squared
|
||||
m2 = torch.mean(float_x ** 2, [0, 2, 3], keepdim=True)
|
||||
# Calculate variance as mean of squared minus mean squared.
|
||||
var = (m2 - m ** 2)
|
||||
# Cast back to float 16 if necessary
|
||||
var = var.type(x.type())
|
||||
m = m.type(x.type())
|
||||
# Return mean and variance for updating stored mean/var if requested
|
||||
if return_mean_var:
|
||||
return fused_bn(x, m, var, gain, bias, eps), m.squeeze(), var.squeeze()
|
||||
else:
|
||||
return fused_bn(x, m, var, gain, bias, eps)
|
||||
|
||||
|
||||
# My batchnorm, supports standing stats
|
||||
class myBN(nn.Module):
|
||||
def __init__(self, num_channels, eps=1e-5, momentum=0.1):
|
||||
super(myBN, self).__init__()
|
||||
# momentum for updating running stats
|
||||
self.momentum = momentum
|
||||
# epsilon to avoid dividing by 0
|
||||
self.eps = eps
|
||||
# Momentum
|
||||
self.momentum = momentum
|
||||
# Register buffers
|
||||
self.register_buffer('stored_mean', torch.zeros(num_channels))
|
||||
self.register_buffer('stored_var', torch.ones(num_channels))
|
||||
self.register_buffer('accumulation_counter', torch.zeros(1))
|
||||
# Accumulate running means and vars
|
||||
self.accumulate_standing = False
|
||||
|
||||
# reset standing stats
|
||||
def reset_stats(self):
|
||||
self.stored_mean[:] = 0
|
||||
self.stored_var[:] = 0
|
||||
self.accumulation_counter[:] = 0
|
||||
|
||||
def forward(self, x, gain, bias):
|
||||
if self.training:
|
||||
out, mean, var = manual_bn(x, gain, bias, return_mean_var=True, eps=self.eps)
|
||||
# If accumulating standing stats, increment them
|
||||
if self.accumulate_standing:
|
||||
self.stored_mean[:] = self.stored_mean + mean.data
|
||||
self.stored_var[:] = self.stored_var + var.data
|
||||
self.accumulation_counter += 1.0
|
||||
# If not accumulating standing stats, take running averages
|
||||
else:
|
||||
self.stored_mean[:] = self.stored_mean * (1 - self.momentum) + mean * self.momentum
|
||||
self.stored_var[:] = self.stored_var * (1 - self.momentum) + var * self.momentum
|
||||
return out
|
||||
# If not in training mode, use the stored statistics
|
||||
else:
|
||||
mean = self.stored_mean.view(1, -1, 1, 1)
|
||||
var = self.stored_var.view(1, -1, 1, 1)
|
||||
# If using standing stats, divide them by the accumulation counter
|
||||
if self.accumulate_standing:
|
||||
mean = mean / self.accumulation_counter
|
||||
var = var / self.accumulation_counter
|
||||
return fused_bn(x, mean, var, gain, bias, self.eps)
|
||||
|
||||
|
||||
# Simple function to handle groupnorm norm stylization
|
||||
def groupnorm(x, norm_style):
|
||||
# If number of channels specified in norm_style:
|
||||
if 'ch' in norm_style:
|
||||
ch = int(norm_style.split('_')[-1])
|
||||
groups = max(int(x.shape[1]) // ch, 1)
|
||||
# If number of groups specified in norm style
|
||||
elif 'grp' in norm_style:
|
||||
groups = int(norm_style.split('_')[-1])
|
||||
# If neither, default to groups = 16
|
||||
else:
|
||||
groups = 16
|
||||
return F.group_norm(x, groups)
|
||||
|
||||
|
||||
# Class-conditional bn
|
||||
# output size is the number of channels, input size is for the linear layers
|
||||
# Andy's Note: this class feels messy but I'm not really sure how to clean it up
|
||||
# Suggestions welcome! (By which I mean, refactor this and make a pull request
|
||||
# if you want to make this more readable/usable).
|
||||
class ccbn(nn.Module):
|
||||
def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1,
|
||||
cross_replica=False, mybn=False, norm_style='bn', ):
|
||||
super(ccbn, self).__init__()
|
||||
self.output_size, self.input_size = output_size, input_size
|
||||
# Prepare gain and bias layers
|
||||
self.gain = which_linear(input_size, output_size)
|
||||
self.bias = which_linear(input_size, output_size)
|
||||
# epsilon to avoid dividing by 0
|
||||
self.eps = eps
|
||||
# Momentum
|
||||
self.momentum = momentum
|
||||
# Use cross-replica batchnorm?
|
||||
self.cross_replica = cross_replica
|
||||
# Use my batchnorm?
|
||||
self.mybn = mybn
|
||||
# Norm style?
|
||||
self.norm_style = norm_style
|
||||
|
||||
if self.cross_replica or self.mybn:
|
||||
self.bn = myBN(output_size, self.eps, self.momentum)
|
||||
elif self.norm_style in ['bn', 'in']:
|
||||
self.register_buffer('stored_mean', torch.zeros(output_size))
|
||||
self.register_buffer('stored_var', torch.ones(output_size))
|
||||
|
||||
def forward(self, x, y):
|
||||
# Calculate class-conditional gains and biases
|
||||
gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
|
||||
bias = self.bias(y).view(y.size(0), -1, 1, 1)
|
||||
# If using my batchnorm
|
||||
if self.mybn or self.cross_replica:
|
||||
return self.bn(x, gain=gain, bias=bias)
|
||||
# else:
|
||||
else:
|
||||
if self.norm_style == 'bn':
|
||||
out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None,
|
||||
self.training, 0.1, self.eps)
|
||||
elif self.norm_style == 'in':
|
||||
out = F.instance_norm(x, self.stored_mean, self.stored_var, None, None,
|
||||
self.training, 0.1, self.eps)
|
||||
elif self.norm_style == 'gn':
|
||||
out = groupnorm(x, self.normstyle)
|
||||
elif self.norm_style == 'nonorm':
|
||||
out = x
|
||||
return out * gain + bias
|
||||
|
||||
def extra_repr(self):
|
||||
s = 'out: {output_size}, in: {input_size},'
|
||||
s += ' cross_replica={cross_replica}'
|
||||
return s.format(**self.__dict__)
|
||||
|
||||
|
||||
# Normal, non-class-conditional BN
|
||||
class bn(nn.Module):
|
||||
def __init__(self, output_size, eps=1e-5, momentum=0.1,
|
||||
cross_replica=False, mybn=False):
|
||||
super(bn, self).__init__()
|
||||
self.output_size = output_size
|
||||
# Prepare gain and bias layers
|
||||
self.gain = P(torch.ones(output_size), requires_grad=True)
|
||||
self.bias = P(torch.zeros(output_size), requires_grad=True)
|
||||
# epsilon to avoid dividing by 0
|
||||
self.eps = eps
|
||||
# Momentum
|
||||
self.momentum = momentum
|
||||
# Use cross-replica batchnorm?
|
||||
self.cross_replica = cross_replica
|
||||
# Use my batchnorm?
|
||||
self.mybn = mybn
|
||||
|
||||
if self.cross_replica or mybn:
|
||||
self.bn = myBN(output_size, self.eps, self.momentum)
|
||||
# Register buffers if neither of the above
|
||||
else:
|
||||
self.register_buffer('stored_mean', torch.zeros(output_size))
|
||||
self.register_buffer('stored_var', torch.ones(output_size))
|
||||
|
||||
def forward(self, x, y=None):
|
||||
if self.cross_replica or self.mybn:
|
||||
gain = self.gain.view(1, -1, 1, 1)
|
||||
bias = self.bias.view(1, -1, 1, 1)
|
||||
return self.bn(x, gain=gain, bias=bias)
|
||||
else:
|
||||
return F.batch_norm(x, self.stored_mean, self.stored_var, self.gain,
|
||||
self.bias, self.training, self.momentum, self.eps)
|
||||
|
||||
|
||||
# Generator blocks
|
||||
# Note that this class assumes the kernel size and padding (and any other
|
||||
# settings) have been selected in the main generator module and passed in
|
||||
# through the which_conv arg. Similar rules apply with which_bn (the input
|
||||
# size [which is actually the number of channels of the conditional info] must
|
||||
# be preselected)
|
||||
class GBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels,
|
||||
which_conv=nn.Conv2d, which_bn=bn, activation=None,
|
||||
upsample=None):
|
||||
super(GBlock, self).__init__()
|
||||
|
||||
self.in_channels, self.out_channels = in_channels, out_channels
|
||||
self.which_conv, self.which_bn = which_conv, which_bn
|
||||
self.activation = activation
|
||||
self.upsample = upsample
|
||||
# Conv layers
|
||||
self.conv1 = self.which_conv(self.in_channels, self.out_channels)
|
||||
self.conv2 = self.which_conv(self.out_channels, self.out_channels)
|
||||
self.learnable_sc = in_channels != out_channels or upsample
|
||||
if self.learnable_sc:
|
||||
self.conv_sc = self.which_conv(in_channels, out_channels,
|
||||
kernel_size=1, padding=0)
|
||||
# Batchnorm layers
|
||||
self.bn1 = self.which_bn(in_channels)
|
||||
self.bn2 = self.which_bn(out_channels)
|
||||
# upsample layers
|
||||
self.upsample = upsample
|
||||
|
||||
def forward(self, x, y):
|
||||
h = self.activation(self.bn1(x, y))
|
||||
if self.upsample:
|
||||
h = self.upsample(h)
|
||||
x = self.upsample(x)
|
||||
h = self.conv1(h)
|
||||
h = self.activation(self.bn2(h, y))
|
||||
h = self.conv2(h)
|
||||
if self.learnable_sc:
|
||||
x = self.conv_sc(x)
|
||||
return h + x
|
||||
|
||||
|
||||
# Residual block for the discriminator
|
||||
class DBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, which_conv=SNConv2d, wide=True,
|
||||
preactivation=False, activation=None, downsample=None, ):
|
||||
super(DBlock, self).__init__()
|
||||
self.in_channels, self.out_channels = in_channels, out_channels
|
||||
# If using wide D (as in SA-GAN and BigGAN), change the channel pattern
|
||||
self.hidden_channels = self.out_channels if wide else self.in_channels
|
||||
self.which_conv = which_conv
|
||||
self.preactivation = preactivation
|
||||
self.activation = activation
|
||||
self.downsample = downsample
|
||||
|
||||
# Conv layers
|
||||
self.conv1 = self.which_conv(self.in_channels, self.hidden_channels)
|
||||
self.conv2 = self.which_conv(self.hidden_channels, self.out_channels)
|
||||
self.learnable_sc = True if (in_channels != out_channels) or downsample else False
|
||||
if self.learnable_sc:
|
||||
self.conv_sc = self.which_conv(in_channels, out_channels,
|
||||
kernel_size=1, padding=0)
|
||||
|
||||
def shortcut(self, x):
|
||||
if self.preactivation:
|
||||
if self.learnable_sc:
|
||||
x = self.conv_sc(x)
|
||||
if self.downsample:
|
||||
x = self.downsample(x)
|
||||
else:
|
||||
if self.downsample:
|
||||
x = self.downsample(x)
|
||||
if self.learnable_sc:
|
||||
x = self.conv_sc(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
if self.preactivation:
|
||||
# h = self.activation(x) # NOT TODAY SATAN
|
||||
# Andy's note: This line *must* be an out-of-place ReLU or it
|
||||
# will negatively affect the shortcut connection.
|
||||
h = F.relu(x)
|
||||
else:
|
||||
h = x
|
||||
h = self.conv1(h)
|
||||
h = self.conv2(self.activation(h))
|
||||
if self.downsample:
|
||||
h = self.downsample(h)
|
||||
|
||||
return h + self.shortcut(x)
|
||||
|
0
codes/models/archs/fixup_resnet/__init__.py
Normal file
0
codes/models/archs/fixup_resnet/__init__.py
Normal file
375
codes/models/archs/stylegan/Discriminator_StyleGAN.py
Normal file
375
codes/models/archs/stylegan/Discriminator_StyleGAN.py
Normal file
|
@ -0,0 +1,375 @@
|
|||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
|
||||
class BlurLayer(nn.Module):
|
||||
def __init__(self, kernel=None, normalize=True, flip=False, stride=1):
|
||||
super(BlurLayer, self).__init__()
|
||||
if kernel is None:
|
||||
kernel = [1, 2, 1]
|
||||
kernel = torch.tensor(kernel, dtype=torch.float32)
|
||||
kernel = kernel[:, None] * kernel[None, :]
|
||||
kernel = kernel[None, None]
|
||||
if normalize:
|
||||
kernel = kernel / kernel.sum()
|
||||
if flip:
|
||||
kernel = kernel[:, :, ::-1, ::-1]
|
||||
self.register_buffer('kernel', kernel)
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
# expand kernel channels
|
||||
kernel = self.kernel.expand(x.size(1), -1, -1, -1)
|
||||
x = F.conv2d(
|
||||
x,
|
||||
kernel,
|
||||
stride=self.stride,
|
||||
padding=int((self.kernel.size(2) - 1) / 2),
|
||||
groups=x.size(1)
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
class Upscale2d(nn.Module):
|
||||
@staticmethod
|
||||
def upscale2d(x, factor=2, gain=1):
|
||||
assert x.dim() == 4
|
||||
if gain != 1:
|
||||
x = x * gain
|
||||
if factor != 1:
|
||||
shape = x.shape
|
||||
x = x.view(shape[0], shape[1], shape[2], 1, shape[3], 1).expand(-1, -1, -1, factor, -1, factor)
|
||||
x = x.contiguous().view(shape[0], shape[1], factor * shape[2], factor * shape[3])
|
||||
return x
|
||||
|
||||
def __init__(self, factor=2, gain=1):
|
||||
super().__init__()
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
self.gain = gain
|
||||
self.factor = factor
|
||||
|
||||
def forward(self, x):
|
||||
return self.upscale2d(x, factor=self.factor, gain=self.gain)
|
||||
|
||||
|
||||
class Downscale2d(nn.Module):
|
||||
def __init__(self, factor=2, gain=1):
|
||||
super().__init__()
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
self.factor = factor
|
||||
self.gain = gain
|
||||
if factor == 2:
|
||||
f = [np.sqrt(gain) / factor] * factor
|
||||
self.blur = BlurLayer(kernel=f, normalize=False, stride=factor)
|
||||
else:
|
||||
self.blur = None
|
||||
|
||||
def forward(self, x):
|
||||
assert x.dim() == 4
|
||||
# 2x2, float32 => downscale using _blur2d().
|
||||
if self.blur is not None and x.dtype == torch.float32:
|
||||
return self.blur(x)
|
||||
|
||||
# Apply gain.
|
||||
if self.gain != 1:
|
||||
x = x * self.gain
|
||||
|
||||
# No-op => early exit.
|
||||
if self.factor == 1:
|
||||
return x
|
||||
|
||||
# Large factor => downscale using tf.nn.avg_pool().
|
||||
# NOTE: Requires tf_config['graph_options.place_pruned_graph']=True to work.
|
||||
return F.avg_pool2d(x, self.factor)
|
||||
|
||||
|
||||
class EqualizedConv2d(nn.Module):
|
||||
"""Conv layer with equalized learning rate and custom learning rate multiplier."""
|
||||
|
||||
def __init__(self, input_channels, output_channels, kernel_size, stride=1, gain=2 ** 0.5, use_wscale=False,
|
||||
lrmul=1, bias=True, intermediate=None, upscale=False, downscale=False):
|
||||
super().__init__()
|
||||
if upscale:
|
||||
self.upscale = Upscale2d()
|
||||
else:
|
||||
self.upscale = None
|
||||
if downscale:
|
||||
self.downscale = Downscale2d()
|
||||
else:
|
||||
self.downscale = None
|
||||
he_std = gain * (input_channels * kernel_size ** 2) ** (-0.5) # He init
|
||||
self.kernel_size = kernel_size
|
||||
if use_wscale:
|
||||
init_std = 1.0 / lrmul
|
||||
self.w_mul = he_std * lrmul
|
||||
else:
|
||||
init_std = he_std / lrmul
|
||||
self.w_mul = lrmul
|
||||
self.weight = torch.nn.Parameter(
|
||||
torch.randn(output_channels, input_channels, kernel_size, kernel_size) * init_std)
|
||||
if bias:
|
||||
self.bias = torch.nn.Parameter(torch.zeros(output_channels))
|
||||
self.b_mul = lrmul
|
||||
else:
|
||||
self.bias = None
|
||||
self.intermediate = intermediate
|
||||
|
||||
def forward(self, x):
|
||||
bias = self.bias
|
||||
if bias is not None:
|
||||
bias = bias * self.b_mul
|
||||
|
||||
have_convolution = False
|
||||
if self.upscale is not None and min(x.shape[2:]) * 2 >= 128:
|
||||
# this is the fused upscale + conv from StyleGAN, sadly this seems incompatible with the non-fused way
|
||||
# this really needs to be cleaned up and go into the conv...
|
||||
w = self.weight * self.w_mul
|
||||
w = w.permute(1, 0, 2, 3)
|
||||
# probably applying a conv on w would be more efficient. also this quadruples the weight (average)?!
|
||||
w = F.pad(w, [1, 1, 1, 1])
|
||||
w = w[:, :, 1:, 1:] + w[:, :, :-1, 1:] + w[:, :, 1:, :-1] + w[:, :, :-1, :-1]
|
||||
x = F.conv_transpose2d(x, w, stride=2, padding=(w.size(-1) - 1) // 2)
|
||||
have_convolution = True
|
||||
elif self.upscale is not None:
|
||||
x = self.upscale(x)
|
||||
|
||||
downscale = self.downscale
|
||||
intermediate = self.intermediate
|
||||
if downscale is not None and min(x.shape[2:]) >= 128:
|
||||
w = self.weight * self.w_mul
|
||||
w = F.pad(w, [1, 1, 1, 1])
|
||||
# in contrast to upscale, this is a mean...
|
||||
w = (w[:, :, 1:, 1:] + w[:, :, :-1, 1:] + w[:, :, 1:, :-1] + w[:, :, :-1, :-1]) * 0.25 # avg_pool?
|
||||
x = F.conv2d(x, w, stride=2, padding=(w.size(-1) - 1) // 2)
|
||||
have_convolution = True
|
||||
downscale = None
|
||||
elif downscale is not None:
|
||||
assert intermediate is None
|
||||
intermediate = downscale
|
||||
|
||||
if not have_convolution and intermediate is None:
|
||||
return F.conv2d(x, self.weight * self.w_mul, bias, padding=self.kernel_size // 2)
|
||||
elif not have_convolution:
|
||||
x = F.conv2d(x, self.weight * self.w_mul, None, padding=self.kernel_size // 2)
|
||||
|
||||
if intermediate is not None:
|
||||
x = intermediate(x)
|
||||
|
||||
if bias is not None:
|
||||
x = x + bias.view(1, -1, 1, 1)
|
||||
return x
|
||||
|
||||
|
||||
class EqualizedLinear(nn.Module):
|
||||
"""Linear layer with equalized learning rate and custom learning rate multiplier."""
|
||||
|
||||
def __init__(self, input_size, output_size, gain=2 ** 0.5, use_wscale=False, lrmul=1, bias=True):
|
||||
super().__init__()
|
||||
he_std = gain * input_size ** (-0.5) # He init
|
||||
# Equalized learning rate and custom learning rate multiplier.
|
||||
if use_wscale:
|
||||
init_std = 1.0 / lrmul
|
||||
self.w_mul = he_std * lrmul
|
||||
else:
|
||||
init_std = he_std / lrmul
|
||||
self.w_mul = lrmul
|
||||
self.weight = torch.nn.Parameter(torch.randn(output_size, input_size) * init_std)
|
||||
if bias:
|
||||
self.bias = torch.nn.Parameter(torch.zeros(output_size))
|
||||
self.b_mul = lrmul
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def forward(self, x):
|
||||
bias = self.bias
|
||||
if bias is not None:
|
||||
bias = bias * self.b_mul
|
||||
return F.linear(x, self.weight * self.w_mul, bias)
|
||||
|
||||
|
||||
class View(nn.Module):
|
||||
def __init__(self, *shape):
|
||||
super().__init__()
|
||||
self.shape = shape
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
return x.view(x.size(0), *self.shape)
|
||||
|
||||
|
||||
class StddevLayer(nn.Module):
|
||||
def __init__(self, group_size=4, num_new_features=1):
|
||||
super().__init__()
|
||||
self.group_size = group_size
|
||||
self.num_new_features = num_new_features
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
group_size = min(self.group_size, b)
|
||||
y = x.reshape([group_size, -1, self.num_new_features,
|
||||
c // self.num_new_features, h, w])
|
||||
y = y - y.mean(0, keepdim=True)
|
||||
y = (y ** 2).mean(0, keepdim=True)
|
||||
y = (y + 1e-8) ** 0.5
|
||||
y = y.mean([3, 4, 5], keepdim=True).squeeze(3) # don't keep the meaned-out channels
|
||||
y = y.expand(group_size, -1, -1, h, w).clone().reshape(b, self.num_new_features, h, w)
|
||||
z = torch.cat([x, y], dim=1)
|
||||
return z
|
||||
|
||||
|
||||
class DiscriminatorBlock(nn.Sequential):
|
||||
def __init__(self, in_channels, out_channels, gain, use_wscale, activation_layer, blur_kernel):
|
||||
super().__init__(OrderedDict([
|
||||
('conv0', EqualizedConv2d(in_channels, in_channels, kernel_size=3, gain=gain, use_wscale=use_wscale)),
|
||||
# out channels nf(res-1)
|
||||
('act0', activation_layer),
|
||||
('blur', BlurLayer(kernel=blur_kernel)),
|
||||
('conv1_down', EqualizedConv2d(in_channels, out_channels, kernel_size=3,
|
||||
gain=gain, use_wscale=use_wscale, downscale=True)),
|
||||
('act1', activation_layer)]))
|
||||
|
||||
|
||||
|
||||
class DiscriminatorTop(nn.Sequential):
|
||||
def __init__(self,
|
||||
mbstd_group_size,
|
||||
mbstd_num_features,
|
||||
in_channels,
|
||||
intermediate_channels,
|
||||
gain, use_wscale,
|
||||
activation_layer,
|
||||
resolution=4,
|
||||
in_channels2=None,
|
||||
output_features=1,
|
||||
last_gain=1):
|
||||
"""
|
||||
:param mbstd_group_size:
|
||||
:param mbstd_num_features:
|
||||
:param in_channels:
|
||||
:param intermediate_channels:
|
||||
:param gain:
|
||||
:param use_wscale:
|
||||
:param activation_layer:
|
||||
:param resolution:
|
||||
:param in_channels2:
|
||||
:param output_features:
|
||||
:param last_gain:
|
||||
"""
|
||||
|
||||
layers = []
|
||||
if mbstd_group_size > 1:
|
||||
layers.append(('stddev_layer', StddevLayer(mbstd_group_size, mbstd_num_features)))
|
||||
|
||||
if in_channels2 is None:
|
||||
in_channels2 = in_channels
|
||||
|
||||
layers.append(('conv', EqualizedConv2d(in_channels + mbstd_num_features, in_channels2, kernel_size=3,
|
||||
gain=gain, use_wscale=use_wscale)))
|
||||
layers.append(('act0', activation_layer))
|
||||
layers.append(('view', View(-1)))
|
||||
layers.append(('dense0', EqualizedLinear(in_channels2 * resolution * resolution, intermediate_channels,
|
||||
gain=gain, use_wscale=use_wscale)))
|
||||
layers.append(('act1', activation_layer))
|
||||
layers.append(('dense1', EqualizedLinear(intermediate_channels, output_features,
|
||||
gain=last_gain, use_wscale=use_wscale)))
|
||||
|
||||
super().__init__(OrderedDict(layers))
|
||||
|
||||
|
||||
class StyleGanDiscriminator(nn.Module):
|
||||
def __init__(self, resolution, num_channels=3, fmap_base=8192, fmap_decay=1.0, fmap_max=512,
|
||||
nonlinearity='lrelu', use_wscale=True, mbstd_group_size=4, mbstd_num_features=1,
|
||||
blur_filter=None, structure='fixed', **kwargs):
|
||||
"""
|
||||
Discriminator used in the StyleGAN paper.
|
||||
:param num_channels: Number of input color channels. Overridden based on dataset.
|
||||
:param resolution: Input resolution. Overridden based on dataset.
|
||||
# label_size=0, # Dimensionality of the labels, 0 if no labels. Overridden based on dataset.
|
||||
:param fmap_base: Overall multiplier for the number of feature maps.
|
||||
:param fmap_decay: log2 feature map reduction when doubling the resolution.
|
||||
:param fmap_max: Maximum number of feature maps in any layer.
|
||||
:param nonlinearity: Activation function: 'relu', 'lrelu'
|
||||
:param use_wscale: Enable equalized learning rate?
|
||||
:param mbstd_group_size: Group size for the mini_batch standard deviation layer, 0 = disable.
|
||||
:param mbstd_num_features: Number of features for the mini_batch standard deviation layer.
|
||||
:param blur_filter: Low-pass filter to apply when resampling activations. None = no filtering.
|
||||
:param structure: 'fixed' = no progressive growing, 'linear' = human-readable
|
||||
:param kwargs: Ignore unrecognized keyword args.
|
||||
"""
|
||||
super(StyleGanDiscriminator, self).__init__()
|
||||
|
||||
def nf(stage):
|
||||
return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max)
|
||||
|
||||
self.mbstd_num_features = mbstd_num_features
|
||||
self.mbstd_group_size = mbstd_group_size
|
||||
self.structure = structure
|
||||
# if blur_filter is None:
|
||||
# blur_filter = [1, 2, 1]
|
||||
|
||||
resolution_log2 = int(np.log2(resolution))
|
||||
assert resolution == 2 ** resolution_log2 and resolution >= 4
|
||||
self.depth = resolution_log2 - 1
|
||||
|
||||
act, gain = {'relu': (torch.relu, np.sqrt(2)),
|
||||
'lrelu': (nn.LeakyReLU(negative_slope=0.2), np.sqrt(2))}[nonlinearity]
|
||||
|
||||
# create the remaining layers
|
||||
blocks = []
|
||||
from_rgb = []
|
||||
for res in range(resolution_log2, 2, -1):
|
||||
# name = '{s}x{s}'.format(s=2 ** res)
|
||||
blocks.append(DiscriminatorBlock(nf(res - 1), nf(res - 2),
|
||||
gain=gain, use_wscale=use_wscale, activation_layer=act,
|
||||
blur_kernel=blur_filter))
|
||||
# create the fromRGB layers for various inputs:
|
||||
from_rgb.append(EqualizedConv2d(num_channels, nf(res - 1), kernel_size=1,
|
||||
gain=gain, use_wscale=use_wscale))
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
|
||||
# Building the final block.
|
||||
self.final_block = DiscriminatorTop(self.mbstd_group_size, self.mbstd_num_features,
|
||||
in_channels=nf(2), intermediate_channels=nf(2),
|
||||
gain=gain, use_wscale=use_wscale, activation_layer=act)
|
||||
from_rgb.append(EqualizedConv2d(num_channels, nf(2), kernel_size=1,
|
||||
gain=gain, use_wscale=use_wscale))
|
||||
self.from_rgb = nn.ModuleList(from_rgb)
|
||||
|
||||
# register the temporary downSampler
|
||||
self.temporaryDownsampler = nn.AvgPool2d(2)
|
||||
|
||||
def forward(self, images_in, depth=0, alpha=1.):
|
||||
"""
|
||||
:param images_in: First input: Images [mini_batch, channel, height, width].
|
||||
:param labels_in: Second input: Labels [mini_batch, label_size].
|
||||
:param depth: current height of operation (Progressive GAN)
|
||||
:param alpha: current value of alpha for fade-in
|
||||
:return:
|
||||
"""
|
||||
|
||||
if self.structure == 'fixed':
|
||||
x = self.from_rgb[0](images_in)
|
||||
for i, block in enumerate(self.blocks):
|
||||
x = block(x)
|
||||
scores_out = self.final_block(x)
|
||||
elif self.structure == 'linear':
|
||||
assert depth < self.depth, "Requested output depth cannot be produced"
|
||||
if depth > 0:
|
||||
residual = self.from_rgb[self.depth - depth](self.temporaryDownsampler(images_in))
|
||||
straight = self.blocks[self.depth - depth - 1](self.from_rgb[self.depth - depth - 1](images_in))
|
||||
x = (alpha * straight) + ((1 - alpha) * residual)
|
||||
|
||||
for block in self.blocks[(self.depth - depth):]:
|
||||
x = block(x)
|
||||
else:
|
||||
x = self.from_rgb[-1](images_in)
|
||||
|
||||
scores_out = self.final_block(x)
|
||||
else:
|
||||
raise KeyError("Unknown structure: ", self.structure)
|
||||
|
||||
return scores_out
|
0
codes/models/archs/stylegan/__init__.py
Normal file
0
codes/models/archs/stylegan/__init__.py
Normal file
|
@ -7,8 +7,7 @@ import torch
|
|||
import torchvision
|
||||
from munch import munchify
|
||||
|
||||
import models.archs.DiscriminatorResnet_arch as DiscriminatorResnet_arch
|
||||
import models.archs.DiscriminatorResnet_arch_passthrough as DiscriminatorResnet_arch_passthrough
|
||||
import models.archs.fixup_resnet.DiscriminatorResnet_arch as DiscriminatorResnet_arch
|
||||
import models.archs.RRDBNet_arch as RRDBNet_arch
|
||||
import models.archs.SPSR_arch as spsr
|
||||
import models.archs.SRResNet_arch as SRResNet_arch
|
||||
|
@ -17,9 +16,11 @@ import models.archs.discriminator_vgg_arch as SRGAN_arch
|
|||
import models.archs.feature_arch as feature_arch
|
||||
import models.archs.panet.panet as panet
|
||||
import models.archs.rcan as rcan
|
||||
import models.archs.ChainedEmbeddingGen as chained
|
||||
from models.archs import srg2_classic
|
||||
from models.archs.biggan.biggan_discriminator import BigGanDiscriminator
|
||||
from models.archs.stylegan.Discriminator_StyleGAN import StyleGanDiscriminator
|
||||
from models.archs.pyramid_arch import BasicResamplingFlowNet
|
||||
from models.archs.rrdb_with_adain_latent import AdaRRDBNet, LinearLatentEstimator
|
||||
from models.archs.rrdb_with_latent import LatentEstimator, RRDBNetWithLatent, LatentEstimator2
|
||||
from models.archs.teco_resgen import TecoGen
|
||||
|
||||
|
@ -90,15 +91,6 @@ def define_G(opt, net_key='network_G', scale=None):
|
|||
netG = spsr.Spsr7(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],
|
||||
multiplexer_reductions=opt_net['multiplexer_reductions'] if 'multiplexer_reductions' in opt_net.keys() else 3,
|
||||
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10, recurrent=recurrent)
|
||||
elif which_model == 'chained_gen_structured':
|
||||
rec = opt_net['recurrent'] if 'recurrent' in opt_net.keys() else False
|
||||
recnf = opt_net['recurrent_nf'] if 'recurrent_nf' in opt_net.keys() else 3
|
||||
recstd = opt_net['recurrent_stride'] if 'recurrent_stride' in opt_net.keys() else 2
|
||||
in_nc = opt_net['in_nc'] if 'in_nc' in opt_net.keys() else 3
|
||||
netG = chained.ChainedEmbeddingGenWithStructure(depth=opt_net['depth'], recurrent=rec, recurrent_nf=recnf, recurrent_stride=recstd, in_nc=in_nc)
|
||||
elif which_model == 'multifaceted_chained':
|
||||
scale = opt_net['scale'] if 'scale' in opt_net.keys() else 2
|
||||
netG = chained.MultifacetedChainedEmbeddingGen(depth=opt_net['depth'], scale=scale)
|
||||
elif which_model == "flownet2":
|
||||
from models.flownet2.models import FlowNet2
|
||||
ld = 'load_path' in opt_net.keys()
|
||||
|
@ -125,12 +117,19 @@ def define_G(opt, net_key='network_G', scale=None):
|
|||
blocks_per_checkpoint=opt_net['blocks_per_checkpoint'],
|
||||
scale=opt_net['scale'],
|
||||
bottom_latent_only=opt_net['bottom_latent_only'])
|
||||
elif which_model == "adarrdb":
|
||||
netG = AdaRRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
|
||||
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'],
|
||||
blocks_per_checkpoint=opt_net['blocks_per_checkpoint'],
|
||||
scale=opt_net['scale'])
|
||||
elif which_model == "latent_estimator":
|
||||
if opt_net['version'] == 2:
|
||||
netG = LatentEstimator2(in_nc=3, nf=opt_net['nf'])
|
||||
else:
|
||||
overwrite = [1,2] if opt_net['only_base_level'] else []
|
||||
netG = LatentEstimator(in_nc=3, nf=opt_net['nf'], overwrite_levels=overwrite)
|
||||
elif which_model == "linear_latent_estimator":
|
||||
netG = LinearLatentEstimator(in_nc=3, nf=opt_net['nf'])
|
||||
else:
|
||||
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
||||
return netG
|
||||
|
@ -159,19 +158,19 @@ def define_D_net(opt_net, img_sz=None, wrap=False):
|
|||
netD = GradDiscWrapper(netD)
|
||||
elif which_model == 'discriminator_vgg_128_gn_checkpointed':
|
||||
netD = SRGAN_arch.Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128, do_checkpointing=True)
|
||||
elif which_model == 'stylegan_vgg':
|
||||
netD = StyleGanDiscriminator(128)
|
||||
elif which_model == 'discriminator_resnet':
|
||||
netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
|
||||
elif which_model == 'discriminator_resnet_50':
|
||||
netD = DiscriminatorResnet_arch.fixup_resnet50(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
|
||||
elif which_model == 'discriminator_resnet_passthrough':
|
||||
netD = DiscriminatorResnet_arch_passthrough.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz,
|
||||
number_skips=opt_net['number_skips'], use_bn=True,
|
||||
disable_passthrough=opt_net['disable_passthrough'])
|
||||
elif which_model == 'resnext':
|
||||
netD = torchvision.models.resnext50_32x4d(norm_layer=functools.partial(torch.nn.GroupNorm, 8))
|
||||
state_dict = torch.hub.load_state_dict_from_url('https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', progress=True)
|
||||
netD.load_state_dict(state_dict, strict=False)
|
||||
#state_dict = torch.hub.load_state_dict_from_url('https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', progress=True)
|
||||
#netD.load_state_dict(state_dict, strict=False)
|
||||
netD.fc = torch.nn.Linear(512 * 4, 1)
|
||||
elif which_model == 'biggan_resnet':
|
||||
netD = BigGanDiscriminator(D_activation=torch.nn.LeakyReLU(negative_slope=.2))
|
||||
elif which_model == 'discriminator_pix':
|
||||
netD = SRGAN_arch.Discriminator_VGG_PixLoss(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
|
||||
elif which_model == "discriminator_unet":
|
||||
|
|
|
@ -265,7 +265,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_latent_mi1_rrdb4x_6bl_lower_signal.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_adalatent_mi1_rrdb4x_6bl.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -280,7 +280,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_latent_mi1_rrdb4x_6bl_lower_signal_2.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_adalatent_mi1_rrdb4x_6bl_resdisc.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
args = parser.parse_args()
|
||||
opt = option.parse(args.opt, is_train=True)
|
||||
|
|
Loading…
Reference in New Issue
Block a user