From 88e9b9caff66f60c8633a6ed335263d00ccabe69 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 11 May 2024 17:29:01 -0500 Subject: [PATCH] local ddp fix --- vall_e/engines/base.py | 37 ++++++++++++++++++++----------------- vall_e/models/base.py | 1 + 2 files changed, 21 insertions(+), 17 deletions(-) diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index 6d4ccb7..de598e0 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -115,22 +115,25 @@ class Engine(): return dispatch_attribute(self.module, *args, **kwargs) def save_checkpoint(self, save_dir, tag ): - save_path = save_dir / tag / "state.pth" - save_path.parent.mkdir(parents=True, exist_ok=True) - torch.save({ - "module": self.module.state_dict(), - "optimizer": self.optimizer.state_dict() if self.optimizer is not None else None, - "lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None, - - "stats": { - "global_step": self.global_step, - "micro_step": self.micro_step, - "global_samples": self.global_samples, - "tokens_processed": self.tokens_processed, - } - }, save_path) + if is_global_leader(): + save_path = save_dir / tag / "state.pth" + save_path.parent.mkdir(parents=True, exist_ok=True) + torch.save({ + "module": self.module.state_dict(), + "optimizer": self.optimizer.state_dict() if self.optimizer is not None else None, + "lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None, + + "stats": { + "global_step": self.global_step, + "micro_step": self.micro_step, + "global_samples": self.global_samples, + "tokens_processed": self.tokens_processed, + } + }, save_path) - open(save_dir / "latest", 'w').write( tag ) + open(save_dir / "latest", 'w').write( tag ) + + torch.distributed.barrier() def load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True, load_module_only=False): if tag is None: @@ -154,10 +157,10 @@ class Engine(): load_lr_scheduler_states = load_lr_scheduler_states and self.lr_scheduler is not None and 'lr_scheduler' in state if load_optimizer_states: - self.optimizer.load_state_dict(state['optimizer'], map_location=torch.device(cfg.device)) + self.optimizer.load_state_dict(state['optimizer']) #, map_location=torch.device(cfg.device)) if load_lr_scheduler_states: - self.lr_scheduler.load_state_dict(state['lr_scheduler'], map_location=torch.device(cfg.device)) + self.lr_scheduler.load_state_dict(state['lr_scheduler']) #, map_location=torch.device(cfg.device)) def eval(self): return self.module.eval() diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 6ee5862..dc12392 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -211,6 +211,7 @@ try: else: attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=LowerTriangularMask()) else: + #torch.nn.attention.sdpa_kernel with torch.backends.cuda.sdp_kernel(enable_flash=self.mode == "flash", enable_math=self.mode == "math", enable_mem_efficient=self.mode == "mem_efficient"): attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask)