Fix feedforward
This commit is contained in:
parent
b476516340
commit
95ca88efce
|
@ -6,6 +6,7 @@ import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
# helpers
|
# helpers
|
||||||
|
from models.arch_util import checkpoint
|
||||||
from models.gpt_voice.reversible import ReversibleSequence, SequentialSequence
|
from models.gpt_voice.reversible import ReversibleSequence, SequentialSequence
|
||||||
from utils.util import sequential_checkpoint
|
from utils.util import sequential_checkpoint
|
||||||
|
|
||||||
|
@ -208,8 +209,8 @@ class Transformer(nn.Module):
|
||||||
def forward(self, x, return_intermediates=False):
|
def forward(self, x, return_intermediates=False):
|
||||||
intermediates = []
|
intermediates = []
|
||||||
for attn, ff in self.layers.layers:
|
for attn, ff in self.layers.layers:
|
||||||
x_ff = attn(x)
|
x_ff = x + checkpoint(attn, x)
|
||||||
x = ff(x_ff)
|
x = x + ff(x_ff)
|
||||||
if return_intermediates:
|
if return_intermediates:
|
||||||
intermediates.append((x_ff, x))
|
intermediates.append((x_ff, x))
|
||||||
if return_intermediates:
|
if return_intermediates:
|
||||||
|
@ -228,9 +229,9 @@ class Transformer(nn.Module):
|
||||||
assert(len(prev_intermediates) == self.depth)
|
assert(len(prev_intermediates) == self.depth)
|
||||||
new_intermediates = []
|
new_intermediates = []
|
||||||
for (attn, ff), (int_ff, int_out) in zip(self.layers.layers, prev_intermediates):
|
for (attn, ff), (int_ff, int_out) in zip(self.layers.layers, prev_intermediates):
|
||||||
x = attn(x, only_last_two_elements=True)
|
x = x + attn(x, only_last_two_elements=True)
|
||||||
# Note that (x) is now only the last two element in the set. Conjoin it with the int_ff latent to compute the norm.
|
# Note that (x) is now only the last two element in the set. Conjoin it with the int_ff latent to compute the norm.
|
||||||
x_ff = torch.cat([int_ff[:,:-1], x], dim=1)
|
x_ff = torch.cat([int_ff[:,:-1], x], dim=1)
|
||||||
x = ff(x_ff, only_last_two_elements=True)
|
x = x + ff(x_ff, only_last_two_elements=True)
|
||||||
new_intermediates.append((x_ff, x))
|
new_intermediates.append((x_ff, x))
|
||||||
return x, new_intermediates
|
return x, new_intermediates
|
||||||
|
|
Loading…
Reference in New Issue
Block a user