Fixes for unet
This commit is contained in:
parent
89f56b2091
commit
ea94b93a37
|
@ -4,7 +4,7 @@ from math import log2
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from models.archs.stylegan.stylegan2 import attn_and_ff, Flatten
|
||||
from models.archs.stylegan.stylegan2 import attn_and_ff
|
||||
|
||||
|
||||
def leaky_relu(p=0.2):
|
||||
|
@ -20,6 +20,14 @@ def double_conv(chan_in, chan_out):
|
|||
)
|
||||
|
||||
|
||||
class Flatten(nn.Module):
|
||||
def __init__(self, index):
|
||||
super().__init__()
|
||||
self.index = index
|
||||
def forward(self, x):
|
||||
return x.flatten(self.index)
|
||||
|
||||
|
||||
class DownBlock(nn.Module):
|
||||
def __init__(self, input_channels, filters, downsample=True):
|
||||
super().__init__()
|
||||
|
@ -103,7 +111,7 @@ class StyleGan2UnetDiscriminator(nn.Module):
|
|||
|
||||
dec_chan_in_out = chan_in_out[:-1][::-1]
|
||||
self.up_blocks = nn.ModuleList(list(map(lambda c: UpBlock(c[1] * 2, c[0]), dec_chan_in_out)))
|
||||
self.conv_out = nn.Conv2d(3, 1, 1)
|
||||
self.conv_out = nn.Conv2d(input_filters, 1, 1)
|
||||
|
||||
def forward(self, x):
|
||||
b, *_ = x.shape
|
||||
|
@ -124,4 +132,4 @@ class StyleGan2UnetDiscriminator(nn.Module):
|
|||
x = up_block(x, res)
|
||||
|
||||
dec_out = self.conv_out(x)
|
||||
return enc_out.squeeze(), dec_out
|
||||
return dec_out
|
||||
|
|
Loading…
Reference in New Issue
Block a user