bitsandbytes-rocm/tests/triton_tests/mlp.py
Mitchell Wortsman 5f3d9ada8d triton-v1
2023-03-29 06:47:08 +00:00

64 lines
1.8 KiB
Python

import time
import torch
import torch.nn as nn
import bitsandbytes.nn as bnn
from bitsandbytes.nn.triton_based_modules import SwitchBackLinear, SwitchBackGlobalLinear, MyLinear
def construct_model(dim, layers, module):
modules = []
for _ in range(layers):
modules.append(module(dim, 4*dim))
modules.append(module(4*dim, dim))
return nn.Sequential(*modules).cuda().train()
def get_time(model, x, name):
for _ in range(repeat // 2):
#with torch.cuda.amp.autocast():
out = model(x)
#(2**16 * out.pow(2).mean()).backward()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
# with torch.cuda.amp.autocast():
out = model(x)
#(2**16 * out.pow(2).mean()).backward()
torch.cuda.synchronize()
end = time.time()
print(f"time {name}: {(end - start) / repeat * 1000:.3f} ms")
if __name__ == '__main__':
torch.manual_seed(0)
# hparams
repeat = 16
dim=2048
layers =4
batch_size = 2
sequence_length = 2**15
# construct models
standard = construct_model(dim, layers, nn.Linear).half()
my_standard = construct_model(dim, layers, MyLinear).half()
switchback = construct_model(dim, layers, SwitchBackLinear).half()
switchback_global = construct_model(dim, layers, SwitchBackGlobalLinear).half()
#bnb_8bitmixed = construct_model(dim, layers, bnn.Linear8bitLt)
# simulate forward pass
x = torch.randn(batch_size * sequence_length, dim, dtype=torch.float16).cuda()
# get time for forward and backward
get_time(standard, x, "standard")
get_time(my_standard, x, "my_standard")
get_time(switchback, x, "switchback")
get_time(switchback_global, x, "switchback_global")
#get_time(bnb_8bitmixed, x, "bnb_8bitmixed")