local ddp fix
This commit is contained in:
parent
3337c69e5a
commit
88e9b9caff
|
@ -115,22 +115,25 @@ class Engine():
|
||||||
return dispatch_attribute(self.module, *args, **kwargs)
|
return dispatch_attribute(self.module, *args, **kwargs)
|
||||||
|
|
||||||
def save_checkpoint(self, save_dir, tag ):
|
def save_checkpoint(self, save_dir, tag ):
|
||||||
save_path = save_dir / tag / "state.pth"
|
if is_global_leader():
|
||||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
save_path = save_dir / tag / "state.pth"
|
||||||
torch.save({
|
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
"module": self.module.state_dict(),
|
torch.save({
|
||||||
"optimizer": self.optimizer.state_dict() if self.optimizer is not None else None,
|
"module": self.module.state_dict(),
|
||||||
"lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
|
"optimizer": self.optimizer.state_dict() if self.optimizer is not None else None,
|
||||||
|
"lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
|
||||||
|
|
||||||
"stats": {
|
"stats": {
|
||||||
"global_step": self.global_step,
|
"global_step": self.global_step,
|
||||||
"micro_step": self.micro_step,
|
"micro_step": self.micro_step,
|
||||||
"global_samples": self.global_samples,
|
"global_samples": self.global_samples,
|
||||||
"tokens_processed": self.tokens_processed,
|
"tokens_processed": self.tokens_processed,
|
||||||
}
|
}
|
||||||
}, save_path)
|
}, save_path)
|
||||||
|
|
||||||
open(save_dir / "latest", 'w').write( tag )
|
open(save_dir / "latest", 'w').write( tag )
|
||||||
|
|
||||||
|
torch.distributed.barrier()
|
||||||
|
|
||||||
def load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True, load_module_only=False):
|
def load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True, load_module_only=False):
|
||||||
if tag is None:
|
if tag is None:
|
||||||
|
@ -154,10 +157,10 @@ class Engine():
|
||||||
load_lr_scheduler_states = load_lr_scheduler_states and self.lr_scheduler is not None and 'lr_scheduler' in state
|
load_lr_scheduler_states = load_lr_scheduler_states and self.lr_scheduler is not None and 'lr_scheduler' in state
|
||||||
|
|
||||||
if load_optimizer_states:
|
if load_optimizer_states:
|
||||||
self.optimizer.load_state_dict(state['optimizer'], map_location=torch.device(cfg.device))
|
self.optimizer.load_state_dict(state['optimizer']) #, map_location=torch.device(cfg.device))
|
||||||
|
|
||||||
if load_lr_scheduler_states:
|
if load_lr_scheduler_states:
|
||||||
self.lr_scheduler.load_state_dict(state['lr_scheduler'], map_location=torch.device(cfg.device))
|
self.lr_scheduler.load_state_dict(state['lr_scheduler']) #, map_location=torch.device(cfg.device))
|
||||||
|
|
||||||
def eval(self):
|
def eval(self):
|
||||||
return self.module.eval()
|
return self.module.eval()
|
||||||
|
|
|
@ -211,6 +211,7 @@ try:
|
||||||
else:
|
else:
|
||||||
attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=LowerTriangularMask())
|
attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=LowerTriangularMask())
|
||||||
else:
|
else:
|
||||||
|
#torch.nn.attention.sdpa_kernel
|
||||||
with torch.backends.cuda.sdp_kernel(enable_flash=self.mode == "flash", enable_math=self.mode == "math", enable_mem_efficient=self.mode == "mem_efficient"):
|
with torch.backends.cuda.sdp_kernel(enable_flash=self.mode == "flash", enable_math=self.mode == "math", enable_mem_efficient=self.mode == "mem_efficient"):
|
||||||
attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask)
|
attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user