a bit more cleanup for deepspeed ds_cfg creation

This commit is contained in:
mrq 2024-05-09 21:00:26 -05:00
parent 0d5d545a40
commit 6ed6ab8c03

View File

@ -305,14 +305,16 @@ class Hyperparameters:
gradient_clipping: int | float = 100 gradient_clipping: int | float = 100
optimizer: str = "Adamw" optimizer: str = "Adamw"
torch_optimizer: bool = False optimizer_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
optimizer_params: dict = field(default_factory=lambda: {})
learning_rate: float = 3.25e-4 learning_rate: float = 3.25e-4
warmup_steps: int = 0
scheduler: str = "" scheduler: str = ""
scheduler_type: str = "" # deprecated scheduler_type: str = "" # deprecated
scheduler_params: dict = field(default_factory=lambda: {}) scheduler_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
torch_optimizer: bool = False
torch_scheduler: bool = False torch_scheduler: bool = False
@dataclass() @dataclass()
@ -336,21 +338,28 @@ class DeepSpeed:
@cached_property @cached_property
def ds_cfg(self): def ds_cfg(self):
scheduler_params = {} optimizer_params = cfg.hyperparameters.optimizer_params
for k in cfg.hyperparameters.scheduler_params:
scheduler_params[k] = cfg.hyperparameters.scheduler_params[k] if 'lr' not in optimizer_params:
optimizer_params["lr"] = cfg.hyperparameters.learning_rate,
if cfg.hyperparameters.scheduler == "WarmupDecayLR" and 'total_num_steps' not in scheduler_params: scheduler_params = cfg.hyperparameters.scheduler_params
if 'warmup_num_steps' not in scheduler_params:
scheduler_params['warmup_num_steps'] = cfg.hyperparameters.warmup_steps
if 'total_num_steps' not in scheduler_params:
scheduler_params['total_num_steps'] = cfg.trainer.iterations scheduler_params['total_num_steps'] = cfg.trainer.iterations
# documentation says neither can work
if cfg.trainer.weight_dtype.lower() == "float16":
cfg.trainer.amp = False
ds_cfg = { ds_cfg = {
"train_micro_batch_size_per_gpu": cfg.hyperparameters.batch_size, "train_micro_batch_size_per_gpu": cfg.hyperparameters.batch_size,
"gradient_accumulation_steps": cfg.hyperparameters.gradient_accumulation_steps, "gradient_accumulation_steps": cfg.hyperparameters.gradient_accumulation_steps,
"optimizer": { "optimizer": {
"type": cfg.hyperparameters.optimizer, "type": cfg.hyperparameters.optimizer,
"params": { "params": optimizer_params,
"lr": cfg.hyperparameters.learning_rate,
}
} if not cfg.hyperparameters.torch_optimizer else None, } if not cfg.hyperparameters.torch_optimizer else None,
"scheduler": { "scheduler": {
"type": cfg.hyperparameters.scheduler, "type": cfg.hyperparameters.scheduler,
@ -358,12 +367,15 @@ class DeepSpeed:
} if not cfg.hyperparameters.torch_scheduler else None, } if not cfg.hyperparameters.torch_scheduler else None,
"gradient_clipping": cfg.hyperparameters.gradient_clipping, "gradient_clipping": cfg.hyperparameters.gradient_clipping,
"fp16": { "fp16": {
"enabled": True, "enabled": cfg.trainer.weight_dtype.lower() == "float16",
"auto_cast": True, "auto_cast": False, # ???
} if cfg.trainer.weight_dtype.lower() == "float16" and not cfg.trainer.amp else None,
"bf16": {
"enabled": cfg.trainer.weight_dtype.lower() == "bfloat16" and not cfg.trainer.amp
}, },
"bf16": {
"enabled": cfg.trainer.weight_dtype.lower() == "bfloat16",
},
"amp": {
"enabled": cfg.trainer.amp,
},
"compression_training": { "compression_training": {
"weight_quantization": { "weight_quantization": {
"shared_parameters":{ "shared_parameters":{
@ -387,11 +399,7 @@ class DeepSpeed:
"target_bits": self.compression_bits, "target_bits": self.compression_bits,
"quantization_period": 0 "quantization_period": 0
}, },
"modules": [ "modules": [ "self_attn", "mlp" ] # for LLaMA, need to find for other arches
# "^.+?$"
"blocks", # for transformer-based models
"retnet", # for RetNets-based models
]
} }
} }
}, },
@ -415,11 +423,7 @@ class DeepSpeed:
"params": { "params": {
"bits": self.compression_bits, "bits": self.compression_bits,
}, },
"modules": [ "modules": [ "self_attn", "mlp" ] # for LLaMA, need to find for other arches
# "^.+?$"
"blocks", # for transformer-based models
"retnet", # for RetNets-based models
]
} }
} }
}, },
@ -450,6 +454,9 @@ class DeepSpeed:
} }
} }
# disable local AMP
cfg.trainer.amp = False
null_keys = [ k for k in ds_cfg if not ds_cfg[k] ] null_keys = [ k for k in ds_cfg if not ds_cfg[k] ]
for k in null_keys: for k in null_keys:
del ds_cfg[k] del ds_cfg[k]