Fixed prefetch bug for non-paged tensors; added benchmark.

This commit is contained in:
Tim Dettmers 2023-05-06 21:49:16 -07:00
parent 41a9c70814
commit f64cfe65aa
2 changed files with 50 additions and 3 deletions

View File

@ -314,9 +314,12 @@ class Optimizer8bit(torch.optim.Optimizer):
def prefetch_state(self, p): def prefetch_state(self, p):
if self.is_paged: if self.is_paged:
state = self.state[p] state = self.state[p]
F.prefetch_tensor(state['state1']) s1 = state['state1']
if 'state2' in state: is_paged = getattr(s1, 'is_paged', False)
F.prefetch_tensor(state['state2']) if is_paged:
F.prefetch_tensor(state['state1'])
if 'state2' in state:
F.prefetch_tensor(state['state2'])
class Optimizer2State(Optimizer8bit): class Optimizer2State(Optimizer8bit):

View File

@ -490,3 +490,47 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
params = (k - k // 5) * dim1 * dim2 params = (k - k // 5) * dim1 * dim2
print(optim_name, gtype, s / params) print(optim_name, gtype, s / params)
# assert s < 3.9 # assert s < 3.9
dim1 = [10*1024]
gtype = [torch.float16]
#mode = ['torch', 'bnb']
mode = ['bnb']
optimizer_names = ['paged_adamw']
#optimizer_names = ['paged_adamw8bit_blockwise']
values = list(product(dim1,gtype, optimizer_names, mode))
names = ['dim1_{0}_gtype_{1}_optim_{2}_mode_{3}'.format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, gtype, optim_name, mode", values, ids=names)
def test_stream_optimizer_bench(dim1, gtype, optim_name, mode):
layers1 = torch.nn.Sequential(*torch.nn.ModuleList([torch.nn.Linear(dim1, dim1) for i in range(10)]))
layers1 = layers1.to(gtype)
layers1 = layers1.cuda()
large_tensor = None
if mode == 'torch':
optim = str2optimizers[optim_name][0](layers1.parameters())
else:
optim = str2optimizers[optim_name][1](layers1.parameters())
# 12 GB
large_tensor = torch.empty((int(4.5e9),), device='cuda')
torch.cuda.synchronize()
time.sleep(5)
num_batches = 5
batches = torch.randn(num_batches, 128, dim1, device='cuda').to(gtype)
lbls = torch.randint(0, 10, size=(num_batches,128)).cuda()
for i in range(num_batches):
print(i)
b = batches[i]
if i ==2:
torch.cuda.synchronize()
t0 = time.time()
out1 = layers1(b)
loss1 = torch.nn.functional.cross_entropy(out1, lbls[i]).mean()
loss1.backward()
optim.step()
torch.cuda.synchronize()
print(mode, time.time() - t0)