Changed prefetching.

This commit is contained in:
Tim Dettmers 2023-05-06 18:59:59 -07:00
parent 44d68ff29c
commit 41a9c70814
2 changed files with 14 additions and 2 deletions

View File

@ -100,7 +100,10 @@ class GlobalPageManager:
return cls._instance
def prefetch_all(self, to_cpu=False):
for t in self.paged_tensors:
# assume the first added, will be hte
# ones that are used first, so swap them in last
# in the case they are evicted again
for t in self.paged_tensors[::-1]:
prefetch_tensor(t, to_cpu)

View File

@ -256,7 +256,7 @@ class Optimizer8bit(torch.optim.Optimizer):
self.to_gpu() # needed for fairseq pure fp16 training
self.initialized = True
if self.is_paged: self.page_mng.prefetch_all()
#if self.is_paged: self.page_mng.prefetch_all()
for gindex, group in enumerate(self.param_groups):
for pindex, p in enumerate(group["params"]):
if p.grad is None:
@ -265,7 +265,9 @@ class Optimizer8bit(torch.optim.Optimizer):
if len(state) == 0:
self.init_state(group, p, gindex, pindex)
self.prefetch_state(p)
self.update_step(group, p, gindex, pindex)
torch.cuda.synchronize()
if self.is_paged:
# all paged operation are asynchronous, we need
# to sync to make sure all tensors are in the right state
@ -309,6 +311,13 @@ class Optimizer8bit(torch.optim.Optimizer):
self.page_mng.paged_tensors.append(buff)
return buff
def prefetch_state(self, p):
if self.is_paged:
state = self.state[p]
F.prefetch_tensor(state['state1'])
if 'state2' in state:
F.prefetch_tensor(state['state2'])
class Optimizer2State(Optimizer8bit):
def __init__(