Merge remote-tracking branch 'origin/gan_lab' into gan_lab

This commit is contained in:
James Betker 2020-10-12 17:43:51 -06:00
commit 4d52374e60

View File

@ -204,7 +204,9 @@ class SSGr1(SwitchModelBase):
x = self.model_fea_conv(x) x = self.model_fea_conv(x)
if self.recurrent: if self.recurrent:
rec = self.recurrent_process(recurrent) rec = self.recurrent_process(recurrent)
x = self.recurrent_join(x, rec) x, recurrent_join_std = self.recurrent_join(x, rec)
else:
recurrent_join_std = 0
x1, a1 = checkpoint(self.sw1, x, ref_embedding) x1, a1 = checkpoint(self.sw1, x, ref_embedding)
x_grad = self.grad_conv(x_grad) x_grad = self.grad_conv(x_grad)
@ -319,7 +321,7 @@ class SSGDeep(SwitchModelBase):
x = self.model_fea_conv(x) x = self.model_fea_conv(x)
if self.recurrent: if self.recurrent:
rec = self.recurrent_process(recurrent) rec = self.recurrent_process(recurrent)
x = self.recurrent_join(x, rec) x, recurrent_std = self.recurrent_join(x, rec)
x1, a1 = checkpoint(self.sw1, x, ref_embedding) x1, a1 = checkpoint(self.sw1, x, ref_embedding)
x2, a2 = checkpoint(self.sw2, x1, ref_embedding) x2, a2 = checkpoint(self.sw2, x1, ref_embedding)