diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index 921ec0a..41c8d27 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -314,9 +314,12 @@ class Optimizer8bit(torch.optim.Optimizer): def prefetch_state(self, p): if self.is_paged: state = self.state[p] - F.prefetch_tensor(state['state1']) - if 'state2' in state: - F.prefetch_tensor(state['state2']) + s1 = state['state1'] + is_paged = getattr(s1, 'is_paged', False) + if is_paged: + F.prefetch_tensor(state['state1']) + if 'state2' in state: + F.prefetch_tensor(state['state2']) class Optimizer2State(Optimizer8bit): diff --git a/tests/test_optim.py b/tests/test_optim.py index a5ecb6e..e35408e 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -490,3 +490,47 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name): params = (k - k // 5) * dim1 * dim2 print(optim_name, gtype, s / params) # assert s < 3.9 + +dim1 = [10*1024] +gtype = [torch.float16] +#mode = ['torch', 'bnb'] +mode = ['bnb'] +optimizer_names = ['paged_adamw'] +#optimizer_names = ['paged_adamw8bit_blockwise'] +values = list(product(dim1,gtype, optimizer_names, mode)) +names = ['dim1_{0}_gtype_{1}_optim_{2}_mode_{3}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, gtype, optim_name, mode", values, ids=names) +def test_stream_optimizer_bench(dim1, gtype, optim_name, mode): + layers1 = torch.nn.Sequential(*torch.nn.ModuleList([torch.nn.Linear(dim1, dim1) for i in range(10)])) + layers1 = layers1.to(gtype) + layers1 = layers1.cuda() + + large_tensor = None + if mode == 'torch': + optim = str2optimizers[optim_name][0](layers1.parameters()) + else: + optim = str2optimizers[optim_name][1](layers1.parameters()) + # 12 GB + large_tensor = torch.empty((int(4.5e9),), device='cuda') + + torch.cuda.synchronize() + time.sleep(5) + + num_batches = 5 + batches = torch.randn(num_batches, 128, dim1, device='cuda').to(gtype) + lbls = torch.randint(0, 10, size=(num_batches,128)).cuda() + + for i in range(num_batches): + print(i) + b = batches[i] + if i ==2: + torch.cuda.synchronize() + t0 = time.time() + + out1 = layers1(b) + + loss1 = torch.nn.functional.cross_entropy(out1, lbls[i]).mean() + loss1.backward() + optim.step() + torch.cuda.synchronize() + print(mode, time.time() - t0)