diff --git a/codes/models/archs/arch_util.py b/codes/models/archs/arch_util.py index 47a64b2d..c239bc60 100644 --- a/codes/models/archs/arch_util.py +++ b/codes/models/archs/arch_util.py @@ -374,7 +374,7 @@ class ExpansionBlock(nn.Module): filters_out = filters_in // 2 self.decimate = block(filters_in, filters_out, kernel_size=1, bias=False, activation=False, norm=True) self.process_passthrough = block(filters_out, filters_out, kernel_size=3, bias=True, activation=False, norm=True) - self.conjoin = block(filters_in, filters_out, kernel_size=3, bias=False, activation=True, norm=False) + self.conjoin = block(filters_out*2, filters_out, kernel_size=3, bias=False, activation=True, norm=False) self.process = block(filters_out, filters_out, kernel_size=3, bias=False, activation=True, norm=True) # input is the feature signal with shape (b, f, w, h)