Fix inference mode for lucidrains_gpt
This commit is contained in:
parent
466b9fbcaa
commit
df45a9dec2
|
@ -228,9 +228,9 @@ class Transformer(nn.Module):
|
|||
assert(len(prev_intermediates) == self.depth)
|
||||
new_intermediates = []
|
||||
for (attn, ff), (int_ff, int_out) in zip(self.layers.layers, prev_intermediates):
|
||||
x = x + attn(x, only_last_two_elements=True)
|
||||
x_ff = attn(x, only_last_two_elements=True)
|
||||
# Note that (x) is now only the last two element in the set. Conjoin it with the int_ff latent to compute the norm.
|
||||
x_ff = torch.cat([int_ff[:,:-1], x], dim=1)
|
||||
x = x + ff(x_ff, only_last_two_elements=True)
|
||||
x_ff = x + torch.cat([int_ff[:,:-1], x_ff], dim=1)
|
||||
x = x_ff + ff(x_ff, only_last_two_elements=True)
|
||||
new_intermediates.append((x_ff, x))
|
||||
return x, new_intermediates
|
||||
|
|
Loading…
Reference in New Issue
Block a user