diff --git a/codes/data/stylegan2_dataset.py b/codes/data/stylegan2_dataset.py index d2e39331..f44f1ddf 100644 --- a/codes/data/stylegan2_dataset.py +++ b/codes/data/stylegan2_dataset.py @@ -9,7 +9,7 @@ from torchvision import transforms import torch.nn as nn from pathlib import Path -from models.archs.stylegan2 import exists +from models.archs.stylegan.stylegan2 import exists def convert_transparent_to_rgb(image): diff --git a/codes/models/archs/stylegan2.py b/codes/models/archs/stylegan/stylegan2.py similarity index 100% rename from codes/models/archs/stylegan2.py rename to codes/models/archs/stylegan/stylegan2.py diff --git a/codes/models/archs/stylegan/stylegan2_unet_disc.py b/codes/models/archs/stylegan/stylegan2_unet_disc.py new file mode 100644 index 00000000..fd51e1ea --- /dev/null +++ b/codes/models/archs/stylegan/stylegan2_unet_disc.py @@ -0,0 +1,124 @@ +from functools import partial +from math import log2 + +import torch +import torch.nn as nn + +def leaky_relu(p=0.2): + return nn.LeakyReLU(p) + + +def double_conv(chan_in, chan_out): + return nn.Sequential( + nn.Conv2d(chan_in, chan_out, 3, padding=1), + leaky_relu(), + nn.Conv2d(chan_out, chan_out, 3, padding=1), + leaky_relu() + ) + + +class DownBlock(nn.Module): + def __init__(self, input_channels, filters, downsample=True): + super().__init__() + self.conv_res = nn.Conv2d(input_channels, filters, 1, stride = (2 if downsample else 1)) + + self.net = double_conv(input_channels, filters) + self.down = nn.Conv2d(filters, filters, 3, padding = 1, stride = 2) if downsample else None + + def forward(self, x): + res = self.conv_res(x) + x = self.net(x) + unet_res = x + + if self.down is not None: + x = self.down(x) + + x = x + res + return x, unet_res + + +class UpBlock(nn.Module): + def __init__(self, input_channels, filters): + super().__init__() + self.conv_res = nn.ConvTranspose2d(input_channels // 2, filters, 1, stride = 2) + self.net = double_conv(input_channels, filters) + self.up = nn.Upsample(scale_factor = 2, mode='bilinear', align_corners=False) + self.input_channels = input_channels + self.filters = filters + + def forward(self, x, res): + *_, h, w = x.shape + conv_res = self.conv_res(x, output_size = (h * 2, w * 2)) + x = self.up(x) + x = torch.cat((x, res), dim=1) + x = self.net(x) + x = x + conv_res + return x + + +class StyleGan2UnetDiscriminator(nn.Module): + def __init__(self, image_size, network_capacity = 16, fmap_max = 512, input_filters=3): + super().__init__() + num_layers = int(log2(image_size) - 3) + + blocks = [] + filters = [input_filters] + [(network_capacity) * (2 ** i) for i in range(num_layers + 1)] + + set_fmap_max = partial(min, fmap_max) + filters = list(map(set_fmap_max, filters)) + filters[-1] = filters[-2] + + chan_in_out = list(zip(filters[:-1], filters[1:])) + chan_in_out = list(map(list, chan_in_out)) + + down_blocks = [] + attn_blocks = [] + + for ind, (in_chan, out_chan) in enumerate(chan_in_out): + num_layer = ind + 1 + is_not_last = ind != (len(chan_in_out) - 1) + + block = DownBlock(in_chan, out_chan, downsample = is_not_last) + down_blocks.append(block) + + attn_fn = attn_and_ff(out_chan) + attn_blocks.append(attn_fn) + + self.down_blocks = nn.ModuleList(down_blocks) + self.attn_blocks = nn.ModuleList(attn_blocks) + + last_chan = filters[-1] + + self.to_logit = nn.Sequential( + leaky_relu(), + nn.AvgPool2d(image_size // (2 ** num_layers)), + Flatten(1), + nn.Linear(last_chan, 1) + ) + + self.conv = double_conv(last_chan, last_chan) + + dec_chan_in_out = chan_in_out[:-1][::-1] + self.up_blocks = nn.ModuleList(list(map(lambda c: UpBlock(c[1] * 2, c[0]), dec_chan_in_out))) + self.conv_out = nn.Conv2d(3, 1, 1) + + def forward(self, x): + b, *_ = x.shape + + residuals = [] + + for (down_block, attn_block) in zip(self.down_blocks, self.attn_blocks): + x, unet_res = down_block(x) + residuals.append(unet_res) + + if attn_block is not None: + x = attn_block(x) + + x = self.conv(x) + x + enc_out = self.to_logit(x) + + for (up_block, res) in zip(self.up_blocks, residuals[:-1][::-1]): + x = up_block(x, res) + + dec_out = self.conv_out(x) + return enc_out.squeeze(), dec_out \ No newline at end of file diff --git a/codes/models/networks.py b/codes/models/networks.py index 0d5296b0..95ac3ae1 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -22,7 +22,8 @@ from models.archs.stylegan.Discriminator_StyleGAN import StyleGanDiscriminator from models.archs.pyramid_arch import BasicResamplingFlowNet from models.archs.rrdb_with_adain_latent import AdaRRDBNet, LinearLatentEstimator from models.archs.rrdb_with_latent import LatentEstimator, RRDBNetWithLatent, LatentEstimator2 -from models.archs.stylegan2 import StyleGan2GeneratorWithLatent, StyleGan2Discriminator, StyleGan2Augmentor +from models.archs.stylegan.stylegan2 import StyleGan2GeneratorWithLatent, StyleGan2Discriminator, StyleGan2Augmentor +from models.archs.stylegan.stylegan2_unet_disc import StyleGan2UnetDiscriminator from models.archs.teco_resgen import TecoGen logger = logging.getLogger('base') @@ -200,6 +201,9 @@ def define_D_net(opt_net, img_sz=None, wrap=False): attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else [] disc = StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn) netD = StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability']) + elif which_model == "stylegan2_unet": + disc = StyleGan2UnetDiscriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc']) + netD = StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability']) else: raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) return netD diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index 28f692b7..639672cf 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -517,7 +517,7 @@ class StyleGan2DivergenceLoss(ConfigurableLoss): # Apply gradient penalty. TODO: migrate this elsewhere. if self.env['step'] % self.gp_frequency == 0: - from models.archs.stylegan2 import gradient_penalty + from models.archs.stylegan.stylegan2 import gradient_penalty gp = gradient_penalty(real_input, real) self.metrics.append(("gradient_penalty", gp.clone().detach())) divergence_loss = divergence_loss + gp @@ -532,17 +532,17 @@ class StyleGan2PathLengthLoss(ConfigurableLoss): self.w_styles = opt['w_styles'] self.gen = opt['gen'] self.pl_mean = None - from models.archs.stylegan2 import EMA + from models.archs.stylegan.stylegan2 import EMA self.pl_length_ma = EMA(.99) def forward(self, net, state): w_styles = state[self.w_styles] gen = state[self.gen] - from models.archs.stylegan2 import calc_pl_lengths + from models.archs.stylegan.stylegan2 import calc_pl_lengths pl_lengths = calc_pl_lengths(w_styles, gen) avg_pl_length = np.mean(pl_lengths.detach().cpu().numpy()) - from models.archs.stylegan2 import is_empty + from models.archs.stylegan.stylegan2 import is_empty if not is_empty(self.pl_mean): pl_loss = ((pl_lengths - self.pl_mean) ** 2).mean() if not torch.isnan(pl_loss):