From 9910c75d5a8ee43d9bcab87797c5f129be6c1ab3 Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 12 May 2024 07:52:54 -0500 Subject: [PATCH] checkpointing for bitnet impl --- vall_e/models/base.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 0d587d6..54c0070 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -56,6 +56,20 @@ except Exception as e: try: from bitnet.bit_transformer import Transformer as BitNetTransformerBlock, RMSNorm as BitNetRMSNorm + # override for wrapping checkpointing + def BitNetTransformerBlock_forward(self, x: Tensor, *args, **kwargs) -> Tensor: + skip = x + for attn, ffn in zip(self.layers, self.ffn_layers): + if x.requires_grad and self.activation_checkpointing: + x, _ = checkpoint(attn, x, x, x, is_causal=True, *args, **kwargs, use_reentrant=False) + else: + x, _ = attn(x, x, x, is_causal=True, *args, **kwargs) + x = x + skip + x = ffn(x) + x + return x + + BitNetTransformerBlock.forward = BitNetTransformerBlock_forward + # override because bitnet's BitNetTransformer includes an embedding input / classifier output layers inside of it, which isn't favorable class BitNetTransformer(nn.Module): def __init__( @@ -65,11 +79,13 @@ try: num_tokens: int, heads=8, ff_mult=4, + activation_checkpointing = True ): super().__init__() self.transformer = BitNetTransformerBlock( dim=dim, depth=depth, heads=heads, ff_mult=ff_mult ) self.norm = BitNetRMSNorm(dim) + self.transformer.activation_checkpointing = activation_checkpointing def forward(self, x): x = self.transformer(x) @@ -654,6 +670,7 @@ class Base(nn.Module): depth=n_layers, heads=n_heads, ff_mult=4, + activation_checkpointing=self.activation_checkpointing, ) else: raise RuntimeError(f'Unknown arch specified: {self.arch_type}')