diff --git a/bitsandbytes/nn/triton_based_modules.py b/bitsandbytes/nn/triton_based_modules.py index 9fe0b69..0344464 100644 --- a/bitsandbytes/nn/triton_based_modules.py +++ b/bitsandbytes/nn/triton_based_modules.py @@ -133,7 +133,7 @@ class SwitchBackGlobalLinear(nn.Linear): -class LinearFunction(torch.autograd.Function): +class StandardLinearFunction(torch.autograd.Function): @staticmethod def forward(ctx, input, weight, bias=None): X = input.view(-1, input.size(-1)) @@ -161,87 +161,8 @@ class LinearFunction(torch.autograd.Function): return grad_input, grad_weight, grad_bias -class MyLinear(nn.Linear): +class StandardLinear(nn.Linear): def forward(self, x): - return LinearFunction.apply(x, self.weight, self.bias) + return StandardLinearFunction.apply(x, self.weight, self.bias) - - - -class _switchback_mlp(torch.autograd.Function): - - - @staticmethod - def forward(ctx, X_3D, W1, B1, W2, B2): - - X1 = X_3D.view(-1, X_3D.size(-1)) - - X1_int8, state_X1 = quantize_rowwise_nogroup(X1) - W1_int8, state_W1 = quantize_global(W1) - - X2_pre = int8_matmul_mixed_dequanitze_bias( - X1_int8, W1_int8.t(), state_X1, state_W1, B1 - ) - - # X2_v1 = torch.nn.functional.gelu(X2) - # X2_int8, state_X2, = quantize_rowwise_nogroup(X2_v1) - X2_int8, state_X2, X2 = quantize_rowwise_nogroup_gelu(X2_pre) - - W2_int8, state_W2 = quantize_global(W2) - - out = int8_matmul_mixed_dequanitze_bias( - X2_int8, W2_int8.t(), state_X2, state_W2, B2 - ) - - ctx.save_for_backward = X1, W1, X2, X2_pre, W2 - - return out.view(*X_3D.size()[:-1], -1) - - @staticmethod - def backward(ctx, G_3D): - - G2 = G_3D.reshape(-1, G_3D.size(-1)) - - grad_X1 = grad_W1 = grad_B1 = grad_W2 = grad_B2 = None - - X1, W1, X2, X2_pre, W2 = ctx.save_for_backward - - G2_int8, state_G2 = quantize_rowwise_nogroup(G2) - W2_int8, state_W2 = quantize_global_transpose(W2) - - G1 = int8_matmul_mixed_dequanitze(G2_int8, W2_int8.t(), state_G2, state_W2).view( - *G_3D.size()[:-1], -1 - ) - - grad_W2 = torch.matmul(G2.t(), X2.to(G2.dtype)) - grad_B2 = G2.sum(dim=0) - - G1_int8, state_G1, G1 = quantize_rowwise_nogroup_back_gelu(G1, X2_pre) - - if ctx.needs_input_grad[0]: - - W1_int8, state_W1 = quantize_global_transpose(W1) - grad_X1 = int8_matmul_mixed_dequanitze(G1_int8, W1_int8.t(), state_G1, state_W1).view( - *G_3D.size()[:-1], -1 - ) - if ctx.needs_input_grad[1]: - grad_W1 = torch.matmul(G1.t(), X1.to(G1.dtype)) - if ctx.needs_input_grad[2]: - grad_B1 = G1.sum(dim=0) - - return grad_X1, grad_W1, grad_B1, grad_W2, grad_B2 - - -class SwitchBackGlobalMLP(nn.Module): - - - def __init__(self, dim_in, dim_hidden): - super().__init__() - self.linear1 = nn.Linear(dim_in, dim_hidden) - self.linear2 = nn.Linear(dim_hidden, dim_in) - - - def forward(self, x): - return _switchback_mlp.apply(x, self.linear1.weight, self.linear1.bias, self.linear2.weight, self.linear2.bias) - \ No newline at end of file diff --git a/tests/triton_tests/attn_decomp.py b/tests/triton_tests/attn_decomp.py index fa86995..b70bceb 100644 --- a/tests/triton_tests/attn_decomp.py +++ b/tests/triton_tests/attn_decomp.py @@ -1,7 +1,7 @@ import torch import json -from bitsandbytes.nn.triton_based_modules import SwitchBackGlobalMLP, SwitchBackGlobalLinear, MyLinear +from bitsandbytes.nn.triton_based_modules import SwitchBackGlobalMLP, SwitchBackGlobalLinear, StandardLinear import time # class AttentionOld(torch.nn.Module): @@ -116,7 +116,7 @@ if __name__ == '__main__': va = torch.randn( batch // 256, 256, dim ).cuda().requires_grad_(True) standard = Attention(dim).cuda() - my_standard = Attention(dim, linear_module=MyLinear).cuda() + my_standard = Attention(dim, linear_module=StandardLinear).cuda() sb = Attention(dim, linear_module=SwitchBackGlobalLinear).cuda() standard_compiled = torch.compile(standard) ln_model = torch.nn.Sequential( @@ -360,4 +360,4 @@ if __name__ == '__main__': # import pdb; pdb.set_trace() - # # NO GELU, ST GRADIENTS, EVERYTHING FINE. \ No newline at end of file + # # NO GELU, ST GRADIENTS, EVERYTHING FINE. diff --git a/tests/triton_tests/full_matrix_decomp.py b/tests/triton_tests/full_matrix_decomp.py index de37b95..e2932d4 100644 --- a/tests/triton_tests/full_matrix_decomp.py +++ b/tests/triton_tests/full_matrix_decomp.py @@ -4,7 +4,7 @@ import time import torch import torch.nn as nn import bitsandbytes.nn as bnn -from bitsandbytes.nn.triton_based_modules import SwitchBackLinear, SwitchBackGlobalLinear, MyLinear +from bitsandbytes.nn.triton_based_modules import SwitchBackLinear, SwitchBackGlobalLinear, StandardLinear from bitsandbytes.nn.triton_utils.v0.quantize_rowwise_nogroup import quantize_rowwise_nogroup from bitsandbytes.nn.triton_utils.v0.quantize_columnwise_nogroup_transpose import quantize_columnwise_nogroup_transpose @@ -350,4 +350,4 @@ if __name__ == '__main__': with open("tests/triton_tests/info.jsonl", "a") as file: - file.write(info_json + "\n") \ No newline at end of file + file.write(info_json + "\n") diff --git a/tests/triton_tests/mlp.py b/tests/triton_tests/mlp.py index 1ec85b8..8aef105 100644 --- a/tests/triton_tests/mlp.py +++ b/tests/triton_tests/mlp.py @@ -3,7 +3,7 @@ import time import torch import torch.nn as nn import bitsandbytes.nn as bnn -from bitsandbytes.nn.triton_based_modules import SwitchBackLinear, SwitchBackGlobalLinear, MyLinear +from bitsandbytes.nn.triton_based_modules import SwitchBackLinear, SwitchBackGlobalLinear, StandardLinear def construct_model(dim, layers, module): modules = [] @@ -41,7 +41,7 @@ if __name__ == '__main__': # construct models standard = construct_model(dim, layers, nn.Linear).half() - my_standard = construct_model(dim, layers, MyLinear).half() + my_standard = construct_model(dim, layers, StandardLinear).half() switchback = construct_model(dim, layers, SwitchBackLinear).half() switchback_global = construct_model(dim, layers, SwitchBackGlobalLinear).half() #bnb_8bitmixed = construct_model(dim, layers, bnn.Linear8bitLt) @@ -61,4 +61,4 @@ if __name__ == '__main__': - \ No newline at end of file + diff --git a/tests/triton_tests/mlp_decomp_autocast.py b/tests/triton_tests/mlp_decomp_autocast.py index 3a1fc9e..54bd5f5 100644 --- a/tests/triton_tests/mlp_decomp_autocast.py +++ b/tests/triton_tests/mlp_decomp_autocast.py @@ -1,7 +1,7 @@ import torch import json -from bitsandbytes.nn.triton_based_modules import SwitchBackGlobalMLP, SwitchBackGlobalLinear, MyLinear +from bitsandbytes.nn.triton_based_modules import SwitchBackGlobalMLP, SwitchBackGlobalLinear, StandardLinear import time if __name__ == '__main__': @@ -26,9 +26,9 @@ if __name__ == '__main__': ).cuda() my_standard = torch.nn.Sequential( - MyLinear(dim, 4 * dim), + StandardLinear(dim, 4 * dim), torch.nn.GELU(), - MyLinear(4 * dim, dim), + StandardLinear(4 * dim, dim), ).cuda() fused_mlp = SwitchBackGlobalMLP(dim, 4 * dim).cuda() @@ -163,4 +163,4 @@ if __name__ == '__main__': # import pdb; pdb.set_trace() - # # NO GELU, ST GRADIENTS, EVERYTHING FINE. \ No newline at end of file + # # NO GELU, ST GRADIENTS, EVERYTHING FINE. diff --git a/tests/triton_tests/mlp_decomp_autocast_ln.py b/tests/triton_tests/mlp_decomp_autocast_ln.py index 2596278..0a50cab 100644 --- a/tests/triton_tests/mlp_decomp_autocast_ln.py +++ b/tests/triton_tests/mlp_decomp_autocast_ln.py @@ -1,7 +1,7 @@ import torch import json -from bitsandbytes.nn.triton_based_modules import SwitchBackGlobalMLP, SwitchBackGlobalLinear, MyLinear +from bitsandbytes.nn.triton_based_modules import SwitchBackGlobalMLP, SwitchBackGlobalLinear, StandardLinear import time if __name__ == '__main__': @@ -24,9 +24,9 @@ if __name__ == '__main__': my_standard = torch.nn.Sequential( torch.nn.LayerNorm(dim), - MyLinear(dim, 4 * dim), + StandardLinear(dim, 4 * dim), torch.nn.GELU(), - MyLinear(4 * dim, dim), + StandardLinear(4 * dim, dim), ).cuda() fused_mlp = SwitchBackGlobalMLP(dim, 4 * dim).cuda() @@ -162,4 +162,4 @@ if __name__ == '__main__': # import pdb; pdb.set_trace() - # # NO GELU, ST GRADIENTS, EVERYTHING FINE. \ No newline at end of file + # # NO GELU, ST GRADIENTS, EVERYTHING FINE.