maybe it's better to be more explicit in deepspeed configs
This commit is contained in:
parent
4d93a16ef7
commit
04a80d6b55
|
@ -340,7 +340,11 @@ class DeepSpeed:
|
||||||
use_compression_training: bool = False
|
use_compression_training: bool = False
|
||||||
compression_bits: int = 8
|
compression_bits: int = 8
|
||||||
inferencing: bool = False
|
inferencing: bool = False
|
||||||
|
|
||||||
amp: bool = False
|
amp: bool = False
|
||||||
|
fp16: bool = False
|
||||||
|
|
||||||
|
config: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def ds_cfg(self):
|
def ds_cfg(self):
|
||||||
|
@ -368,7 +372,7 @@ class DeepSpeed:
|
||||||
autotune_params['exps_dir'] = str( cfg.relpath / "autotune" / "exps_" )
|
autotune_params['exps_dir'] = str( cfg.relpath / "autotune" / "exps_" )
|
||||||
|
|
||||||
# DeepSpeed fp16 is incompatible with its AMP
|
# DeepSpeed fp16 is incompatible with its AMP
|
||||||
if cfg.trainer.weight_dtype.lower() == "float16":
|
if cfg.trainer.weight_dtype.lower() == "float16" and self.fp16:
|
||||||
self.amp = False
|
self.amp = False
|
||||||
|
|
||||||
# disable local AMP
|
# disable local AMP
|
||||||
|
@ -388,9 +392,9 @@ class DeepSpeed:
|
||||||
} if not cfg.hyperparameters.torch_scheduler else None,
|
} if not cfg.hyperparameters.torch_scheduler else None,
|
||||||
"gradient_clipping": cfg.hyperparameters.gradient_clipping,
|
"gradient_clipping": cfg.hyperparameters.gradient_clipping,
|
||||||
"fp16": {
|
"fp16": {
|
||||||
"enabled": cfg.trainer.weight_dtype.lower() == "float16",
|
"enabled": cfg.trainer.weight_dtype.lower() == "float16" and self.fp16,
|
||||||
"auto_cast": True, # ???
|
"auto_cast": True, # ???
|
||||||
} if not self.amp else None,
|
},
|
||||||
"bf16": {
|
"bf16": {
|
||||||
"enabled": cfg.trainer.weight_dtype.lower() == "bfloat16",
|
"enabled": cfg.trainer.weight_dtype.lower() == "bfloat16",
|
||||||
},
|
},
|
||||||
|
@ -482,6 +486,8 @@ class DeepSpeed:
|
||||||
|
|
||||||
if os.path.exists("./data/ds_config.json"):
|
if os.path.exists("./data/ds_config.json"):
|
||||||
ds_cfg.update(json.load(open("./data/ds_config.json", "r", encoding="utf-8")))
|
ds_cfg.update(json.load(open("./data/ds_config.json", "r", encoding="utf-8")))
|
||||||
|
else:
|
||||||
|
ds_cfg.update(self.config)
|
||||||
|
|
||||||
return ds_cfg
|
return ds_cfg
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user