Fix ref branch using fixed filters

This commit is contained in:
James Betker 2020-09-11 08:58:35 -06:00
parent 8c469b8286
commit 1086f0476b

View File

@ -359,8 +359,8 @@ class SwitchedSpsr(nn.Module):
class RefJoiner(nn.Module): class RefJoiner(nn.Module):
def __init__(self, nf): def __init__(self, nf):
super(RefJoiner, self).__init__() super(RefJoiner, self).__init__()
self.lin1 = nn.Linear(512, 256) self.lin1 = nn.Linear(nf * 8, nf * 4)
self.lin2 = nn.Linear(256, nf) self.lin2 = nn.Linear(nf * 4, nf)
self.join = ReferenceJoinBlock(nf, residual_weight_init_factor=.1) self.join = ReferenceJoinBlock(nf, residual_weight_init_factor=.1)
def forward(self, x, ref): def forward(self, x, ref):