Fixes for unet

This commit is contained in:
James Betker 2020-11-15 10:38:33 -07:00
parent 89f56b2091
commit ea94b93a37

View File

@ -4,7 +4,7 @@ from math import log2
import torch
import torch.nn as nn
from models.archs.stylegan.stylegan2 import attn_and_ff, Flatten
from models.archs.stylegan.stylegan2 import attn_and_ff
def leaky_relu(p=0.2):
@ -20,6 +20,14 @@ def double_conv(chan_in, chan_out):
)
class Flatten(nn.Module):
def __init__(self, index):
super().__init__()
self.index = index
def forward(self, x):
return x.flatten(self.index)
class DownBlock(nn.Module):
def __init__(self, input_channels, filters, downsample=True):
super().__init__()
@ -103,7 +111,7 @@ class StyleGan2UnetDiscriminator(nn.Module):
dec_chan_in_out = chan_in_out[:-1][::-1]
self.up_blocks = nn.ModuleList(list(map(lambda c: UpBlock(c[1] * 2, c[0]), dec_chan_in_out)))
self.conv_out = nn.Conv2d(3, 1, 1)
self.conv_out = nn.Conv2d(input_filters, 1, 1)
def forward(self, x):
b, *_ = x.shape
@ -124,4 +132,4 @@ class StyleGan2UnetDiscriminator(nn.Module):
x = up_block(x, res)
dec_out = self.conv_out(x)
return enc_out.squeeze(), dec_out
return dec_out