diff --git a/codes/models/gpt_voice/lucidrains_gpt.py b/codes/models/gpt_voice/lucidrains_gpt.py index 08646884..8aad83e5 100644 --- a/codes/models/gpt_voice/lucidrains_gpt.py +++ b/codes/models/gpt_voice/lucidrains_gpt.py @@ -6,6 +6,7 @@ import torch.nn.functional as F from einops import rearrange # helpers +from models.arch_util import checkpoint from models.gpt_voice.reversible import ReversibleSequence, SequentialSequence from utils.util import sequential_checkpoint @@ -208,8 +209,8 @@ class Transformer(nn.Module): def forward(self, x, return_intermediates=False): intermediates = [] for attn, ff in self.layers.layers: - x_ff = attn(x) - x = ff(x_ff) + x_ff = x + checkpoint(attn, x) + x = x + ff(x_ff) if return_intermediates: intermediates.append((x_ff, x)) if return_intermediates: @@ -228,9 +229,9 @@ class Transformer(nn.Module): assert(len(prev_intermediates) == self.depth) new_intermediates = [] for (attn, ff), (int_ff, int_out) in zip(self.layers.layers, prev_intermediates): - x = attn(x, only_last_two_elements=True) + x = x + attn(x, only_last_two_elements=True) # Note that (x) is now only the last two element in the set. Conjoin it with the int_ff latent to compute the norm. x_ff = torch.cat([int_ff[:,:-1], x], dim=1) - x = ff(x_ff, only_last_two_elements=True) + x = x + ff(x_ff, only_last_two_elements=True) new_intermediates.append((x_ff, x)) return x, new_intermediates