diff --git a/codes/models/archs/stylegan/stylegan2_unet_disc.py b/codes/models/archs/stylegan/stylegan2_unet_disc.py index 2f92750b..f506b83e 100644 --- a/codes/models/archs/stylegan/stylegan2_unet_disc.py +++ b/codes/models/archs/stylegan/stylegan2_unet_disc.py @@ -4,7 +4,7 @@ from math import log2 import torch import torch.nn as nn -from models.archs.stylegan.stylegan2 import attn_and_ff, Flatten +from models.archs.stylegan.stylegan2 import attn_and_ff def leaky_relu(p=0.2): @@ -20,6 +20,14 @@ def double_conv(chan_in, chan_out): ) +class Flatten(nn.Module): + def __init__(self, index): + super().__init__() + self.index = index + def forward(self, x): + return x.flatten(self.index) + + class DownBlock(nn.Module): def __init__(self, input_channels, filters, downsample=True): super().__init__() @@ -103,7 +111,7 @@ class StyleGan2UnetDiscriminator(nn.Module): dec_chan_in_out = chan_in_out[:-1][::-1] self.up_blocks = nn.ModuleList(list(map(lambda c: UpBlock(c[1] * 2, c[0]), dec_chan_in_out))) - self.conv_out = nn.Conv2d(3, 1, 1) + self.conv_out = nn.Conv2d(input_filters, 1, 1) def forward(self, x): b, *_ = x.shape @@ -124,4 +132,4 @@ class StyleGan2UnetDiscriminator(nn.Module): x = up_block(x, res) dec_out = self.conv_out(x) - return enc_out.squeeze(), dec_out \ No newline at end of file + return dec_out