forked from mrq/DL-Art-School
Fix for referencingmultiplexer
This commit is contained in:
parent
19487d9bbd
commit
a1800f45ef
|
@ -133,7 +133,7 @@ class ReferenceImageBranch(nn.Module):
|
||||||
class ReferencingConvMultiplexer(nn.Module):
|
class ReferencingConvMultiplexer(nn.Module):
|
||||||
def __init__(self, input_channels, base_filters, multiplexer_channels, use_gn=True):
|
def __init__(self, input_channels, base_filters, multiplexer_channels, use_gn=True):
|
||||||
super(ReferencingConvMultiplexer, self).__init__()
|
super(ReferencingConvMultiplexer, self).__init__()
|
||||||
self.filter_conv = ConvGnSilu(input_channels, multiplexer_channels, bias=True)
|
self.filter_conv = ConvGnSilu(input_channels, base_filters, bias=True)
|
||||||
self.ref_proc = nn.Linear(512, 512)
|
self.ref_proc = nn.Linear(512, 512)
|
||||||
self.ref_red = nn.Linear(512, base_filters * 2)
|
self.ref_red = nn.Linear(512, base_filters * 2)
|
||||||
self.feature_norm = torch.nn.InstanceNorm2d(base_filters)
|
self.feature_norm = torch.nn.InstanceNorm2d(base_filters)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user