Fixed prefetch bug for non-paged tensors; added benchmark.
This commit is contained in:
parent
41a9c70814
commit
f64cfe65aa
|
@ -314,9 +314,12 @@ class Optimizer8bit(torch.optim.Optimizer):
|
|||
def prefetch_state(self, p):
|
||||
if self.is_paged:
|
||||
state = self.state[p]
|
||||
F.prefetch_tensor(state['state1'])
|
||||
if 'state2' in state:
|
||||
F.prefetch_tensor(state['state2'])
|
||||
s1 = state['state1']
|
||||
is_paged = getattr(s1, 'is_paged', False)
|
||||
if is_paged:
|
||||
F.prefetch_tensor(state['state1'])
|
||||
if 'state2' in state:
|
||||
F.prefetch_tensor(state['state2'])
|
||||
|
||||
|
||||
class Optimizer2State(Optimizer8bit):
|
||||
|
|
|
@ -490,3 +490,47 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
|
|||
params = (k - k // 5) * dim1 * dim2
|
||||
print(optim_name, gtype, s / params)
|
||||
# assert s < 3.9
|
||||
|
||||
dim1 = [10*1024]
|
||||
gtype = [torch.float16]
|
||||
#mode = ['torch', 'bnb']
|
||||
mode = ['bnb']
|
||||
optimizer_names = ['paged_adamw']
|
||||
#optimizer_names = ['paged_adamw8bit_blockwise']
|
||||
values = list(product(dim1,gtype, optimizer_names, mode))
|
||||
names = ['dim1_{0}_gtype_{1}_optim_{2}_mode_{3}'.format(*vals) for vals in values]
|
||||
@pytest.mark.parametrize("dim1, gtype, optim_name, mode", values, ids=names)
|
||||
def test_stream_optimizer_bench(dim1, gtype, optim_name, mode):
|
||||
layers1 = torch.nn.Sequential(*torch.nn.ModuleList([torch.nn.Linear(dim1, dim1) for i in range(10)]))
|
||||
layers1 = layers1.to(gtype)
|
||||
layers1 = layers1.cuda()
|
||||
|
||||
large_tensor = None
|
||||
if mode == 'torch':
|
||||
optim = str2optimizers[optim_name][0](layers1.parameters())
|
||||
else:
|
||||
optim = str2optimizers[optim_name][1](layers1.parameters())
|
||||
# 12 GB
|
||||
large_tensor = torch.empty((int(4.5e9),), device='cuda')
|
||||
|
||||
torch.cuda.synchronize()
|
||||
time.sleep(5)
|
||||
|
||||
num_batches = 5
|
||||
batches = torch.randn(num_batches, 128, dim1, device='cuda').to(gtype)
|
||||
lbls = torch.randint(0, 10, size=(num_batches,128)).cuda()
|
||||
|
||||
for i in range(num_batches):
|
||||
print(i)
|
||||
b = batches[i]
|
||||
if i ==2:
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
|
||||
out1 = layers1(b)
|
||||
|
||||
loss1 = torch.nn.functional.cross_entropy(out1, lbls[i]).mean()
|
||||
loss1.backward()
|
||||
optim.step()
|
||||
torch.cuda.synchronize()
|
||||
print(mode, time.time() - t0)
|
||||
|
|
Loading…
Reference in New Issue
Block a user