diff --git a/vall_e/config.py b/vall_e/config.py index 33d07ac..100fb0a 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -313,6 +313,9 @@ class Hyperparameters: scheduler: str = "" scheduler_type: str = "" # deprecated scheduler_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config + + autotune: bool = False + autotune_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config torch_optimizer: bool = False torch_scheduler: bool = False @@ -354,6 +357,17 @@ class DeepSpeed: if cfg.trainer.weight_dtype.lower() == "float16": cfg.trainer.amp = False + autotune_params = cfg.hyperparameters.autotune_params + + if "enabled" not in autotune_params: + autotune_params['enabled'] = True + + if "results_dir" not in autotune_params: + autotune_params['results_dir'] = str( cfg.relpath / "autotune" / "results" ) + + if "exps_dir" not in autotune_params: + autotune_params['exps_dir'] = str( cfg.relpath / "autotune" / "exps_" ) + ds_cfg = { "train_micro_batch_size_per_gpu": cfg.hyperparameters.batch_size, "gradient_accumulation_steps": cfg.hyperparameters.gradient_accumulation_steps, @@ -374,8 +388,9 @@ class DeepSpeed: "enabled": cfg.trainer.weight_dtype.lower() == "bfloat16", }, "amp": { - "enabled": cfg.trainer.amp, - }, + "enabled": cfg.trainer.amp, + }, + "autotuning": autotune_params if cfg.hyperparameters.autotune else None, "compression_training": { "weight_quantization": { "shared_parameters":{