Fix for referencingmultiplexer

This commit is contained in:
James Betker 2020-08-25 15:43:12 -06:00
parent 19487d9bbd
commit a1800f45ef

View File

@ -133,7 +133,7 @@ class ReferenceImageBranch(nn.Module):
class ReferencingConvMultiplexer(nn.Module): class ReferencingConvMultiplexer(nn.Module):
def __init__(self, input_channels, base_filters, multiplexer_channels, use_gn=True): def __init__(self, input_channels, base_filters, multiplexer_channels, use_gn=True):
super(ReferencingConvMultiplexer, self).__init__() super(ReferencingConvMultiplexer, self).__init__()
self.filter_conv = ConvGnSilu(input_channels, multiplexer_channels, bias=True) self.filter_conv = ConvGnSilu(input_channels, base_filters, bias=True)
self.ref_proc = nn.Linear(512, 512) self.ref_proc = nn.Linear(512, 512)
self.ref_red = nn.Linear(512, base_filters * 2) self.ref_red = nn.Linear(512, base_filters * 2)
self.feature_norm = torch.nn.InstanceNorm2d(base_filters) self.feature_norm = torch.nn.InstanceNorm2d(base_filters)