autotune?
This commit is contained in:
parent
6ed6ab8c03
commit
b6131565ad
|
@ -313,6 +313,9 @@ class Hyperparameters:
|
||||||
scheduler: str = ""
|
scheduler: str = ""
|
||||||
scheduler_type: str = "" # deprecated
|
scheduler_type: str = "" # deprecated
|
||||||
scheduler_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
scheduler_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
||||||
|
|
||||||
|
autotune: bool = False
|
||||||
|
autotune_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
||||||
|
|
||||||
torch_optimizer: bool = False
|
torch_optimizer: bool = False
|
||||||
torch_scheduler: bool = False
|
torch_scheduler: bool = False
|
||||||
|
@ -354,6 +357,17 @@ class DeepSpeed:
|
||||||
if cfg.trainer.weight_dtype.lower() == "float16":
|
if cfg.trainer.weight_dtype.lower() == "float16":
|
||||||
cfg.trainer.amp = False
|
cfg.trainer.amp = False
|
||||||
|
|
||||||
|
autotune_params = cfg.hyperparameters.autotune_params
|
||||||
|
|
||||||
|
if "enabled" not in autotune_params:
|
||||||
|
autotune_params['enabled'] = True
|
||||||
|
|
||||||
|
if "results_dir" not in autotune_params:
|
||||||
|
autotune_params['results_dir'] = str( cfg.relpath / "autotune" / "results" )
|
||||||
|
|
||||||
|
if "exps_dir" not in autotune_params:
|
||||||
|
autotune_params['exps_dir'] = str( cfg.relpath / "autotune" / "exps_" )
|
||||||
|
|
||||||
ds_cfg = {
|
ds_cfg = {
|
||||||
"train_micro_batch_size_per_gpu": cfg.hyperparameters.batch_size,
|
"train_micro_batch_size_per_gpu": cfg.hyperparameters.batch_size,
|
||||||
"gradient_accumulation_steps": cfg.hyperparameters.gradient_accumulation_steps,
|
"gradient_accumulation_steps": cfg.hyperparameters.gradient_accumulation_steps,
|
||||||
|
@ -374,8 +388,9 @@ class DeepSpeed:
|
||||||
"enabled": cfg.trainer.weight_dtype.lower() == "bfloat16",
|
"enabled": cfg.trainer.weight_dtype.lower() == "bfloat16",
|
||||||
},
|
},
|
||||||
"amp": {
|
"amp": {
|
||||||
"enabled": cfg.trainer.amp,
|
"enabled": cfg.trainer.amp,
|
||||||
},
|
},
|
||||||
|
"autotuning": autotune_params if cfg.hyperparameters.autotune else None,
|
||||||
"compression_training": {
|
"compression_training": {
|
||||||
"weight_quantization": {
|
"weight_quantization": {
|
||||||
"shared_parameters":{
|
"shared_parameters":{
|
||||||
|
|
Loading…
Reference in New Issue
Block a user