diff --git a/codes/models/gpt_voice/lucidrains_gpt.py b/codes/models/gpt_voice/lucidrains_gpt.py index 8aad83e5..1b314af2 100644 --- a/codes/models/gpt_voice/lucidrains_gpt.py +++ b/codes/models/gpt_voice/lucidrains_gpt.py @@ -6,9 +6,8 @@ import torch.nn.functional as F from einops import rearrange # helpers -from models.arch_util import checkpoint from models.gpt_voice.reversible import ReversibleSequence, SequentialSequence -from utils.util import sequential_checkpoint +from utils.util import checkpoint def exists(val): @@ -210,7 +209,7 @@ class Transformer(nn.Module): intermediates = [] for attn, ff in self.layers.layers: x_ff = x + checkpoint(attn, x) - x = x + ff(x_ff) + x = x_ff + ff(x_ff) if return_intermediates: intermediates.append((x_ff, x)) if return_intermediates: