diff --git a/codes/models/archs/SPSR_arch.py b/codes/models/archs/SPSR_arch.py index e22fcc64..c98523f4 100644 --- a/codes/models/archs/SPSR_arch.py +++ b/codes/models/archs/SPSR_arch.py @@ -359,8 +359,8 @@ class SwitchedSpsr(nn.Module): class RefJoiner(nn.Module): def __init__(self, nf): super(RefJoiner, self).__init__() - self.lin1 = nn.Linear(512, 256) - self.lin2 = nn.Linear(256, nf) + self.lin1 = nn.Linear(nf * 8, nf * 4) + self.lin2 = nn.Linear(nf * 4, nf) self.join = ReferenceJoinBlock(nf, residual_weight_init_factor=.1) def forward(self, x, ref):