checkpointing for bitnet impl

This commit is contained in:
mrq 2024-05-12 07:52:54 -05:00
parent 14709ac67f
commit 9910c75d5a

View File

@ -56,6 +56,20 @@ except Exception as e:
try:
from bitnet.bit_transformer import Transformer as BitNetTransformerBlock, RMSNorm as BitNetRMSNorm
# override for wrapping checkpointing
def BitNetTransformerBlock_forward(self, x: Tensor, *args, **kwargs) -> Tensor:
skip = x
for attn, ffn in zip(self.layers, self.ffn_layers):
if x.requires_grad and self.activation_checkpointing:
x, _ = checkpoint(attn, x, x, x, is_causal=True, *args, **kwargs, use_reentrant=False)
else:
x, _ = attn(x, x, x, is_causal=True, *args, **kwargs)
x = x + skip
x = ffn(x) + x
return x
BitNetTransformerBlock.forward = BitNetTransformerBlock_forward
# override because bitnet's BitNetTransformer includes an embedding input / classifier output layers inside of it, which isn't favorable
class BitNetTransformer(nn.Module):
def __init__(
@ -65,11 +79,13 @@ try:
num_tokens: int,
heads=8,
ff_mult=4,
activation_checkpointing = True
):
super().__init__()
self.transformer = BitNetTransformerBlock( dim=dim, depth=depth, heads=heads, ff_mult=ff_mult )
self.norm = BitNetRMSNorm(dim)
self.transformer.activation_checkpointing = activation_checkpointing
def forward(self, x):
x = self.transformer(x)
@ -654,6 +670,7 @@ class Base(nn.Module):
depth=n_layers,
heads=n_heads,
ff_mult=4,
activation_checkpointing=self.activation_checkpointing,
)
else:
raise RuntimeError(f'Unknown arch specified: {self.arch_type}')