maybe it's better to be more explicit in deepspeed configs

This commit is contained in:
mrq 2024-05-11 13:57:43 -05:00
parent 4d93a16ef7
commit 04a80d6b55

View File

@ -340,7 +340,11 @@ class DeepSpeed:
use_compression_training: bool = False use_compression_training: bool = False
compression_bits: int = 8 compression_bits: int = 8
inferencing: bool = False inferencing: bool = False
amp: bool = False amp: bool = False
fp16: bool = False
config: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
@cached_property @cached_property
def ds_cfg(self): def ds_cfg(self):
@ -368,7 +372,7 @@ class DeepSpeed:
autotune_params['exps_dir'] = str( cfg.relpath / "autotune" / "exps_" ) autotune_params['exps_dir'] = str( cfg.relpath / "autotune" / "exps_" )
# DeepSpeed fp16 is incompatible with its AMP # DeepSpeed fp16 is incompatible with its AMP
if cfg.trainer.weight_dtype.lower() == "float16": if cfg.trainer.weight_dtype.lower() == "float16" and self.fp16:
self.amp = False self.amp = False
# disable local AMP # disable local AMP
@ -388,9 +392,9 @@ class DeepSpeed:
} if not cfg.hyperparameters.torch_scheduler else None, } if not cfg.hyperparameters.torch_scheduler else None,
"gradient_clipping": cfg.hyperparameters.gradient_clipping, "gradient_clipping": cfg.hyperparameters.gradient_clipping,
"fp16": { "fp16": {
"enabled": cfg.trainer.weight_dtype.lower() == "float16", "enabled": cfg.trainer.weight_dtype.lower() == "float16" and self.fp16,
"auto_cast": True, # ??? "auto_cast": True, # ???
} if not self.amp else None, },
"bf16": { "bf16": {
"enabled": cfg.trainer.weight_dtype.lower() == "bfloat16", "enabled": cfg.trainer.weight_dtype.lower() == "bfloat16",
}, },
@ -482,6 +486,8 @@ class DeepSpeed:
if os.path.exists("./data/ds_config.json"): if os.path.exists("./data/ds_config.json"):
ds_cfg.update(json.load(open("./data/ds_config.json", "r", encoding="utf-8"))) ds_cfg.update(json.load(open("./data/ds_config.json", "r", encoding="utf-8")))
else:
ds_cfg.update(self.config)
return ds_cfg return ds_cfg