diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index a6ed675..2542e4b 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -100,7 +100,10 @@ class GlobalPageManager: return cls._instance def prefetch_all(self, to_cpu=False): - for t in self.paged_tensors: + # assume the first added, will be hte + # ones that are used first, so swap them in last + # in the case they are evicted again + for t in self.paged_tensors[::-1]: prefetch_tensor(t, to_cpu) diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index 4f8dcc7..921ec0a 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -256,7 +256,7 @@ class Optimizer8bit(torch.optim.Optimizer): self.to_gpu() # needed for fairseq pure fp16 training self.initialized = True - if self.is_paged: self.page_mng.prefetch_all() + #if self.is_paged: self.page_mng.prefetch_all() for gindex, group in enumerate(self.param_groups): for pindex, p in enumerate(group["params"]): if p.grad is None: @@ -265,7 +265,9 @@ class Optimizer8bit(torch.optim.Optimizer): if len(state) == 0: self.init_state(group, p, gindex, pindex) + self.prefetch_state(p) self.update_step(group, p, gindex, pindex) + torch.cuda.synchronize() if self.is_paged: # all paged operation are asynchronous, we need # to sync to make sure all tensors are in the right state @@ -309,6 +311,13 @@ class Optimizer8bit(torch.optim.Optimizer): self.page_mng.paged_tensors.append(buff) return buff + def prefetch_state(self, p): + if self.is_paged: + state = self.state[p] + F.prefetch_tensor(state['state1']) + if 'state2' in state: + F.prefetch_tensor(state['state2']) + class Optimizer2State(Optimizer8bit): def __init__(