some possible sanity with deepspeed config

This commit is contained in:
mrq 2024-05-09 22:48:42 -05:00
parent c4b696ebeb
commit b7bd885651

View File

@ -338,6 +338,7 @@ class DeepSpeed:
use_compression_training: bool = False use_compression_training: bool = False
compression_bits: int = 8 compression_bits: int = 8
inferencing: bool = False inferencing: bool = False
amp: bool = True
@cached_property @cached_property
def ds_cfg(self): def ds_cfg(self):
@ -353,10 +354,6 @@ class DeepSpeed:
if 'total_num_steps' not in scheduler_params: if 'total_num_steps' not in scheduler_params:
scheduler_params['total_num_steps'] = cfg.trainer.iterations scheduler_params['total_num_steps'] = cfg.trainer.iterations
# documentation says neither can work
if cfg.trainer.weight_dtype.lower() == "float16":
cfg.trainer.amp = False
autotune_params = cfg.hyperparameters.autotune_params autotune_params = cfg.hyperparameters.autotune_params
if "enabled" not in autotune_params: if "enabled" not in autotune_params:
@ -368,6 +365,14 @@ class DeepSpeed:
if "exps_dir" not in autotune_params: if "exps_dir" not in autotune_params:
autotune_params['exps_dir'] = str( cfg.relpath / "autotune" / "exps_" ) autotune_params['exps_dir'] = str( cfg.relpath / "autotune" / "exps_" )
# DeepSpeed fp16 is incompatible with its AMP
if cfg.trainer.weight_dtype.lower() == "float16":
self.amp = False
# disable local AMP
if self.amp:
cfg.trainer.amp = False
ds_cfg = { ds_cfg = {
"train_micro_batch_size_per_gpu": cfg.hyperparameters.batch_size, "train_micro_batch_size_per_gpu": cfg.hyperparameters.batch_size,
"gradient_accumulation_steps": cfg.hyperparameters.gradient_accumulation_steps, "gradient_accumulation_steps": cfg.hyperparameters.gradient_accumulation_steps,
@ -382,13 +387,13 @@ class DeepSpeed:
"gradient_clipping": cfg.hyperparameters.gradient_clipping, "gradient_clipping": cfg.hyperparameters.gradient_clipping,
"fp16": { "fp16": {
"enabled": cfg.trainer.weight_dtype.lower() == "float16", "enabled": cfg.trainer.weight_dtype.lower() == "float16",
"auto_cast": False, # ??? "auto_cast": True, # ???
}, } if not self.amp else None,
"bf16": { "bf16": {
"enabled": cfg.trainer.weight_dtype.lower() == "bfloat16", "enabled": cfg.trainer.weight_dtype.lower() == "bfloat16",
}, },
"amp": { "amp": {
"enabled": cfg.trainer.amp, "enabled": self.amp,
}, },
"autotuning": autotune_params if cfg.hyperparameters.autotune else None, "autotuning": autotune_params if cfg.hyperparameters.autotune else None,
"compression_training": { "compression_training": {
@ -469,9 +474,6 @@ class DeepSpeed:
} }
} }
# disable local AMP
cfg.trainer.amp = False
null_keys = [ k for k in ds_cfg if not ds_cfg[k] ] null_keys = [ k for k in ds_cfg if not ds_cfg[k] ]
for k in null_keys: for k in null_keys:
del ds_cfg[k] del ds_cfg[k]