diff --git a/vall_e/config.py b/vall_e/config.py index ef6d515..8ae3309 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -340,7 +340,11 @@ class DeepSpeed: use_compression_training: bool = False compression_bits: int = 8 inferencing: bool = False + amp: bool = False + fp16: bool = False + + config: dict = field(default_factory=lambda: {}) # to pass through deepspeed config @cached_property def ds_cfg(self): @@ -368,7 +372,7 @@ class DeepSpeed: autotune_params['exps_dir'] = str( cfg.relpath / "autotune" / "exps_" ) # DeepSpeed fp16 is incompatible with its AMP - if cfg.trainer.weight_dtype.lower() == "float16": + if cfg.trainer.weight_dtype.lower() == "float16" and self.fp16: self.amp = False # disable local AMP @@ -388,9 +392,9 @@ class DeepSpeed: } if not cfg.hyperparameters.torch_scheduler else None, "gradient_clipping": cfg.hyperparameters.gradient_clipping, "fp16": { - "enabled": cfg.trainer.weight_dtype.lower() == "float16", + "enabled": cfg.trainer.weight_dtype.lower() == "float16" and self.fp16, "auto_cast": True, # ??? - } if not self.amp else None, + }, "bf16": { "enabled": cfg.trainer.weight_dtype.lower() == "bfloat16", }, @@ -482,6 +486,8 @@ class DeepSpeed: if os.path.exists("./data/ds_config.json"): ds_cfg.update(json.load(open("./data/ds_config.json", "r", encoding="utf-8"))) + else: + ds_cfg.update(self.config) return ds_cfg