From df45a9dec24576137b5daae3713baf46ba25cc5f Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 30 Oct 2021 16:59:18 -0600 Subject: [PATCH] Fix inference mode for lucidrains_gpt --- codes/models/gpt_voice/lucidrains_gpt.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/codes/models/gpt_voice/lucidrains_gpt.py b/codes/models/gpt_voice/lucidrains_gpt.py index 1b314af2..b218f7fc 100644 --- a/codes/models/gpt_voice/lucidrains_gpt.py +++ b/codes/models/gpt_voice/lucidrains_gpt.py @@ -228,9 +228,9 @@ class Transformer(nn.Module): assert(len(prev_intermediates) == self.depth) new_intermediates = [] for (attn, ff), (int_ff, int_out) in zip(self.layers.layers, prev_intermediates): - x = x + attn(x, only_last_two_elements=True) + x_ff = attn(x, only_last_two_elements=True) # Note that (x) is now only the last two element in the set. Conjoin it with the int_ff latent to compute the norm. - x_ff = torch.cat([int_ff[:,:-1], x], dim=1) - x = x + ff(x_ff, only_last_two_elements=True) + x_ff = x + torch.cat([int_ff[:,:-1], x_ff], dim=1) + x = x_ff + ff(x_ff, only_last_two_elements=True) new_intermediates.append((x_ff, x)) return x, new_intermediates