maybe it's better to be more explicit in deepspeed configs

This commit is contained in:
mrq 2024-05-11 13:57:43 -05:00
parent 4d93a16ef7
commit 04a80d6b55

View File

@ -340,7 +340,11 @@ class DeepSpeed:
use_compression_training: bool = False
compression_bits: int = 8
inferencing: bool = False
amp: bool = False
fp16: bool = False
config: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
@cached_property
def ds_cfg(self):
@ -368,7 +372,7 @@ class DeepSpeed:
autotune_params['exps_dir'] = str( cfg.relpath / "autotune" / "exps_" )
# DeepSpeed fp16 is incompatible with its AMP
if cfg.trainer.weight_dtype.lower() == "float16":
if cfg.trainer.weight_dtype.lower() == "float16" and self.fp16:
self.amp = False
# disable local AMP
@ -388,9 +392,9 @@ class DeepSpeed:
} if not cfg.hyperparameters.torch_scheduler else None,
"gradient_clipping": cfg.hyperparameters.gradient_clipping,
"fp16": {
"enabled": cfg.trainer.weight_dtype.lower() == "float16",
"enabled": cfg.trainer.weight_dtype.lower() == "float16" and self.fp16,
"auto_cast": True, # ???
} if not self.amp else None,
},
"bf16": {
"enabled": cfg.trainer.weight_dtype.lower() == "bfloat16",
},
@ -482,6 +486,8 @@ class DeepSpeed:
if os.path.exists("./data/ds_config.json"):
ds_cfg.update(json.load(open("./data/ds_config.json", "r", encoding="utf-8")))
else:
ds_cfg.update(self.config)
return ds_cfg