diff --git a/vall_e/config.py b/vall_e/config.py index 100fb0a..ab57e7f 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -338,6 +338,7 @@ class DeepSpeed: use_compression_training: bool = False compression_bits: int = 8 inferencing: bool = False + amp: bool = True @cached_property def ds_cfg(self): @@ -353,10 +354,6 @@ class DeepSpeed: if 'total_num_steps' not in scheduler_params: scheduler_params['total_num_steps'] = cfg.trainer.iterations - # documentation says neither can work - if cfg.trainer.weight_dtype.lower() == "float16": - cfg.trainer.amp = False - autotune_params = cfg.hyperparameters.autotune_params if "enabled" not in autotune_params: @@ -368,6 +365,14 @@ class DeepSpeed: if "exps_dir" not in autotune_params: autotune_params['exps_dir'] = str( cfg.relpath / "autotune" / "exps_" ) + # DeepSpeed fp16 is incompatible with its AMP + if cfg.trainer.weight_dtype.lower() == "float16": + self.amp = False + + # disable local AMP + if self.amp: + cfg.trainer.amp = False + ds_cfg = { "train_micro_batch_size_per_gpu": cfg.hyperparameters.batch_size, "gradient_accumulation_steps": cfg.hyperparameters.gradient_accumulation_steps, @@ -382,13 +387,13 @@ class DeepSpeed: "gradient_clipping": cfg.hyperparameters.gradient_clipping, "fp16": { "enabled": cfg.trainer.weight_dtype.lower() == "float16", - "auto_cast": False, # ??? - }, + "auto_cast": True, # ??? + } if not self.amp else None, "bf16": { "enabled": cfg.trainer.weight_dtype.lower() == "bfloat16", }, "amp": { - "enabled": cfg.trainer.amp, + "enabled": self.amp, }, "autotuning": autotune_params if cfg.hyperparameters.autotune else None, "compression_training": { @@ -469,9 +474,6 @@ class DeepSpeed: } } - # disable local AMP - cfg.trainer.amp = False - null_keys = [ k for k in ds_cfg if not ds_cfg[k] ] for k in null_keys: del ds_cfg[k]