From 34dcb845b55a2cdab6b7bd28f608f98b2fdb7f36 Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 8 Mar 2023 15:31:33 +0000 Subject: [PATCH] actually make using adamw_zero optimizer for multi-gpus work --- models/.template.yaml | 4 +--- src/train.py | 7 +------ src/utils.py | 34 ++++++++++++++++++++++++++++++++-- 3 files changed, 34 insertions(+), 11 deletions(-) diff --git a/models/.template.yaml b/models/.template.yaml index 258ee89..15cfbe3 100755 --- a/models/.template.yaml +++ b/models/.template.yaml @@ -126,9 +126,7 @@ train: ema_enabled: false # I really don't think EMA matters - default_lr_scheme: MultiStepLR - gen_lr_steps: ${gen_lr_steps} #[50000, 100000, 140000, 180000] - lr_gamma: 0.5 + ${learning_rate_scheme} eval: pure: ${validation_enabled} diff --git a/src/train.py b/src/train.py index 79bdcde..144ecc0 100755 --- a/src/train.py +++ b/src/train.py @@ -20,15 +20,10 @@ if __name__ == "__main__": parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') args = parser.parse_args() args.opt = " ".join(args.opt) # absolutely disgusting - - + with open(args.opt, 'r') as file: opt_config = yaml.safe_load(file) - if "WORLD_SIZE" in os.environ: - if int(os.environ["WORLD_SIZE"]) > 1 and opt_config["steps"]["gpt_train"]["optimizer"] == "adamw": - opt_config["steps"]["gpt_train"]["optimizer"] = "adamw_zero" - if "ext" in opt_config and "bitsandbytes" in opt_config["ext"] and not opt_config["ext"]["bitsandbytes"]: os.environ['BITSANDBYTES_OVERRIDE_LINEAR'] = '0' os.environ['BITSANDBYTES_OVERRIDE_EMBEDDING'] = '0' diff --git a/src/utils.py b/src/utils.py index ce318a5..ba0d691 100755 --- a/src/utils.py +++ b/src/utils.py @@ -1008,6 +1008,21 @@ def run_training(config_path, verbose=False, gpus=1, keep_x_past_checkpoints=0, # I don't know if this is still necessary, as it was bitching at me for not doing this, despite it being in a separate process torch.multiprocessing.freeze_support() + # edit any gpu-count-specific variables + with open(config_path, 'r', encoding="utf-8") as f: + yaml_string = f.read() + edited = False + if gpus > 1: + yaml_string = yaml_string.replace(" adamw ", " adamw_zero ") + edited = True + else: + yaml_string = yaml_string.replace(" adamw_zero ", " adamw ") + edited = True + if edited: + print(f'Modified YAML config') + with open(config_path, 'w', encoding="utf-8") as f: + f.write(yaml_string) + unload_tts() unload_whisper() unload_voicefixer() @@ -1347,7 +1362,7 @@ def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learni messages ) -def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weight=None, learning_rate_schedule=None, batch_size=None, gradient_accumulation_size=None, print_rate=None, save_rate=None, validation_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, validation_batch_size=None, output_name=None, resume_path=None, half_p=None, bnb=None, workers=None, source_model=None ): +def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weight=None, learning_rate_scheme=None, learning_rate_schedule=None, batch_size=None, gradient_accumulation_size=None, print_rate=None, save_rate=None, validation_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, validation_batch_size=None, output_name=None, resume_path=None, half_p=None, bnb=None, workers=None, source_model=None ): if not source_model: source_model = f"./models/tortoise/autoregressive{'_half' if half_p else ''}.pth" @@ -1355,7 +1370,6 @@ def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weig "iterations": iterations if iterations else 500, "batch_size": batch_size if batch_size else 64, "learning_rate": learning_rate if learning_rate else 1e-5, - "gen_lr_steps": learning_rate_schedule if learning_rate_schedule else EPOCH_SCHEDULE, "gradient_accumulation_size": gradient_accumulation_size if gradient_accumulation_size else 4, "print_rate": print_rate if print_rate else 1, "save_rate": save_rate if save_rate else 50, @@ -1379,6 +1393,22 @@ def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weig 'workers': workers if workers else 2, } + LEARNING_RATE_SCHEMES = ["MultiStepLR", "CosineAnnealingLR_Restart"] + if learning_rate_scheme not in LEARNING_RATE_SCHEMES: + learning_rate_scheme = LEARNING_RATE_SCHEMES[0] + + learning_rate_schema = [f"default_lr_scheme: {learning_rate_scheme}"] + if learning_rate_scheme == "MultiStepLR": + learning_rate_schema.append(f" gen_lr_steps: {learning_rate_schedule if learning_rate_schedule else EPOCH_SCHEDULE}") + learning_rate_schema.append(f" lr_gamma: 0.5") + elif learning_rate_scheme == "CosineAnnealingLR_Restart": + learning_rate_schema.append(f" T_period: [120000, 120000, 120000]") + learning_rate_schema.append(f" warmup: 10000") + learning_rate_schema.append(f" eta_min: .01") + learning_rate_schema.append(f" restarts: [140000, 280000]") + learning_rate_schema.append(f" restart_weights: [.5, .25]") + settings['learning_rate_scheme'] = "\n".join(learning_rate_schema) + if resume_path: settings['pretrain_model_gpt'] = f"# {settings['pretrain_model_gpt']}" else: