2024-06-06 01:30:43 +00:00
|
|
|
# https://github.com/kyegomez/BitNet
|
|
|
|
from torch import Tensor, nn
|
2024-06-06 14:48:43 +00:00
|
|
|
from torch.utils.checkpoint import checkpoint
|
|
|
|
|
2024-06-06 01:30:43 +00:00
|
|
|
from bitnet.bit_transformer import Transformer as BitNetTransformerBlock, RMSNorm as BitNetRMSNorm
|
|
|
|
|
|
|
|
# re-enable logging because zetascale fucking sucks
|
|
|
|
import logging
|
|
|
|
logging.getLogger().setLevel(logging.DEBUG)
|
|
|
|
|
|
|
|
# override for wrapping checkpointing
|
|
|
|
def BitNetTransformerBlock_forward(self, x: Tensor, *args, **kwargs) -> Tensor:
|
|
|
|
skip = x
|
|
|
|
for attn, ffn in zip(self.layers, self.ffn_layers):
|
|
|
|
if x.requires_grad and self.gradient_checkpointing:
|
|
|
|
x, _ = checkpoint(attn, x, x, x, is_causal=True, *args, **kwargs, use_reentrant=False)
|
|
|
|
else:
|
|
|
|
x, _ = attn(x, x, x, is_causal=True, *args, **kwargs)
|
|
|
|
x = x + skip
|
|
|
|
x = ffn(x) + x
|
|
|
|
return x
|
|
|
|
|
|
|
|
BitNetTransformerBlock.forward = BitNetTransformerBlock_forward
|
|
|
|
|
|
|
|
# override because bitnet's BitNetTransformer includes an embedding input / classifier output layers inside of it, which isn't favorable
|
|
|
|
class BitNetTransformer(nn.Module):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
dim: int,
|
|
|
|
depth: int,
|
|
|
|
num_tokens: int,
|
|
|
|
heads=8,
|
|
|
|
ff_mult=4,
|
|
|
|
gradient_checkpointing = True
|
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
self.transformer = BitNetTransformerBlock( dim=dim, depth=depth, heads=heads, ff_mult=ff_mult )
|
|
|
|
self.norm = BitNetRMSNorm(dim)
|
|
|
|
self.transformer.gradient_checkpointing = gradient_checkpointing
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.transformer(x)
|
|
|
|
return self.norm( x )
|
|
|
|
|
|
|
|
"""
|
|
|
|
from bitnet import BitNetTransformer
|
|
|
|
def NoEmbedding_BitNetTransformer_Forward(self, x):
|
|
|
|
x = self.transformer(x)
|
|
|
|
return self.to_logits[0](x)
|
|
|
|
|
|
|
|
BitNetTransformer.forward = NoEmbedding_BitNetTransformer_Forward
|
|
|
|
"""
|