checkpointing for bitnet impl
This commit is contained in:
parent
14709ac67f
commit
9910c75d5a
|
@ -56,6 +56,20 @@ except Exception as e:
|
|||
try:
|
||||
from bitnet.bit_transformer import Transformer as BitNetTransformerBlock, RMSNorm as BitNetRMSNorm
|
||||
|
||||
# override for wrapping checkpointing
|
||||
def BitNetTransformerBlock_forward(self, x: Tensor, *args, **kwargs) -> Tensor:
|
||||
skip = x
|
||||
for attn, ffn in zip(self.layers, self.ffn_layers):
|
||||
if x.requires_grad and self.activation_checkpointing:
|
||||
x, _ = checkpoint(attn, x, x, x, is_causal=True, *args, **kwargs, use_reentrant=False)
|
||||
else:
|
||||
x, _ = attn(x, x, x, is_causal=True, *args, **kwargs)
|
||||
x = x + skip
|
||||
x = ffn(x) + x
|
||||
return x
|
||||
|
||||
BitNetTransformerBlock.forward = BitNetTransformerBlock_forward
|
||||
|
||||
# override because bitnet's BitNetTransformer includes an embedding input / classifier output layers inside of it, which isn't favorable
|
||||
class BitNetTransformer(nn.Module):
|
||||
def __init__(
|
||||
|
@ -65,11 +79,13 @@ try:
|
|||
num_tokens: int,
|
||||
heads=8,
|
||||
ff_mult=4,
|
||||
activation_checkpointing = True
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.transformer = BitNetTransformerBlock( dim=dim, depth=depth, heads=heads, ff_mult=ff_mult )
|
||||
self.norm = BitNetRMSNorm(dim)
|
||||
self.transformer.activation_checkpointing = activation_checkpointing
|
||||
|
||||
def forward(self, x):
|
||||
x = self.transformer(x)
|
||||
|
@ -654,6 +670,7 @@ class Base(nn.Module):
|
|||
depth=n_layers,
|
||||
heads=n_heads,
|
||||
ff_mult=4,
|
||||
activation_checkpointing=self.activation_checkpointing,
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
|
||||
|
|
Loading…
Reference in New Issue
Block a user