Added first triton test.
This commit is contained in:
parent
b373034e31
commit
a13a522c4c
|
@ -133,7 +133,7 @@ class SwitchBackGlobalLinear(nn.Linear):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class LinearFunction(torch.autograd.Function):
|
class StandardLinearFunction(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, input, weight, bias=None):
|
def forward(ctx, input, weight, bias=None):
|
||||||
X = input.view(-1, input.size(-1))
|
X = input.view(-1, input.size(-1))
|
||||||
|
@ -161,87 +161,8 @@ class LinearFunction(torch.autograd.Function):
|
||||||
|
|
||||||
return grad_input, grad_weight, grad_bias
|
return grad_input, grad_weight, grad_bias
|
||||||
|
|
||||||
class MyLinear(nn.Linear):
|
class StandardLinear(nn.Linear):
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return LinearFunction.apply(x, self.weight, self.bias)
|
return StandardLinearFunction.apply(x, self.weight, self.bias)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class _switchback_mlp(torch.autograd.Function):
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, X_3D, W1, B1, W2, B2):
|
|
||||||
|
|
||||||
X1 = X_3D.view(-1, X_3D.size(-1))
|
|
||||||
|
|
||||||
X1_int8, state_X1 = quantize_rowwise_nogroup(X1)
|
|
||||||
W1_int8, state_W1 = quantize_global(W1)
|
|
||||||
|
|
||||||
X2_pre = int8_matmul_mixed_dequanitze_bias(
|
|
||||||
X1_int8, W1_int8.t(), state_X1, state_W1, B1
|
|
||||||
)
|
|
||||||
|
|
||||||
# X2_v1 = torch.nn.functional.gelu(X2)
|
|
||||||
# X2_int8, state_X2, = quantize_rowwise_nogroup(X2_v1)
|
|
||||||
X2_int8, state_X2, X2 = quantize_rowwise_nogroup_gelu(X2_pre)
|
|
||||||
|
|
||||||
W2_int8, state_W2 = quantize_global(W2)
|
|
||||||
|
|
||||||
out = int8_matmul_mixed_dequanitze_bias(
|
|
||||||
X2_int8, W2_int8.t(), state_X2, state_W2, B2
|
|
||||||
)
|
|
||||||
|
|
||||||
ctx.save_for_backward = X1, W1, X2, X2_pre, W2
|
|
||||||
|
|
||||||
return out.view(*X_3D.size()[:-1], -1)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, G_3D):
|
|
||||||
|
|
||||||
G2 = G_3D.reshape(-1, G_3D.size(-1))
|
|
||||||
|
|
||||||
grad_X1 = grad_W1 = grad_B1 = grad_W2 = grad_B2 = None
|
|
||||||
|
|
||||||
X1, W1, X2, X2_pre, W2 = ctx.save_for_backward
|
|
||||||
|
|
||||||
G2_int8, state_G2 = quantize_rowwise_nogroup(G2)
|
|
||||||
W2_int8, state_W2 = quantize_global_transpose(W2)
|
|
||||||
|
|
||||||
G1 = int8_matmul_mixed_dequanitze(G2_int8, W2_int8.t(), state_G2, state_W2).view(
|
|
||||||
*G_3D.size()[:-1], -1
|
|
||||||
)
|
|
||||||
|
|
||||||
grad_W2 = torch.matmul(G2.t(), X2.to(G2.dtype))
|
|
||||||
grad_B2 = G2.sum(dim=0)
|
|
||||||
|
|
||||||
G1_int8, state_G1, G1 = quantize_rowwise_nogroup_back_gelu(G1, X2_pre)
|
|
||||||
|
|
||||||
if ctx.needs_input_grad[0]:
|
|
||||||
|
|
||||||
W1_int8, state_W1 = quantize_global_transpose(W1)
|
|
||||||
grad_X1 = int8_matmul_mixed_dequanitze(G1_int8, W1_int8.t(), state_G1, state_W1).view(
|
|
||||||
*G_3D.size()[:-1], -1
|
|
||||||
)
|
|
||||||
if ctx.needs_input_grad[1]:
|
|
||||||
grad_W1 = torch.matmul(G1.t(), X1.to(G1.dtype))
|
|
||||||
if ctx.needs_input_grad[2]:
|
|
||||||
grad_B1 = G1.sum(dim=0)
|
|
||||||
|
|
||||||
return grad_X1, grad_W1, grad_B1, grad_W2, grad_B2
|
|
||||||
|
|
||||||
|
|
||||||
class SwitchBackGlobalMLP(nn.Module):
|
|
||||||
|
|
||||||
|
|
||||||
def __init__(self, dim_in, dim_hidden):
|
|
||||||
super().__init__()
|
|
||||||
self.linear1 = nn.Linear(dim_in, dim_hidden)
|
|
||||||
self.linear2 = nn.Linear(dim_hidden, dim_in)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return _switchback_mlp.apply(x, self.linear1.weight, self.linear1.bias, self.linear2.weight, self.linear2.bias)
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import json
|
import json
|
||||||
from bitsandbytes.nn.triton_based_modules import SwitchBackGlobalMLP, SwitchBackGlobalLinear, MyLinear
|
from bitsandbytes.nn.triton_based_modules import SwitchBackGlobalMLP, SwitchBackGlobalLinear, StandardLinear
|
||||||
import time
|
import time
|
||||||
|
|
||||||
# class AttentionOld(torch.nn.Module):
|
# class AttentionOld(torch.nn.Module):
|
||||||
|
@ -116,7 +116,7 @@ if __name__ == '__main__':
|
||||||
va = torch.randn( batch // 256, 256, dim ).cuda().requires_grad_(True)
|
va = torch.randn( batch // 256, 256, dim ).cuda().requires_grad_(True)
|
||||||
|
|
||||||
standard = Attention(dim).cuda()
|
standard = Attention(dim).cuda()
|
||||||
my_standard = Attention(dim, linear_module=MyLinear).cuda()
|
my_standard = Attention(dim, linear_module=StandardLinear).cuda()
|
||||||
sb = Attention(dim, linear_module=SwitchBackGlobalLinear).cuda()
|
sb = Attention(dim, linear_module=SwitchBackGlobalLinear).cuda()
|
||||||
standard_compiled = torch.compile(standard)
|
standard_compiled = torch.compile(standard)
|
||||||
ln_model = torch.nn.Sequential(
|
ln_model = torch.nn.Sequential(
|
||||||
|
|
|
@ -4,7 +4,7 @@ import time
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import bitsandbytes.nn as bnn
|
import bitsandbytes.nn as bnn
|
||||||
from bitsandbytes.nn.triton_based_modules import SwitchBackLinear, SwitchBackGlobalLinear, MyLinear
|
from bitsandbytes.nn.triton_based_modules import SwitchBackLinear, SwitchBackGlobalLinear, StandardLinear
|
||||||
|
|
||||||
from bitsandbytes.nn.triton_utils.v0.quantize_rowwise_nogroup import quantize_rowwise_nogroup
|
from bitsandbytes.nn.triton_utils.v0.quantize_rowwise_nogroup import quantize_rowwise_nogroup
|
||||||
from bitsandbytes.nn.triton_utils.v0.quantize_columnwise_nogroup_transpose import quantize_columnwise_nogroup_transpose
|
from bitsandbytes.nn.triton_utils.v0.quantize_columnwise_nogroup_transpose import quantize_columnwise_nogroup_transpose
|
||||||
|
|
|
@ -3,7 +3,7 @@ import time
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import bitsandbytes.nn as bnn
|
import bitsandbytes.nn as bnn
|
||||||
from bitsandbytes.nn.triton_based_modules import SwitchBackLinear, SwitchBackGlobalLinear, MyLinear
|
from bitsandbytes.nn.triton_based_modules import SwitchBackLinear, SwitchBackGlobalLinear, StandardLinear
|
||||||
|
|
||||||
def construct_model(dim, layers, module):
|
def construct_model(dim, layers, module):
|
||||||
modules = []
|
modules = []
|
||||||
|
@ -41,7 +41,7 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
# construct models
|
# construct models
|
||||||
standard = construct_model(dim, layers, nn.Linear).half()
|
standard = construct_model(dim, layers, nn.Linear).half()
|
||||||
my_standard = construct_model(dim, layers, MyLinear).half()
|
my_standard = construct_model(dim, layers, StandardLinear).half()
|
||||||
switchback = construct_model(dim, layers, SwitchBackLinear).half()
|
switchback = construct_model(dim, layers, SwitchBackLinear).half()
|
||||||
switchback_global = construct_model(dim, layers, SwitchBackGlobalLinear).half()
|
switchback_global = construct_model(dim, layers, SwitchBackGlobalLinear).half()
|
||||||
#bnb_8bitmixed = construct_model(dim, layers, bnn.Linear8bitLt)
|
#bnb_8bitmixed = construct_model(dim, layers, bnn.Linear8bitLt)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import json
|
import json
|
||||||
from bitsandbytes.nn.triton_based_modules import SwitchBackGlobalMLP, SwitchBackGlobalLinear, MyLinear
|
from bitsandbytes.nn.triton_based_modules import SwitchBackGlobalMLP, SwitchBackGlobalLinear, StandardLinear
|
||||||
import time
|
import time
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -26,9 +26,9 @@ if __name__ == '__main__':
|
||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
my_standard = torch.nn.Sequential(
|
my_standard = torch.nn.Sequential(
|
||||||
MyLinear(dim, 4 * dim),
|
StandardLinear(dim, 4 * dim),
|
||||||
torch.nn.GELU(),
|
torch.nn.GELU(),
|
||||||
MyLinear(4 * dim, dim),
|
StandardLinear(4 * dim, dim),
|
||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
fused_mlp = SwitchBackGlobalMLP(dim, 4 * dim).cuda()
|
fused_mlp = SwitchBackGlobalMLP(dim, 4 * dim).cuda()
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import json
|
import json
|
||||||
from bitsandbytes.nn.triton_based_modules import SwitchBackGlobalMLP, SwitchBackGlobalLinear, MyLinear
|
from bitsandbytes.nn.triton_based_modules import SwitchBackGlobalMLP, SwitchBackGlobalLinear, StandardLinear
|
||||||
import time
|
import time
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -24,9 +24,9 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
my_standard = torch.nn.Sequential(
|
my_standard = torch.nn.Sequential(
|
||||||
torch.nn.LayerNorm(dim),
|
torch.nn.LayerNorm(dim),
|
||||||
MyLinear(dim, 4 * dim),
|
StandardLinear(dim, 4 * dim),
|
||||||
torch.nn.GELU(),
|
torch.nn.GELU(),
|
||||||
MyLinear(4 * dim, dim),
|
StandardLinear(4 * dim, dim),
|
||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
fused_mlp = SwitchBackGlobalMLP(dim, 4 * dim).cuda()
|
fused_mlp = SwitchBackGlobalMLP(dim, 4 * dim).cuda()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user