Added first triton test.

This commit is contained in:
Tim Dettmers 2023-03-31 11:20:54 -07:00
parent b373034e31
commit a13a522c4c
6 changed files with 19 additions and 98 deletions

View File

@ -133,7 +133,7 @@ class SwitchBackGlobalLinear(nn.Linear):
class LinearFunction(torch.autograd.Function):
class StandardLinearFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias=None):
X = input.view(-1, input.size(-1))
@ -161,87 +161,8 @@ class LinearFunction(torch.autograd.Function):
return grad_input, grad_weight, grad_bias
class MyLinear(nn.Linear):
class StandardLinear(nn.Linear):
def forward(self, x):
return LinearFunction.apply(x, self.weight, self.bias)
return StandardLinearFunction.apply(x, self.weight, self.bias)
class _switchback_mlp(torch.autograd.Function):
@staticmethod
def forward(ctx, X_3D, W1, B1, W2, B2):
X1 = X_3D.view(-1, X_3D.size(-1))
X1_int8, state_X1 = quantize_rowwise_nogroup(X1)
W1_int8, state_W1 = quantize_global(W1)
X2_pre = int8_matmul_mixed_dequanitze_bias(
X1_int8, W1_int8.t(), state_X1, state_W1, B1
)
# X2_v1 = torch.nn.functional.gelu(X2)
# X2_int8, state_X2, = quantize_rowwise_nogroup(X2_v1)
X2_int8, state_X2, X2 = quantize_rowwise_nogroup_gelu(X2_pre)
W2_int8, state_W2 = quantize_global(W2)
out = int8_matmul_mixed_dequanitze_bias(
X2_int8, W2_int8.t(), state_X2, state_W2, B2
)
ctx.save_for_backward = X1, W1, X2, X2_pre, W2
return out.view(*X_3D.size()[:-1], -1)
@staticmethod
def backward(ctx, G_3D):
G2 = G_3D.reshape(-1, G_3D.size(-1))
grad_X1 = grad_W1 = grad_B1 = grad_W2 = grad_B2 = None
X1, W1, X2, X2_pre, W2 = ctx.save_for_backward
G2_int8, state_G2 = quantize_rowwise_nogroup(G2)
W2_int8, state_W2 = quantize_global_transpose(W2)
G1 = int8_matmul_mixed_dequanitze(G2_int8, W2_int8.t(), state_G2, state_W2).view(
*G_3D.size()[:-1], -1
)
grad_W2 = torch.matmul(G2.t(), X2.to(G2.dtype))
grad_B2 = G2.sum(dim=0)
G1_int8, state_G1, G1 = quantize_rowwise_nogroup_back_gelu(G1, X2_pre)
if ctx.needs_input_grad[0]:
W1_int8, state_W1 = quantize_global_transpose(W1)
grad_X1 = int8_matmul_mixed_dequanitze(G1_int8, W1_int8.t(), state_G1, state_W1).view(
*G_3D.size()[:-1], -1
)
if ctx.needs_input_grad[1]:
grad_W1 = torch.matmul(G1.t(), X1.to(G1.dtype))
if ctx.needs_input_grad[2]:
grad_B1 = G1.sum(dim=0)
return grad_X1, grad_W1, grad_B1, grad_W2, grad_B2
class SwitchBackGlobalMLP(nn.Module):
def __init__(self, dim_in, dim_hidden):
super().__init__()
self.linear1 = nn.Linear(dim_in, dim_hidden)
self.linear2 = nn.Linear(dim_hidden, dim_in)
def forward(self, x):
return _switchback_mlp.apply(x, self.linear1.weight, self.linear1.bias, self.linear2.weight, self.linear2.bias)

View File

@ -1,7 +1,7 @@
import torch
import json
from bitsandbytes.nn.triton_based_modules import SwitchBackGlobalMLP, SwitchBackGlobalLinear, MyLinear
from bitsandbytes.nn.triton_based_modules import SwitchBackGlobalMLP, SwitchBackGlobalLinear, StandardLinear
import time
# class AttentionOld(torch.nn.Module):
@ -116,7 +116,7 @@ if __name__ == '__main__':
va = torch.randn( batch // 256, 256, dim ).cuda().requires_grad_(True)
standard = Attention(dim).cuda()
my_standard = Attention(dim, linear_module=MyLinear).cuda()
my_standard = Attention(dim, linear_module=StandardLinear).cuda()
sb = Attention(dim, linear_module=SwitchBackGlobalLinear).cuda()
standard_compiled = torch.compile(standard)
ln_model = torch.nn.Sequential(
@ -360,4 +360,4 @@ if __name__ == '__main__':
# import pdb; pdb.set_trace()
# # NO GELU, ST GRADIENTS, EVERYTHING FINE.
# # NO GELU, ST GRADIENTS, EVERYTHING FINE.

View File

@ -4,7 +4,7 @@ import time
import torch
import torch.nn as nn
import bitsandbytes.nn as bnn
from bitsandbytes.nn.triton_based_modules import SwitchBackLinear, SwitchBackGlobalLinear, MyLinear
from bitsandbytes.nn.triton_based_modules import SwitchBackLinear, SwitchBackGlobalLinear, StandardLinear
from bitsandbytes.nn.triton_utils.v0.quantize_rowwise_nogroup import quantize_rowwise_nogroup
from bitsandbytes.nn.triton_utils.v0.quantize_columnwise_nogroup_transpose import quantize_columnwise_nogroup_transpose
@ -350,4 +350,4 @@ if __name__ == '__main__':
with open("tests/triton_tests/info.jsonl", "a") as file:
file.write(info_json + "\n")
file.write(info_json + "\n")

View File

@ -3,7 +3,7 @@ import time
import torch
import torch.nn as nn
import bitsandbytes.nn as bnn
from bitsandbytes.nn.triton_based_modules import SwitchBackLinear, SwitchBackGlobalLinear, MyLinear
from bitsandbytes.nn.triton_based_modules import SwitchBackLinear, SwitchBackGlobalLinear, StandardLinear
def construct_model(dim, layers, module):
modules = []
@ -41,7 +41,7 @@ if __name__ == '__main__':
# construct models
standard = construct_model(dim, layers, nn.Linear).half()
my_standard = construct_model(dim, layers, MyLinear).half()
my_standard = construct_model(dim, layers, StandardLinear).half()
switchback = construct_model(dim, layers, SwitchBackLinear).half()
switchback_global = construct_model(dim, layers, SwitchBackGlobalLinear).half()
#bnb_8bitmixed = construct_model(dim, layers, bnn.Linear8bitLt)
@ -61,4 +61,4 @@ if __name__ == '__main__':

View File

@ -1,7 +1,7 @@
import torch
import json
from bitsandbytes.nn.triton_based_modules import SwitchBackGlobalMLP, SwitchBackGlobalLinear, MyLinear
from bitsandbytes.nn.triton_based_modules import SwitchBackGlobalMLP, SwitchBackGlobalLinear, StandardLinear
import time
if __name__ == '__main__':
@ -26,9 +26,9 @@ if __name__ == '__main__':
).cuda()
my_standard = torch.nn.Sequential(
MyLinear(dim, 4 * dim),
StandardLinear(dim, 4 * dim),
torch.nn.GELU(),
MyLinear(4 * dim, dim),
StandardLinear(4 * dim, dim),
).cuda()
fused_mlp = SwitchBackGlobalMLP(dim, 4 * dim).cuda()
@ -163,4 +163,4 @@ if __name__ == '__main__':
# import pdb; pdb.set_trace()
# # NO GELU, ST GRADIENTS, EVERYTHING FINE.
# # NO GELU, ST GRADIENTS, EVERYTHING FINE.

View File

@ -1,7 +1,7 @@
import torch
import json
from bitsandbytes.nn.triton_based_modules import SwitchBackGlobalMLP, SwitchBackGlobalLinear, MyLinear
from bitsandbytes.nn.triton_based_modules import SwitchBackGlobalMLP, SwitchBackGlobalLinear, StandardLinear
import time
if __name__ == '__main__':
@ -24,9 +24,9 @@ if __name__ == '__main__':
my_standard = torch.nn.Sequential(
torch.nn.LayerNorm(dim),
MyLinear(dim, 4 * dim),
StandardLinear(dim, 4 * dim),
torch.nn.GELU(),
MyLinear(4 * dim, dim),
StandardLinear(4 * dim, dim),
).cuda()
fused_mlp = SwitchBackGlobalMLP(dim, 4 * dim).cuda()
@ -162,4 +162,4 @@ if __name__ == '__main__':
# import pdb; pdb.set_trace()
# # NO GELU, ST GRADIENTS, EVERYTHING FINE.
# # NO GELU, ST GRADIENTS, EVERYTHING FINE.