diff --git a/vall_e/config.py b/vall_e/config.py index 810d1d0..33d07ac 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -305,14 +305,16 @@ class Hyperparameters: gradient_clipping: int | float = 100 optimizer: str = "Adamw" - torch_optimizer: bool = False - - optimizer_params: dict = field(default_factory=lambda: {}) + optimizer_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config + learning_rate: float = 3.25e-4 + warmup_steps: int = 0 scheduler: str = "" scheduler_type: str = "" # deprecated - scheduler_params: dict = field(default_factory=lambda: {}) + scheduler_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config + + torch_optimizer: bool = False torch_scheduler: bool = False @dataclass() @@ -336,21 +338,28 @@ class DeepSpeed: @cached_property def ds_cfg(self): - scheduler_params = {} - for k in cfg.hyperparameters.scheduler_params: - scheduler_params[k] = cfg.hyperparameters.scheduler_params[k] + optimizer_params = cfg.hyperparameters.optimizer_params + + if 'lr' not in optimizer_params: + optimizer_params["lr"] = cfg.hyperparameters.learning_rate, - if cfg.hyperparameters.scheduler == "WarmupDecayLR" and 'total_num_steps' not in scheduler_params: + scheduler_params = cfg.hyperparameters.scheduler_params + if 'warmup_num_steps' not in scheduler_params: + scheduler_params['warmup_num_steps'] = cfg.hyperparameters.warmup_steps + + if 'total_num_steps' not in scheduler_params: scheduler_params['total_num_steps'] = cfg.trainer.iterations + # documentation says neither can work + if cfg.trainer.weight_dtype.lower() == "float16": + cfg.trainer.amp = False + ds_cfg = { "train_micro_batch_size_per_gpu": cfg.hyperparameters.batch_size, "gradient_accumulation_steps": cfg.hyperparameters.gradient_accumulation_steps, "optimizer": { "type": cfg.hyperparameters.optimizer, - "params": { - "lr": cfg.hyperparameters.learning_rate, - } + "params": optimizer_params, } if not cfg.hyperparameters.torch_optimizer else None, "scheduler": { "type": cfg.hyperparameters.scheduler, @@ -358,12 +367,15 @@ class DeepSpeed: } if not cfg.hyperparameters.torch_scheduler else None, "gradient_clipping": cfg.hyperparameters.gradient_clipping, "fp16": { - "enabled": True, - "auto_cast": True, - } if cfg.trainer.weight_dtype.lower() == "float16" and not cfg.trainer.amp else None, - "bf16": { - "enabled": cfg.trainer.weight_dtype.lower() == "bfloat16" and not cfg.trainer.amp + "enabled": cfg.trainer.weight_dtype.lower() == "float16", + "auto_cast": False, # ??? }, + "bf16": { + "enabled": cfg.trainer.weight_dtype.lower() == "bfloat16", + }, + "amp": { + "enabled": cfg.trainer.amp, + }, "compression_training": { "weight_quantization": { "shared_parameters":{ @@ -387,11 +399,7 @@ class DeepSpeed: "target_bits": self.compression_bits, "quantization_period": 0 }, - "modules": [ - # "^.+?$" - "blocks", # for transformer-based models - "retnet", # for RetNets-based models - ] + "modules": [ "self_attn", "mlp" ] # for LLaMA, need to find for other arches } } }, @@ -415,11 +423,7 @@ class DeepSpeed: "params": { "bits": self.compression_bits, }, - "modules": [ - # "^.+?$" - "blocks", # for transformer-based models - "retnet", # for RetNets-based models - ] + "modules": [ "self_attn", "mlp" ] # for LLaMA, need to find for other arches } } }, @@ -450,6 +454,9 @@ class DeepSpeed: } } + # disable local AMP + cfg.trainer.amp = False + null_keys = [ k for k in ds_cfg if not ds_cfg[k] ] for k in null_keys: del ds_cfg[k]