This commit is contained in:
James Betker 2021-10-29 17:29:49 -06:00
parent 95ca88efce
commit 92fe8b4dd9

View File

@ -6,9 +6,8 @@ import torch.nn.functional as F
from einops import rearrange from einops import rearrange
# helpers # helpers
from models.arch_util import checkpoint
from models.gpt_voice.reversible import ReversibleSequence, SequentialSequence from models.gpt_voice.reversible import ReversibleSequence, SequentialSequence
from utils.util import sequential_checkpoint from utils.util import checkpoint
def exists(val): def exists(val):
@ -210,7 +209,7 @@ class Transformer(nn.Module):
intermediates = [] intermediates = []
for attn, ff in self.layers.layers: for attn, ff in self.layers.layers:
x_ff = x + checkpoint(attn, x) x_ff = x + checkpoint(attn, x)
x = x + ff(x_ff) x = x_ff + ff(x_ff)
if return_intermediates: if return_intermediates:
intermediates.append((x_ff, x)) intermediates.append((x_ff, x))
if return_intermediates: if return_intermediates: