handle case of dropping cond for segment mask
This commit is contained in:
parent
89e52b9877
commit
8d848ed549
|
@ -574,6 +574,7 @@ class DeepSpeed:
|
||||||
max_loss_scale: float = 1048576.0
|
max_loss_scale: float = 1048576.0
|
||||||
loss_scale = 0.0
|
loss_scale = 0.0
|
||||||
|
|
||||||
|
profile: bool = False
|
||||||
config: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
config: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
|
@ -592,6 +593,12 @@ class DeepSpeed:
|
||||||
|
|
||||||
autotune_params = cfg.hyperparameters.autotune_params
|
autotune_params = cfg.hyperparameters.autotune_params
|
||||||
|
|
||||||
|
profiler_path = str( cfg.rel_path / "profiler.log" )
|
||||||
|
|
||||||
|
ds_cfg_path = cfg.rel_path / "ds_config.json"
|
||||||
|
if not ds_cfg_path.exists():
|
||||||
|
ds_cfg_path = Path("./data/ds_config.json")
|
||||||
|
|
||||||
if "enabled" not in autotune_params:
|
if "enabled" not in autotune_params:
|
||||||
autotune_params['enabled'] = True
|
autotune_params['enabled'] = True
|
||||||
|
|
||||||
|
@ -710,6 +717,14 @@ class DeepSpeed:
|
||||||
} if self.zero_optimization_level > 0 else None,
|
} if self.zero_optimization_level > 0 else None,
|
||||||
"comms_logger": {
|
"comms_logger": {
|
||||||
"enabled": False
|
"enabled": False
|
||||||
|
},
|
||||||
|
"flops_profiler": {
|
||||||
|
"enabled": self.profile,
|
||||||
|
"profile_step": 1,
|
||||||
|
"module_depth": -1,
|
||||||
|
"top_modules": 1,
|
||||||
|
"detailed": True,
|
||||||
|
"output_file": profiler_path
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -717,10 +732,9 @@ class DeepSpeed:
|
||||||
for k in null_keys:
|
for k in null_keys:
|
||||||
del ds_cfg[k]
|
del ds_cfg[k]
|
||||||
|
|
||||||
if os.path.exists("./data/ds_config.json"):
|
ds_cfg.update(self.config)
|
||||||
ds_cfg.update(json.loads(open("./data/ds_config.json", "r", encoding="utf-8")).read())
|
if ds_cfg_path.exists():
|
||||||
else:
|
ds_cfg.update( json_read( ds_cfg_path ) )
|
||||||
ds_cfg.update(self.config)
|
|
||||||
|
|
||||||
return ds_cfg
|
return ds_cfg
|
||||||
|
|
||||||
|
|
|
@ -1222,7 +1222,7 @@ class Base_V2(nn.Module):
|
||||||
# create special masks
|
# create special masks
|
||||||
# to-do, create it if mixed (although I expect this model to be purely non-causal)
|
# to-do, create it if mixed (although I expect this model to be purely non-causal)
|
||||||
if self.use_segmented_attention_mask and not any(is_causal):
|
if self.use_segmented_attention_mask and not any(is_causal):
|
||||||
aux_lens = torch.zeros((batch_size, 2), device=x.device, dtype=torch.int32)
|
aux_lens = torch.ones((batch_size, 2), device=x.device, dtype=torch.int32) * 2
|
||||||
# fill aux lens
|
# fill aux lens
|
||||||
for batch_index, batch_input in enumerate( inputs ):
|
for batch_index, batch_input in enumerate( inputs ):
|
||||||
for name, input in batch_input:
|
for name, input in batch_input:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user