autotune?

This commit is contained in:
mrq 2024-05-09 21:25:40 -05:00
parent 6ed6ab8c03
commit b6131565ad

View File

@ -313,6 +313,9 @@ class Hyperparameters:
scheduler: str = ""
scheduler_type: str = "" # deprecated
scheduler_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
autotune: bool = False
autotune_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
torch_optimizer: bool = False
torch_scheduler: bool = False
@ -354,6 +357,17 @@ class DeepSpeed:
if cfg.trainer.weight_dtype.lower() == "float16":
cfg.trainer.amp = False
autotune_params = cfg.hyperparameters.autotune_params
if "enabled" not in autotune_params:
autotune_params['enabled'] = True
if "results_dir" not in autotune_params:
autotune_params['results_dir'] = str( cfg.relpath / "autotune" / "results" )
if "exps_dir" not in autotune_params:
autotune_params['exps_dir'] = str( cfg.relpath / "autotune" / "exps_" )
ds_cfg = {
"train_micro_batch_size_per_gpu": cfg.hyperparameters.batch_size,
"gradient_accumulation_steps": cfg.hyperparameters.gradient_accumulation_steps,
@ -374,8 +388,9 @@ class DeepSpeed:
"enabled": cfg.trainer.weight_dtype.lower() == "bfloat16",
},
"amp": {
"enabled": cfg.trainer.amp,
},
"enabled": cfg.trainer.amp,
},
"autotuning": autotune_params if cfg.hyperparameters.autotune else None,
"compression_training": {
"weight_quantization": {
"shared_parameters":{