handle case of dropping cond for segment mask
This commit is contained in:
parent
89e52b9877
commit
8d848ed549
|
@ -574,6 +574,7 @@ class DeepSpeed:
|
|||
max_loss_scale: float = 1048576.0
|
||||
loss_scale = 0.0
|
||||
|
||||
profile: bool = False
|
||||
config: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
||||
|
||||
@cached_property
|
||||
|
@ -592,6 +593,12 @@ class DeepSpeed:
|
|||
|
||||
autotune_params = cfg.hyperparameters.autotune_params
|
||||
|
||||
profiler_path = str( cfg.rel_path / "profiler.log" )
|
||||
|
||||
ds_cfg_path = cfg.rel_path / "ds_config.json"
|
||||
if not ds_cfg_path.exists():
|
||||
ds_cfg_path = Path("./data/ds_config.json")
|
||||
|
||||
if "enabled" not in autotune_params:
|
||||
autotune_params['enabled'] = True
|
||||
|
||||
|
@ -710,6 +717,14 @@ class DeepSpeed:
|
|||
} if self.zero_optimization_level > 0 else None,
|
||||
"comms_logger": {
|
||||
"enabled": False
|
||||
},
|
||||
"flops_profiler": {
|
||||
"enabled": self.profile,
|
||||
"profile_step": 1,
|
||||
"module_depth": -1,
|
||||
"top_modules": 1,
|
||||
"detailed": True,
|
||||
"output_file": profiler_path
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -717,10 +732,9 @@ class DeepSpeed:
|
|||
for k in null_keys:
|
||||
del ds_cfg[k]
|
||||
|
||||
if os.path.exists("./data/ds_config.json"):
|
||||
ds_cfg.update(json.loads(open("./data/ds_config.json", "r", encoding="utf-8")).read())
|
||||
else:
|
||||
ds_cfg.update(self.config)
|
||||
ds_cfg.update(self.config)
|
||||
if ds_cfg_path.exists():
|
||||
ds_cfg.update( json_read( ds_cfg_path ) )
|
||||
|
||||
return ds_cfg
|
||||
|
||||
|
|
|
@ -1222,7 +1222,7 @@ class Base_V2(nn.Module):
|
|||
# create special masks
|
||||
# to-do, create it if mixed (although I expect this model to be purely non-causal)
|
||||
if self.use_segmented_attention_mask and not any(is_causal):
|
||||
aux_lens = torch.zeros((batch_size, 2), device=x.device, dtype=torch.int32)
|
||||
aux_lens = torch.ones((batch_size, 2), device=x.device, dtype=torch.int32) * 2
|
||||
# fill aux lens
|
||||
for batch_index, batch_input in enumerate( inputs ):
|
||||
for name, input in batch_input:
|
||||
|
|
Loading…
Reference in New Issue
Block a user