Fix multifaceted chain gen

This commit is contained in:
James Betker 2020-10-22 13:27:06 -06:00
parent f9dc472f63
commit 40dc2938e8

View File

@ -226,8 +226,7 @@ class MultifacetedChainedEmbeddingGen(nn.Module):
# Integrate recurrence inputs. # Integrate recurrence inputs.
if teco_recurrent is not None: if teco_recurrent is not None:
teco_rec = torch.nn.functional.interpolate(teco_recurrent, scale_factor=2, mode='nearest') teco_rec = self.teco_recurrent_process(teco_recurrent)
teco_rec = self.teco_recurrent_process(teco_rec)
fea, std = self.teco_recurrent_join(fea, teco_rec) fea, std = self.teco_recurrent_join(fea, teco_rec)
self.teco_ref_std = std.item() self.teco_ref_std = std.item()
elif prog_recurrent is not None: elif prog_recurrent is not None: