diff --git a/vall_e/config.py b/vall_e/config.py index 722a272..05dec9e 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -574,6 +574,7 @@ class DeepSpeed: max_loss_scale: float = 1048576.0 loss_scale = 0.0 + profile: bool = False config: dict = field(default_factory=lambda: {}) # to pass through deepspeed config @cached_property @@ -592,6 +593,12 @@ class DeepSpeed: autotune_params = cfg.hyperparameters.autotune_params + profiler_path = str( cfg.rel_path / "profiler.log" ) + + ds_cfg_path = cfg.rel_path / "ds_config.json" + if not ds_cfg_path.exists(): + ds_cfg_path = Path("./data/ds_config.json") + if "enabled" not in autotune_params: autotune_params['enabled'] = True @@ -710,6 +717,14 @@ class DeepSpeed: } if self.zero_optimization_level > 0 else None, "comms_logger": { "enabled": False + }, + "flops_profiler": { + "enabled": self.profile, + "profile_step": 1, + "module_depth": -1, + "top_modules": 1, + "detailed": True, + "output_file": profiler_path } } @@ -717,10 +732,9 @@ class DeepSpeed: for k in null_keys: del ds_cfg[k] - if os.path.exists("./data/ds_config.json"): - ds_cfg.update(json.loads(open("./data/ds_config.json", "r", encoding="utf-8")).read()) - else: - ds_cfg.update(self.config) + ds_cfg.update(self.config) + if ds_cfg_path.exists(): + ds_cfg.update( json_read( ds_cfg_path ) ) return ds_cfg diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index 97915e7..6b34689 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -1222,7 +1222,7 @@ class Base_V2(nn.Module): # create special masks # to-do, create it if mixed (although I expect this model to be purely non-causal) if self.use_segmented_attention_mask and not any(is_causal): - aux_lens = torch.zeros((batch_size, 2), device=x.device, dtype=torch.int32) + aux_lens = torch.ones((batch_size, 2), device=x.device, dtype=torch.int32) * 2 # fill aux lens for batch_index, batch_input in enumerate( inputs ): for name, input in batch_input: