forked from mrq/DL-Art-School
Fixes for unet
This commit is contained in:
parent
89f56b2091
commit
ea94b93a37
|
@ -4,7 +4,7 @@ from math import log2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from models.archs.stylegan.stylegan2 import attn_and_ff, Flatten
|
from models.archs.stylegan.stylegan2 import attn_and_ff
|
||||||
|
|
||||||
|
|
||||||
def leaky_relu(p=0.2):
|
def leaky_relu(p=0.2):
|
||||||
|
@ -20,6 +20,14 @@ def double_conv(chan_in, chan_out):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Flatten(nn.Module):
|
||||||
|
def __init__(self, index):
|
||||||
|
super().__init__()
|
||||||
|
self.index = index
|
||||||
|
def forward(self, x):
|
||||||
|
return x.flatten(self.index)
|
||||||
|
|
||||||
|
|
||||||
class DownBlock(nn.Module):
|
class DownBlock(nn.Module):
|
||||||
def __init__(self, input_channels, filters, downsample=True):
|
def __init__(self, input_channels, filters, downsample=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -103,7 +111,7 @@ class StyleGan2UnetDiscriminator(nn.Module):
|
||||||
|
|
||||||
dec_chan_in_out = chan_in_out[:-1][::-1]
|
dec_chan_in_out = chan_in_out[:-1][::-1]
|
||||||
self.up_blocks = nn.ModuleList(list(map(lambda c: UpBlock(c[1] * 2, c[0]), dec_chan_in_out)))
|
self.up_blocks = nn.ModuleList(list(map(lambda c: UpBlock(c[1] * 2, c[0]), dec_chan_in_out)))
|
||||||
self.conv_out = nn.Conv2d(3, 1, 1)
|
self.conv_out = nn.Conv2d(input_filters, 1, 1)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
b, *_ = x.shape
|
b, *_ = x.shape
|
||||||
|
@ -124,4 +132,4 @@ class StyleGan2UnetDiscriminator(nn.Module):
|
||||||
x = up_block(x, res)
|
x = up_block(x, res)
|
||||||
|
|
||||||
dec_out = self.conv_out(x)
|
dec_out = self.conv_out(x)
|
||||||
return enc_out.squeeze(), dec_out
|
return dec_out
|
||||||
|
|
Loading…
Reference in New Issue
Block a user