Changed prefetching.
This commit is contained in:
parent
44d68ff29c
commit
41a9c70814
|
@ -100,7 +100,10 @@ class GlobalPageManager:
|
|||
return cls._instance
|
||||
|
||||
def prefetch_all(self, to_cpu=False):
|
||||
for t in self.paged_tensors:
|
||||
# assume the first added, will be hte
|
||||
# ones that are used first, so swap them in last
|
||||
# in the case they are evicted again
|
||||
for t in self.paged_tensors[::-1]:
|
||||
prefetch_tensor(t, to_cpu)
|
||||
|
||||
|
||||
|
|
|
@ -256,7 +256,7 @@ class Optimizer8bit(torch.optim.Optimizer):
|
|||
self.to_gpu() # needed for fairseq pure fp16 training
|
||||
self.initialized = True
|
||||
|
||||
if self.is_paged: self.page_mng.prefetch_all()
|
||||
#if self.is_paged: self.page_mng.prefetch_all()
|
||||
for gindex, group in enumerate(self.param_groups):
|
||||
for pindex, p in enumerate(group["params"]):
|
||||
if p.grad is None:
|
||||
|
@ -265,7 +265,9 @@ class Optimizer8bit(torch.optim.Optimizer):
|
|||
if len(state) == 0:
|
||||
self.init_state(group, p, gindex, pindex)
|
||||
|
||||
self.prefetch_state(p)
|
||||
self.update_step(group, p, gindex, pindex)
|
||||
torch.cuda.synchronize()
|
||||
if self.is_paged:
|
||||
# all paged operation are asynchronous, we need
|
||||
# to sync to make sure all tensors are in the right state
|
||||
|
@ -309,6 +311,13 @@ class Optimizer8bit(torch.optim.Optimizer):
|
|||
self.page_mng.paged_tensors.append(buff)
|
||||
return buff
|
||||
|
||||
def prefetch_state(self, p):
|
||||
if self.is_paged:
|
||||
state = self.state[p]
|
||||
F.prefetch_tensor(state['state1'])
|
||||
if 'state2' in state:
|
||||
F.prefetch_tensor(state['state2'])
|
||||
|
||||
|
||||
class Optimizer2State(Optimizer8bit):
|
||||
def __init__(
|
||||
|
|
Loading…
Reference in New Issue
Block a user