handle case of dropping cond for segment mask

This commit is contained in:
mrq 2025-03-07 14:11:58 -06:00
parent 89e52b9877
commit 8d848ed549
2 changed files with 19 additions and 5 deletions

View File

@ -574,6 +574,7 @@ class DeepSpeed:
max_loss_scale: float = 1048576.0
loss_scale = 0.0
profile: bool = False
config: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
@cached_property
@ -592,6 +593,12 @@ class DeepSpeed:
autotune_params = cfg.hyperparameters.autotune_params
profiler_path = str( cfg.rel_path / "profiler.log" )
ds_cfg_path = cfg.rel_path / "ds_config.json"
if not ds_cfg_path.exists():
ds_cfg_path = Path("./data/ds_config.json")
if "enabled" not in autotune_params:
autotune_params['enabled'] = True
@ -710,6 +717,14 @@ class DeepSpeed:
} if self.zero_optimization_level > 0 else None,
"comms_logger": {
"enabled": False
},
"flops_profiler": {
"enabled": self.profile,
"profile_step": 1,
"module_depth": -1,
"top_modules": 1,
"detailed": True,
"output_file": profiler_path
}
}
@ -717,10 +732,9 @@ class DeepSpeed:
for k in null_keys:
del ds_cfg[k]
if os.path.exists("./data/ds_config.json"):
ds_cfg.update(json.loads(open("./data/ds_config.json", "r", encoding="utf-8")).read())
else:
ds_cfg.update(self.config)
if ds_cfg_path.exists():
ds_cfg.update( json_read( ds_cfg_path ) )
return ds_cfg

View File

@ -1222,7 +1222,7 @@ class Base_V2(nn.Module):
# create special masks
# to-do, create it if mixed (although I expect this model to be purely non-causal)
if self.use_segmented_attention_mask and not any(is_causal):
aux_lens = torch.zeros((batch_size, 2), device=x.device, dtype=torch.int32)
aux_lens = torch.ones((batch_size, 2), device=x.device, dtype=torch.int32) * 2
# fill aux lens
for batch_index, batch_input in enumerate( inputs ):
for name, input in batch_input: