Added first triton test.
This commit is contained in:
parent
b373034e31
commit
a13a522c4c
|
@ -133,7 +133,7 @@ class SwitchBackGlobalLinear(nn.Linear):
|
|||
|
||||
|
||||
|
||||
class LinearFunction(torch.autograd.Function):
|
||||
class StandardLinearFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight, bias=None):
|
||||
X = input.view(-1, input.size(-1))
|
||||
|
@ -161,87 +161,8 @@ class LinearFunction(torch.autograd.Function):
|
|||
|
||||
return grad_input, grad_weight, grad_bias
|
||||
|
||||
class MyLinear(nn.Linear):
|
||||
class StandardLinear(nn.Linear):
|
||||
|
||||
def forward(self, x):
|
||||
return LinearFunction.apply(x, self.weight, self.bias)
|
||||
return StandardLinearFunction.apply(x, self.weight, self.bias)
|
||||
|
||||
|
||||
|
||||
|
||||
class _switchback_mlp(torch.autograd.Function):
|
||||
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, X_3D, W1, B1, W2, B2):
|
||||
|
||||
X1 = X_3D.view(-1, X_3D.size(-1))
|
||||
|
||||
X1_int8, state_X1 = quantize_rowwise_nogroup(X1)
|
||||
W1_int8, state_W1 = quantize_global(W1)
|
||||
|
||||
X2_pre = int8_matmul_mixed_dequanitze_bias(
|
||||
X1_int8, W1_int8.t(), state_X1, state_W1, B1
|
||||
)
|
||||
|
||||
# X2_v1 = torch.nn.functional.gelu(X2)
|
||||
# X2_int8, state_X2, = quantize_rowwise_nogroup(X2_v1)
|
||||
X2_int8, state_X2, X2 = quantize_rowwise_nogroup_gelu(X2_pre)
|
||||
|
||||
W2_int8, state_W2 = quantize_global(W2)
|
||||
|
||||
out = int8_matmul_mixed_dequanitze_bias(
|
||||
X2_int8, W2_int8.t(), state_X2, state_W2, B2
|
||||
)
|
||||
|
||||
ctx.save_for_backward = X1, W1, X2, X2_pre, W2
|
||||
|
||||
return out.view(*X_3D.size()[:-1], -1)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, G_3D):
|
||||
|
||||
G2 = G_3D.reshape(-1, G_3D.size(-1))
|
||||
|
||||
grad_X1 = grad_W1 = grad_B1 = grad_W2 = grad_B2 = None
|
||||
|
||||
X1, W1, X2, X2_pre, W2 = ctx.save_for_backward
|
||||
|
||||
G2_int8, state_G2 = quantize_rowwise_nogroup(G2)
|
||||
W2_int8, state_W2 = quantize_global_transpose(W2)
|
||||
|
||||
G1 = int8_matmul_mixed_dequanitze(G2_int8, W2_int8.t(), state_G2, state_W2).view(
|
||||
*G_3D.size()[:-1], -1
|
||||
)
|
||||
|
||||
grad_W2 = torch.matmul(G2.t(), X2.to(G2.dtype))
|
||||
grad_B2 = G2.sum(dim=0)
|
||||
|
||||
G1_int8, state_G1, G1 = quantize_rowwise_nogroup_back_gelu(G1, X2_pre)
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
|
||||
W1_int8, state_W1 = quantize_global_transpose(W1)
|
||||
grad_X1 = int8_matmul_mixed_dequanitze(G1_int8, W1_int8.t(), state_G1, state_W1).view(
|
||||
*G_3D.size()[:-1], -1
|
||||
)
|
||||
if ctx.needs_input_grad[1]:
|
||||
grad_W1 = torch.matmul(G1.t(), X1.to(G1.dtype))
|
||||
if ctx.needs_input_grad[2]:
|
||||
grad_B1 = G1.sum(dim=0)
|
||||
|
||||
return grad_X1, grad_W1, grad_B1, grad_W2, grad_B2
|
||||
|
||||
|
||||
class SwitchBackGlobalMLP(nn.Module):
|
||||
|
||||
|
||||
def __init__(self, dim_in, dim_hidden):
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(dim_in, dim_hidden)
|
||||
self.linear2 = nn.Linear(dim_hidden, dim_in)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
return _switchback_mlp.apply(x, self.linear1.weight, self.linear1.bias, self.linear2.weight, self.linear2.bias)
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
|
||||
import torch
|
||||
import json
|
||||
from bitsandbytes.nn.triton_based_modules import SwitchBackGlobalMLP, SwitchBackGlobalLinear, MyLinear
|
||||
from bitsandbytes.nn.triton_based_modules import SwitchBackGlobalMLP, SwitchBackGlobalLinear, StandardLinear
|
||||
import time
|
||||
|
||||
# class AttentionOld(torch.nn.Module):
|
||||
|
@ -116,7 +116,7 @@ if __name__ == '__main__':
|
|||
va = torch.randn( batch // 256, 256, dim ).cuda().requires_grad_(True)
|
||||
|
||||
standard = Attention(dim).cuda()
|
||||
my_standard = Attention(dim, linear_module=MyLinear).cuda()
|
||||
my_standard = Attention(dim, linear_module=StandardLinear).cuda()
|
||||
sb = Attention(dim, linear_module=SwitchBackGlobalLinear).cuda()
|
||||
standard_compiled = torch.compile(standard)
|
||||
ln_model = torch.nn.Sequential(
|
||||
|
@ -360,4 +360,4 @@ if __name__ == '__main__':
|
|||
# import pdb; pdb.set_trace()
|
||||
|
||||
|
||||
# # NO GELU, ST GRADIENTS, EVERYTHING FINE.
|
||||
# # NO GELU, ST GRADIENTS, EVERYTHING FINE.
|
||||
|
|
|
@ -4,7 +4,7 @@ import time
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import bitsandbytes.nn as bnn
|
||||
from bitsandbytes.nn.triton_based_modules import SwitchBackLinear, SwitchBackGlobalLinear, MyLinear
|
||||
from bitsandbytes.nn.triton_based_modules import SwitchBackLinear, SwitchBackGlobalLinear, StandardLinear
|
||||
|
||||
from bitsandbytes.nn.triton_utils.v0.quantize_rowwise_nogroup import quantize_rowwise_nogroup
|
||||
from bitsandbytes.nn.triton_utils.v0.quantize_columnwise_nogroup_transpose import quantize_columnwise_nogroup_transpose
|
||||
|
@ -350,4 +350,4 @@ if __name__ == '__main__':
|
|||
|
||||
|
||||
with open("tests/triton_tests/info.jsonl", "a") as file:
|
||||
file.write(info_json + "\n")
|
||||
file.write(info_json + "\n")
|
||||
|
|
|
@ -3,7 +3,7 @@ import time
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import bitsandbytes.nn as bnn
|
||||
from bitsandbytes.nn.triton_based_modules import SwitchBackLinear, SwitchBackGlobalLinear, MyLinear
|
||||
from bitsandbytes.nn.triton_based_modules import SwitchBackLinear, SwitchBackGlobalLinear, StandardLinear
|
||||
|
||||
def construct_model(dim, layers, module):
|
||||
modules = []
|
||||
|
@ -41,7 +41,7 @@ if __name__ == '__main__':
|
|||
|
||||
# construct models
|
||||
standard = construct_model(dim, layers, nn.Linear).half()
|
||||
my_standard = construct_model(dim, layers, MyLinear).half()
|
||||
my_standard = construct_model(dim, layers, StandardLinear).half()
|
||||
switchback = construct_model(dim, layers, SwitchBackLinear).half()
|
||||
switchback_global = construct_model(dim, layers, SwitchBackGlobalLinear).half()
|
||||
#bnb_8bitmixed = construct_model(dim, layers, bnn.Linear8bitLt)
|
||||
|
@ -61,4 +61,4 @@ if __name__ == '__main__':
|
|||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
|
||||
import torch
|
||||
import json
|
||||
from bitsandbytes.nn.triton_based_modules import SwitchBackGlobalMLP, SwitchBackGlobalLinear, MyLinear
|
||||
from bitsandbytes.nn.triton_based_modules import SwitchBackGlobalMLP, SwitchBackGlobalLinear, StandardLinear
|
||||
import time
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -26,9 +26,9 @@ if __name__ == '__main__':
|
|||
).cuda()
|
||||
|
||||
my_standard = torch.nn.Sequential(
|
||||
MyLinear(dim, 4 * dim),
|
||||
StandardLinear(dim, 4 * dim),
|
||||
torch.nn.GELU(),
|
||||
MyLinear(4 * dim, dim),
|
||||
StandardLinear(4 * dim, dim),
|
||||
).cuda()
|
||||
|
||||
fused_mlp = SwitchBackGlobalMLP(dim, 4 * dim).cuda()
|
||||
|
@ -163,4 +163,4 @@ if __name__ == '__main__':
|
|||
# import pdb; pdb.set_trace()
|
||||
|
||||
|
||||
# # NO GELU, ST GRADIENTS, EVERYTHING FINE.
|
||||
# # NO GELU, ST GRADIENTS, EVERYTHING FINE.
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
|
||||
import torch
|
||||
import json
|
||||
from bitsandbytes.nn.triton_based_modules import SwitchBackGlobalMLP, SwitchBackGlobalLinear, MyLinear
|
||||
from bitsandbytes.nn.triton_based_modules import SwitchBackGlobalMLP, SwitchBackGlobalLinear, StandardLinear
|
||||
import time
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -24,9 +24,9 @@ if __name__ == '__main__':
|
|||
|
||||
my_standard = torch.nn.Sequential(
|
||||
torch.nn.LayerNorm(dim),
|
||||
MyLinear(dim, 4 * dim),
|
||||
StandardLinear(dim, 4 * dim),
|
||||
torch.nn.GELU(),
|
||||
MyLinear(4 * dim, dim),
|
||||
StandardLinear(4 * dim, dim),
|
||||
).cuda()
|
||||
|
||||
fused_mlp = SwitchBackGlobalMLP(dim, 4 * dim).cuda()
|
||||
|
@ -162,4 +162,4 @@ if __name__ == '__main__':
|
|||
# import pdb; pdb.set_trace()
|
||||
|
||||
|
||||
# # NO GELU, ST GRADIENTS, EVERYTHING FINE.
|
||||
# # NO GELU, ST GRADIENTS, EVERYTHING FINE.
|
||||
|
|
Loading…
Reference in New Issue
Block a user