local ddp fix

This commit is contained in:
mrq 2024-05-11 17:29:01 -05:00
parent 3337c69e5a
commit 88e9b9caff
2 changed files with 21 additions and 17 deletions

View File

@ -115,6 +115,7 @@ class Engine():
return dispatch_attribute(self.module, *args, **kwargs) return dispatch_attribute(self.module, *args, **kwargs)
def save_checkpoint(self, save_dir, tag ): def save_checkpoint(self, save_dir, tag ):
if is_global_leader():
save_path = save_dir / tag / "state.pth" save_path = save_dir / tag / "state.pth"
save_path.parent.mkdir(parents=True, exist_ok=True) save_path.parent.mkdir(parents=True, exist_ok=True)
torch.save({ torch.save({
@ -132,6 +133,8 @@ class Engine():
open(save_dir / "latest", 'w').write( tag ) open(save_dir / "latest", 'w').write( tag )
torch.distributed.barrier()
def load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True, load_module_only=False): def load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True, load_module_only=False):
if tag is None: if tag is None:
tag_path = load_dir / "latest" tag_path = load_dir / "latest"
@ -154,10 +157,10 @@ class Engine():
load_lr_scheduler_states = load_lr_scheduler_states and self.lr_scheduler is not None and 'lr_scheduler' in state load_lr_scheduler_states = load_lr_scheduler_states and self.lr_scheduler is not None and 'lr_scheduler' in state
if load_optimizer_states: if load_optimizer_states:
self.optimizer.load_state_dict(state['optimizer'], map_location=torch.device(cfg.device)) self.optimizer.load_state_dict(state['optimizer']) #, map_location=torch.device(cfg.device))
if load_lr_scheduler_states: if load_lr_scheduler_states:
self.lr_scheduler.load_state_dict(state['lr_scheduler'], map_location=torch.device(cfg.device)) self.lr_scheduler.load_state_dict(state['lr_scheduler']) #, map_location=torch.device(cfg.device))
def eval(self): def eval(self):
return self.module.eval() return self.module.eval()

View File

@ -211,6 +211,7 @@ try:
else: else:
attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=LowerTriangularMask()) attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=LowerTriangularMask())
else: else:
#torch.nn.attention.sdpa_kernel
with torch.backends.cuda.sdp_kernel(enable_flash=self.mode == "flash", enable_math=self.mode == "math", enable_mem_efficient=self.mode == "mem_efficient"): with torch.backends.cuda.sdp_kernel(enable_flash=self.mode == "flash", enable_math=self.mode == "math", enable_mem_efficient=self.mode == "mem_efficient"):
attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask) attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask)