a bit more cleanup for deepspeed ds_cfg creation
This commit is contained in:
parent
0d5d545a40
commit
6ed6ab8c03
|
@ -305,14 +305,16 @@ class Hyperparameters:
|
||||||
gradient_clipping: int | float = 100
|
gradient_clipping: int | float = 100
|
||||||
|
|
||||||
optimizer: str = "Adamw"
|
optimizer: str = "Adamw"
|
||||||
torch_optimizer: bool = False
|
optimizer_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
||||||
|
|
||||||
optimizer_params: dict = field(default_factory=lambda: {})
|
|
||||||
learning_rate: float = 3.25e-4
|
learning_rate: float = 3.25e-4
|
||||||
|
warmup_steps: int = 0
|
||||||
|
|
||||||
scheduler: str = ""
|
scheduler: str = ""
|
||||||
scheduler_type: str = "" # deprecated
|
scheduler_type: str = "" # deprecated
|
||||||
scheduler_params: dict = field(default_factory=lambda: {})
|
scheduler_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
||||||
|
|
||||||
|
torch_optimizer: bool = False
|
||||||
torch_scheduler: bool = False
|
torch_scheduler: bool = False
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
|
@ -336,21 +338,28 @@ class DeepSpeed:
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def ds_cfg(self):
|
def ds_cfg(self):
|
||||||
scheduler_params = {}
|
optimizer_params = cfg.hyperparameters.optimizer_params
|
||||||
for k in cfg.hyperparameters.scheduler_params:
|
|
||||||
scheduler_params[k] = cfg.hyperparameters.scheduler_params[k]
|
if 'lr' not in optimizer_params:
|
||||||
|
optimizer_params["lr"] = cfg.hyperparameters.learning_rate,
|
||||||
|
|
||||||
if cfg.hyperparameters.scheduler == "WarmupDecayLR" and 'total_num_steps' not in scheduler_params:
|
scheduler_params = cfg.hyperparameters.scheduler_params
|
||||||
|
if 'warmup_num_steps' not in scheduler_params:
|
||||||
|
scheduler_params['warmup_num_steps'] = cfg.hyperparameters.warmup_steps
|
||||||
|
|
||||||
|
if 'total_num_steps' not in scheduler_params:
|
||||||
scheduler_params['total_num_steps'] = cfg.trainer.iterations
|
scheduler_params['total_num_steps'] = cfg.trainer.iterations
|
||||||
|
|
||||||
|
# documentation says neither can work
|
||||||
|
if cfg.trainer.weight_dtype.lower() == "float16":
|
||||||
|
cfg.trainer.amp = False
|
||||||
|
|
||||||
ds_cfg = {
|
ds_cfg = {
|
||||||
"train_micro_batch_size_per_gpu": cfg.hyperparameters.batch_size,
|
"train_micro_batch_size_per_gpu": cfg.hyperparameters.batch_size,
|
||||||
"gradient_accumulation_steps": cfg.hyperparameters.gradient_accumulation_steps,
|
"gradient_accumulation_steps": cfg.hyperparameters.gradient_accumulation_steps,
|
||||||
"optimizer": {
|
"optimizer": {
|
||||||
"type": cfg.hyperparameters.optimizer,
|
"type": cfg.hyperparameters.optimizer,
|
||||||
"params": {
|
"params": optimizer_params,
|
||||||
"lr": cfg.hyperparameters.learning_rate,
|
|
||||||
}
|
|
||||||
} if not cfg.hyperparameters.torch_optimizer else None,
|
} if not cfg.hyperparameters.torch_optimizer else None,
|
||||||
"scheduler": {
|
"scheduler": {
|
||||||
"type": cfg.hyperparameters.scheduler,
|
"type": cfg.hyperparameters.scheduler,
|
||||||
|
@ -358,12 +367,15 @@ class DeepSpeed:
|
||||||
} if not cfg.hyperparameters.torch_scheduler else None,
|
} if not cfg.hyperparameters.torch_scheduler else None,
|
||||||
"gradient_clipping": cfg.hyperparameters.gradient_clipping,
|
"gradient_clipping": cfg.hyperparameters.gradient_clipping,
|
||||||
"fp16": {
|
"fp16": {
|
||||||
"enabled": True,
|
"enabled": cfg.trainer.weight_dtype.lower() == "float16",
|
||||||
"auto_cast": True,
|
"auto_cast": False, # ???
|
||||||
} if cfg.trainer.weight_dtype.lower() == "float16" and not cfg.trainer.amp else None,
|
|
||||||
"bf16": {
|
|
||||||
"enabled": cfg.trainer.weight_dtype.lower() == "bfloat16" and not cfg.trainer.amp
|
|
||||||
},
|
},
|
||||||
|
"bf16": {
|
||||||
|
"enabled": cfg.trainer.weight_dtype.lower() == "bfloat16",
|
||||||
|
},
|
||||||
|
"amp": {
|
||||||
|
"enabled": cfg.trainer.amp,
|
||||||
|
},
|
||||||
"compression_training": {
|
"compression_training": {
|
||||||
"weight_quantization": {
|
"weight_quantization": {
|
||||||
"shared_parameters":{
|
"shared_parameters":{
|
||||||
|
@ -387,11 +399,7 @@ class DeepSpeed:
|
||||||
"target_bits": self.compression_bits,
|
"target_bits": self.compression_bits,
|
||||||
"quantization_period": 0
|
"quantization_period": 0
|
||||||
},
|
},
|
||||||
"modules": [
|
"modules": [ "self_attn", "mlp" ] # for LLaMA, need to find for other arches
|
||||||
# "^.+?$"
|
|
||||||
"blocks", # for transformer-based models
|
|
||||||
"retnet", # for RetNets-based models
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -415,11 +423,7 @@ class DeepSpeed:
|
||||||
"params": {
|
"params": {
|
||||||
"bits": self.compression_bits,
|
"bits": self.compression_bits,
|
||||||
},
|
},
|
||||||
"modules": [
|
"modules": [ "self_attn", "mlp" ] # for LLaMA, need to find for other arches
|
||||||
# "^.+?$"
|
|
||||||
"blocks", # for transformer-based models
|
|
||||||
"retnet", # for RetNets-based models
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -450,6 +454,9 @@ class DeepSpeed:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# disable local AMP
|
||||||
|
cfg.trainer.amp = False
|
||||||
|
|
||||||
null_keys = [ k for k in ds_cfg if not ds_cfg[k] ]
|
null_keys = [ k for k in ds_cfg if not ds_cfg[k] ]
|
||||||
for k in null_keys:
|
for k in null_keys:
|
||||||
del ds_cfg[k]
|
del ds_cfg[k]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user