Fix recurrent std in arch

This commit is contained in:
James Betker 2020-10-12 17:42:32 -06:00
parent 05377973bf
commit ca523215c6

View File

@ -204,7 +204,9 @@ class SSGr1(SwitchModelBase):
x = self.model_fea_conv(x)
if self.recurrent:
rec = self.recurrent_process(recurrent)
x = self.recurrent_join(x, rec)
x, recurrent_join_std = self.recurrent_join(x, rec)
else:
recurrent_join_std = 0
x1, a1 = checkpoint(self.sw1, x, ref_embedding)
x_grad = self.grad_conv(x_grad)
@ -319,7 +321,7 @@ class SSGDeep(SwitchModelBase):
x = self.model_fea_conv(x)
if self.recurrent:
rec = self.recurrent_process(recurrent)
x = self.recurrent_join(x, rec)
x, recurrent_std = self.recurrent_join(x, rec)
x1, a1 = checkpoint(self.sw1, x, ref_embedding)
x2, a2 = checkpoint(self.sw2, x1, ref_embedding)