80 lines
2.5 KiB
Python
80 lines
2.5 KiB
Python
|
import torch
|
||
|
from torch import nn
|
||
|
|
||
|
|
||
|
def down_block(in_filters, out_filters):
|
||
|
return nn.Conv2d(in_filters, out_filters, kernel_size=5,
|
||
|
stride=2, padding=2,
|
||
|
), nn.Sequential(
|
||
|
nn.BatchNorm2d(out_filters, track_running_stats=True, eps=1e-3, momentum=0.01),
|
||
|
nn.LeakyReLU(0.2)
|
||
|
)
|
||
|
|
||
|
|
||
|
def up_block(in_filters, out_filters, dropout=False):
|
||
|
layers = [
|
||
|
nn.ConvTranspose2d(in_filters, out_filters, kernel_size=5,
|
||
|
stride=2, padding=2, output_padding=1
|
||
|
),
|
||
|
nn.ReLU(),
|
||
|
nn.BatchNorm2d(out_filters, track_running_stats=True, eps=1e-3, momentum=0.01)
|
||
|
]
|
||
|
if dropout:
|
||
|
layers.append(nn.Dropout(0.5))
|
||
|
|
||
|
return nn.Sequential(*layers)
|
||
|
|
||
|
|
||
|
class UNet(nn.Module):
|
||
|
def __init__(self, in_channels=2):
|
||
|
super(UNet, self).__init__()
|
||
|
self.down1_conv, self.down1_act = down_block(in_channels, 16)
|
||
|
self.down2_conv, self.down2_act = down_block(16, 32)
|
||
|
self.down3_conv, self.down3_act = down_block(32, 64)
|
||
|
self.down4_conv, self.down4_act = down_block(64, 128)
|
||
|
self.down5_conv, self.down5_act = down_block(128, 256)
|
||
|
self.down6_conv, self.down6_act = down_block(256, 512)
|
||
|
|
||
|
self.up1 = up_block(512, 256, dropout=True)
|
||
|
self.up2 = up_block(512, 128, dropout=True)
|
||
|
self.up3 = up_block(256, 64, dropout=True)
|
||
|
self.up4 = up_block(128, 32)
|
||
|
self.up5 = up_block(64, 16)
|
||
|
self.up6 = up_block(32, 1)
|
||
|
self.up7 = nn.Sequential(
|
||
|
nn.Conv2d(1, 2, kernel_size=4, dilation=2, padding=3),
|
||
|
nn.Sigmoid()
|
||
|
)
|
||
|
|
||
|
def forward(self, x):
|
||
|
d1_conv = self.down1_conv(x)
|
||
|
d1 = self.down1_act(d1_conv)
|
||
|
|
||
|
d2_conv = self.down2_conv(d1)
|
||
|
d2 = self.down2_act(d2_conv)
|
||
|
|
||
|
d3_conv = self.down3_conv(d2)
|
||
|
d3 = self.down3_act(d3_conv)
|
||
|
|
||
|
d4_conv = self.down4_conv(d3)
|
||
|
d4 = self.down4_act(d4_conv)
|
||
|
|
||
|
d5_conv = self.down5_conv(d4)
|
||
|
d5 = self.down5_act(d5_conv)
|
||
|
|
||
|
d6_conv = self.down6_conv(d5)
|
||
|
d6 = self.down6_act(d6_conv)
|
||
|
|
||
|
u1 = self.up1(d6)
|
||
|
u2 = self.up2(torch.cat([d5_conv, u1], axis=1))
|
||
|
u3 = self.up3(torch.cat([d4_conv, u2], axis=1))
|
||
|
u4 = self.up4(torch.cat([d3_conv, u3], axis=1))
|
||
|
u5 = self.up5(torch.cat([d2_conv, u4], axis=1))
|
||
|
u6 = self.up6(torch.cat([d1_conv, u5], axis=1))
|
||
|
u7 = self.up7(u6)
|
||
|
return u7 * x
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
net = UNet(14)
|
||
|
print(net(torch.rand(1, 14, 20, 48)).shape)
|