diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index 6fcdfab0..58b02c86 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -133,7 +133,7 @@ class ReferenceImageBranch(nn.Module): class ReferencingConvMultiplexer(nn.Module): def __init__(self, input_channels, base_filters, multiplexer_channels, use_gn=True): super(ReferencingConvMultiplexer, self).__init__() - self.filter_conv = ConvGnSilu(input_channels, multiplexer_channels, bias=True) + self.filter_conv = ConvGnSilu(input_channels, base_filters, bias=True) self.ref_proc = nn.Linear(512, 512) self.ref_red = nn.Linear(512, base_filters * 2) self.feature_norm = torch.nn.InstanceNorm2d(base_filters)