maybe it's better to be more explicit in deepspeed configs
This commit is contained in:
parent
4d93a16ef7
commit
04a80d6b55
|
@ -340,7 +340,11 @@ class DeepSpeed:
|
|||
use_compression_training: bool = False
|
||||
compression_bits: int = 8
|
||||
inferencing: bool = False
|
||||
|
||||
amp: bool = False
|
||||
fp16: bool = False
|
||||
|
||||
config: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
||||
|
||||
@cached_property
|
||||
def ds_cfg(self):
|
||||
|
@ -368,7 +372,7 @@ class DeepSpeed:
|
|||
autotune_params['exps_dir'] = str( cfg.relpath / "autotune" / "exps_" )
|
||||
|
||||
# DeepSpeed fp16 is incompatible with its AMP
|
||||
if cfg.trainer.weight_dtype.lower() == "float16":
|
||||
if cfg.trainer.weight_dtype.lower() == "float16" and self.fp16:
|
||||
self.amp = False
|
||||
|
||||
# disable local AMP
|
||||
|
@ -388,9 +392,9 @@ class DeepSpeed:
|
|||
} if not cfg.hyperparameters.torch_scheduler else None,
|
||||
"gradient_clipping": cfg.hyperparameters.gradient_clipping,
|
||||
"fp16": {
|
||||
"enabled": cfg.trainer.weight_dtype.lower() == "float16",
|
||||
"enabled": cfg.trainer.weight_dtype.lower() == "float16" and self.fp16,
|
||||
"auto_cast": True, # ???
|
||||
} if not self.amp else None,
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": cfg.trainer.weight_dtype.lower() == "bfloat16",
|
||||
},
|
||||
|
@ -482,6 +486,8 @@ class DeepSpeed:
|
|||
|
||||
if os.path.exists("./data/ds_config.json"):
|
||||
ds_cfg.update(json.load(open("./data/ds_config.json", "r", encoding="utf-8")))
|
||||
else:
|
||||
ds_cfg.update(self.config)
|
||||
|
||||
return ds_cfg
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user