This commit is contained in:
James Betker 2021-10-29 17:29:49 -06:00
parent 95ca88efce
commit 92fe8b4dd9

View File

@ -6,9 +6,8 @@ import torch.nn.functional as F
from einops import rearrange
# helpers
from models.arch_util import checkpoint
from models.gpt_voice.reversible import ReversibleSequence, SequentialSequence
from utils.util import sequential_checkpoint
from utils.util import checkpoint
def exists(val):
@ -210,7 +209,7 @@ class Transformer(nn.Module):
intermediates = []
for attn, ff in self.layers.layers:
x_ff = x + checkpoint(attn, x)
x = x + ff(x_ff)
x = x_ff + ff(x_ff)
if return_intermediates:
intermediates.append((x_ff, x))
if return_intermediates: