diff --git a/codes/models/archs/arch_util.py b/codes/models/archs/arch_util.py index 86e10e0f..47a64b2d 100644 --- a/codes/models/archs/arch_util.py +++ b/codes/models/archs/arch_util.py @@ -368,12 +368,14 @@ class ConvGnSilu(nn.Module): # Block that upsamples 2x and reduces incoming filters by 2x. It preserves structure by taking a passthrough feed # along with the feature representation. class ExpansionBlock(nn.Module): - def __init__(self, filters, block=ConvGnSilu): + def __init__(self, filters_in, filters_out=None, block=ConvGnSilu): super(ExpansionBlock, self).__init__() - self.decimate = block(filters, filters // 2, kernel_size=1, bias=False, activation=False, norm=True) - self.process_passthrough = block(filters // 2, filters // 2, kernel_size=3, bias=True, activation=False, norm=True) - self.conjoin = block(filters, filters // 2, kernel_size=3, bias=False, activation=True, norm=False) - self.process = block(filters // 2, filters // 2, kernel_size=3, bias=False, activation=True, norm=True) + if filters_out is None: + filters_out = filters_in // 2 + self.decimate = block(filters_in, filters_out, kernel_size=1, bias=False, activation=False, norm=True) + self.process_passthrough = block(filters_out, filters_out, kernel_size=3, bias=True, activation=False, norm=True) + self.conjoin = block(filters_in, filters_out, kernel_size=3, bias=False, activation=True, norm=False) + self.process = block(filters_out, filters_out, kernel_size=3, bias=False, activation=True, norm=True) # input is the feature signal with shape (b, f, w, h) # passthrough is the structure signal with shape (b, f/2, w*2, h*2) diff --git a/codes/models/archs/discriminator_vgg_arch.py b/codes/models/archs/discriminator_vgg_arch.py index 0a109247..559c2f16 100644 --- a/codes/models/archs/discriminator_vgg_arch.py +++ b/codes/models/archs/discriminator_vgg_arch.py @@ -192,17 +192,17 @@ class Discriminator_UNet(nn.Module): self.conv4_0 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, bias=False) self.conv4_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False) - self.up1 = ExpansionBlock(nf * 8, block=ConvGnLelu) - self.proc1 = ConvGnLelu(nf * 4, nf * 4, bias=False) - self.collapse1 = ConvGnLelu(nf * 4, 1, bias=True, norm=False, activation=False) + self.up1 = ExpansionBlock(nf * 8, nf * 8, block=ConvGnLelu) + self.proc1 = ConvGnLelu(nf * 8, nf * 8, bias=False) + self.collapse1 = ConvGnLelu(nf * 8, 1, bias=True, norm=False, activation=False) - self.up2 = ExpansionBlock(nf * 4, block=ConvGnLelu) - self.proc2 = ConvGnLelu(nf * 2, nf * 2, bias=False) - self.collapse2 = ConvGnLelu(nf * 2, 1, bias=True, norm=False, activation=False) + self.up2 = ExpansionBlock(nf * 8, nf * 4, block=ConvGnLelu) + self.proc2 = ConvGnLelu(nf * 4, nf * 4, bias=False) + self.collapse2 = ConvGnLelu(nf * 4, 1, bias=True, norm=False, activation=False) - self.up3 = ExpansionBlock(nf * 2, block=ConvGnLelu) - self.proc3 = ConvGnLelu(nf, nf, bias=False) - self.collapse3 = ConvGnLelu(nf, 1, bias=True, norm=False, activation=False) + self.up3 = ExpansionBlock(nf * 4, nf * 2, block=ConvGnLelu) + self.proc3 = ConvGnLelu(nf * 2, nf * 2, bias=False) + self.collapse3 = ConvGnLelu(nf * 2, 1, bias=True, norm=False, activation=False) def forward(self, x, flatten=True): x = x[0]