This commit is contained in:
James Betker 2020-07-10 22:57:34 -06:00
parent b3a2c21250
commit 020b3361fa
2 changed files with 16 additions and 14 deletions

View File

@ -368,12 +368,14 @@ class ConvGnSilu(nn.Module):
# Block that upsamples 2x and reduces incoming filters by 2x. It preserves structure by taking a passthrough feed
# along with the feature representation.
class ExpansionBlock(nn.Module):
def __init__(self, filters, block=ConvGnSilu):
def __init__(self, filters_in, filters_out=None, block=ConvGnSilu):
super(ExpansionBlock, self).__init__()
self.decimate = block(filters, filters // 2, kernel_size=1, bias=False, activation=False, norm=True)
self.process_passthrough = block(filters // 2, filters // 2, kernel_size=3, bias=True, activation=False, norm=True)
self.conjoin = block(filters, filters // 2, kernel_size=3, bias=False, activation=True, norm=False)
self.process = block(filters // 2, filters // 2, kernel_size=3, bias=False, activation=True, norm=True)
if filters_out is None:
filters_out = filters_in // 2
self.decimate = block(filters_in, filters_out, kernel_size=1, bias=False, activation=False, norm=True)
self.process_passthrough = block(filters_out, filters_out, kernel_size=3, bias=True, activation=False, norm=True)
self.conjoin = block(filters_in, filters_out, kernel_size=3, bias=False, activation=True, norm=False)
self.process = block(filters_out, filters_out, kernel_size=3, bias=False, activation=True, norm=True)
# input is the feature signal with shape (b, f, w, h)
# passthrough is the structure signal with shape (b, f/2, w*2, h*2)

View File

@ -192,17 +192,17 @@ class Discriminator_UNet(nn.Module):
self.conv4_0 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, bias=False)
self.conv4_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False)
self.up1 = ExpansionBlock(nf * 8, block=ConvGnLelu)
self.proc1 = ConvGnLelu(nf * 4, nf * 4, bias=False)
self.collapse1 = ConvGnLelu(nf * 4, 1, bias=True, norm=False, activation=False)
self.up1 = ExpansionBlock(nf * 8, nf * 8, block=ConvGnLelu)
self.proc1 = ConvGnLelu(nf * 8, nf * 8, bias=False)
self.collapse1 = ConvGnLelu(nf * 8, 1, bias=True, norm=False, activation=False)
self.up2 = ExpansionBlock(nf * 4, block=ConvGnLelu)
self.proc2 = ConvGnLelu(nf * 2, nf * 2, bias=False)
self.collapse2 = ConvGnLelu(nf * 2, 1, bias=True, norm=False, activation=False)
self.up2 = ExpansionBlock(nf * 8, nf * 4, block=ConvGnLelu)
self.proc2 = ConvGnLelu(nf * 4, nf * 4, bias=False)
self.collapse2 = ConvGnLelu(nf * 4, 1, bias=True, norm=False, activation=False)
self.up3 = ExpansionBlock(nf * 2, block=ConvGnLelu)
self.proc3 = ConvGnLelu(nf, nf, bias=False)
self.collapse3 = ConvGnLelu(nf, 1, bias=True, norm=False, activation=False)
self.up3 = ExpansionBlock(nf * 4, nf * 2, block=ConvGnLelu)
self.proc3 = ConvGnLelu(nf * 2, nf * 2, bias=False)
self.collapse3 = ConvGnLelu(nf * 2, 1, bias=True, norm=False, activation=False)
def forward(self, x, flatten=True):
x = x[0]