Fix for referencingmultiplexer
This commit is contained in:
parent
19487d9bbd
commit
a1800f45ef
|
@ -133,7 +133,7 @@ class ReferenceImageBranch(nn.Module):
|
|||
class ReferencingConvMultiplexer(nn.Module):
|
||||
def __init__(self, input_channels, base_filters, multiplexer_channels, use_gn=True):
|
||||
super(ReferencingConvMultiplexer, self).__init__()
|
||||
self.filter_conv = ConvGnSilu(input_channels, multiplexer_channels, bias=True)
|
||||
self.filter_conv = ConvGnSilu(input_channels, base_filters, bias=True)
|
||||
self.ref_proc = nn.Linear(512, 512)
|
||||
self.ref_red = nn.Linear(512, base_filters * 2)
|
||||
self.feature_norm = torch.nn.InstanceNorm2d(base_filters)
|
||||
|
|
Loading…
Reference in New Issue
Block a user