Fix multifaceted chain gen

This commit is contained in:
James Betker 2020-10-22 13:27:06 -06:00
parent f9dc472f63
commit 40dc2938e8

View File

@ -226,8 +226,7 @@ class MultifacetedChainedEmbeddingGen(nn.Module):
# Integrate recurrence inputs.
if teco_recurrent is not None:
teco_rec = torch.nn.functional.interpolate(teco_recurrent, scale_factor=2, mode='nearest')
teco_rec = self.teco_recurrent_process(teco_rec)
teco_rec = self.teco_recurrent_process(teco_recurrent)
fea, std = self.teco_recurrent_join(fea, teco_rec)
self.teco_ref_std = std.item()
elif prog_recurrent is not None: