diff --git a/src/train.py b/src/train.py index a3e480e..b1f61a3 100755 --- a/src/train.py +++ b/src/train.py @@ -1,7 +1,7 @@ import os import sys import argparse - +import yaml """ if 'BITSANDBYTES_OVERRIDE_LINEAR' not in os.environ: @@ -14,6 +14,21 @@ if 'BITSANDBYTES_OVERRIDE_ADAMW' not in os.environ: os.environ['BITSANDBYTES_OVERRIDE_ADAMW'] = '1' """ +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml', nargs='+') # ugh + parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') + args = parser.parse_args() + args.opt = " ".join(args.opt) # absolutely disgusting + + with open(args.opt, 'r') as file: + opt_config = yaml.safe_load(file) + + if "ext" in opt_config and "bitsandbytes" in opt_config["ext"] and not opt_config["ext"]["bitsandbytes"]: + os.environ['BITSANDBYTES_OVERRIDE_LINEAR'] = '0' + os.environ['BITSANDBYTES_OVERRIDE_EMBEDDING'] = '0' + os.environ['BITSANDBYTES_OVERRIDE_ADAM'] = '0' + os.environ['BITSANDBYTES_OVERRIDE_ADAMW'] = '0' # this is some massive kludge that only works if it's called from a shell and not an import/PIP package # it's smart-yet-irritating module-model loader breaks when trying to load something specifically when not from a shell @@ -70,13 +85,9 @@ if __name__ == "__main__": import torch_intermediary if torch_intermediary.OVERRIDE_ADAM: print("Using BitsAndBytes ADAMW optimizations") + else: + print("NOT using BitsAndBytes ADAMW optimizations") except Exception as e: pass - parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml', nargs='+') # ugh - parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') - args = parser.parse_args() - args.opt = " ".join(args.opt) # absolutely disgusting - train(args.opt, args.launcher) \ No newline at end of file diff --git a/src/utils.py b/src/utils.py index 1dadb87..a90a7b0 100755 --- a/src/utils.py +++ b/src/utils.py @@ -538,7 +538,7 @@ class TrainingState(): """ # I wanted frequently updated ETA, but I can't wrap my noggin around getting it to work on an empty belly # will fix later - + #self.eta = (self.its - self.it) * self.it_time_delta self.it_time_deltas = self.it_time_deltas + self.it_time_delta self.it_taken = self.it_taken + 1 @@ -731,7 +731,7 @@ EPOCH_SCHEDULE = [ 9, 18, 25, 33 ] def schedule_learning_rate( iterations ): return [int(iterations * d) for d in EPOCH_SCHEDULE] -def optimize_training_settings( epochs, learning_rate, learning_rate_schedule, batch_size, mega_batch_factor, print_rate, save_rate, resume_path, half_p, voice ): +def optimize_training_settings( epochs, learning_rate, learning_rate_schedule, batch_size, mega_batch_factor, print_rate, save_rate, resume_path, half_p, bnb, voice ): name = f"{voice}-finetune" dataset_name = f"{voice}-train" dataset_path = f"./training/{voice}/train.txt" @@ -777,6 +777,9 @@ def optimize_training_settings( epochs, learning_rate, learning_rate_schedule, b messages.append("Half Precision requested. Please note this is ! EXPERIMENTAL !") if not os.path.exists(get_halfp_model_path()): convert_to_halfp() + + if bnb: + messages.append("BitsAndBytes requested. Please note this is ! EXPERIMENTAL !") messages.append(f"For {epochs} epochs with {lines} lines in batches of {batch_size}, iterating for {iterations} steps ({int(iterations / epochs)} steps per epoch)") @@ -791,7 +794,7 @@ def optimize_training_settings( epochs, learning_rate, learning_rate_schedule, b messages ) -def save_training_settings( iterations=None, learning_rate=None, learning_rate_schedule=None, batch_size=None, mega_batch_factor=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None, half_p=None ): +def save_training_settings( iterations=None, learning_rate=None, learning_rate_schedule=None, batch_size=None, mega_batch_factor=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None, half_p=None, bnb=None ): settings = { "iterations": iterations if iterations else 500, "batch_size": batch_size if batch_size else 64, @@ -809,7 +812,8 @@ def save_training_settings( iterations=None, learning_rate=None, learning_rate_s 'resume_state': f"resume_state: '{resume_path}'", 'pretrain_model_gpt': f"pretrain_model_gpt: './models/tortoise/autoregressive{'_half' if half_p else ''}.pth'", - 'float16': 'true' if half_p else 'false' + 'float16': 'true' if half_p else 'false', + 'bitsandbytes': 'true' if bnb else 'false', } if resume_path: @@ -1038,6 +1042,9 @@ def setup_args(): 'concurrency-count': 2, 'output-sample-rate': 44100, 'output-volume': 1, + + 'training-default-halfp': False, + 'training-default-bnb': True, } if os.path.isfile('./config/exec.json'): @@ -1067,6 +1074,9 @@ def setup_args(): parser.add_argument("--output-sample-rate", type=int, default=default_arguments['output-sample-rate'], help="Sample rate to resample the output to (from 24KHz)") parser.add_argument("--output-volume", type=float, default=default_arguments['output-volume'], help="Adjusts volume of output") + parser.add_argument("--training-default-halfp", action='store_true', default=default_arguments['training-default-halfp'], help="Training default: halfp") + parser.add_argument("--training-default-bnb", action='store_true', default=default_arguments['training-default-bnb'], help="Training default: bnb") + parser.add_argument("--os", default="unix", help="Specifies which OS, easily") args = parser.parse_args() @@ -1093,7 +1103,7 @@ def setup_args(): return args -def update_args( listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents, defer_tts_load, prune_nonfinal_outputs, device_override, sample_batch_size, concurrency_count, output_sample_rate, output_volume ): +def update_args( listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents, defer_tts_load, prune_nonfinal_outputs, device_override, sample_batch_size, concurrency_count, output_sample_rate, output_volume, training_default_halfp, training_default_bnb ): global args args.listen = listen @@ -1113,10 +1123,13 @@ def update_args( listen, share, check_for_updates, models_from_local_only, low_v args.concurrency_count = concurrency_count args.output_sample_rate = output_sample_rate args.output_volume = output_volume + args.training_default_halfp = training_default_halfp + args.training_default_bnb = training_default_bnb save_args_settings() def save_args_settings(): + global args settings = { 'listen': None if args.listen else args.listen, 'share': args.share, @@ -1137,8 +1150,13 @@ def save_args_settings(): 'concurrency-count': args.concurrency_count, 'output-sample-rate': args.output_sample_rate, 'output-volume': args.output_volume, + + 'training-default-halfp': args.training_default_halfp, + 'training-default-bnb': args.training_default_bnb, } + print(settings) + os.makedirs('./config/', exist_ok=True) with open(f'./config/exec.json', 'w', encoding="utf-8") as f: f.write(json.dumps(settings, indent='\t') ) diff --git a/src/webui.py b/src/webui.py index 513f6cb..1ebcc43 100755 --- a/src/webui.py +++ b/src/webui.py @@ -200,11 +200,12 @@ def optimize_training_settings_proxy( *args, **kwargs ): "\n".join(tup[7]) ) -def import_training_settings_proxy( epochs, learning_rate, learning_rate_schedule, batch_size, mega_batch_factor, print_rate, save_rate, resume_path, half_p, voice ): +def import_training_settings_proxy( voice ): indir = f'./training/{voice}/' outdir = f'./training/{voice}-finetune/' in_config_path = f"{indir}/train.yaml" + out_configs = [] if os.path.isdir(outdir): out_configs = sorted([d[:-5] for d in os.listdir(outdir) if d[-5:] == ".yaml" ]) if len(out_configs) > 0: @@ -244,6 +245,12 @@ def import_training_settings_proxy( epochs, learning_rate, learning_rate_schedul resume_path = f'{statedir}/{resumes[-1]}.state' messages.append(f"Latest resume found: {resume_path}") + half_p = config['fp16'] + bnb = True + + if "ext" in config and "bitsandbytes" in config["ext"]: + bnb = config["ext"]["bitsandbytes"] + messages = "\n".join(messages) return ( @@ -255,11 +262,13 @@ def import_training_settings_proxy( epochs, learning_rate, learning_rate_schedul print_rate, save_rate, resume_path, + half_p, + bnb, messages ) -def save_training_settings_proxy( epochs, learning_rate, learning_rate_schedule, batch_size, mega_batch_factor, print_rate, save_rate, resume_path, half_p, voice ): +def save_training_settings_proxy( epochs, learning_rate, learning_rate_schedule, batch_size, mega_batch_factor, print_rate, save_rate, resume_path, half_p, bnb, voice ): name = f"{voice}-finetune" dataset_name = f"{voice}-train" dataset_path = f"./training/{voice}/train.txt" @@ -297,6 +306,7 @@ def save_training_settings_proxy( epochs, learning_rate, learning_rate_schedule, output_name=f"{voice}/train.yaml", resume_path=resume_path, half_p=half_p, + bnb=bnb, )) return "\n".join(messages) @@ -471,10 +481,12 @@ def setup_gradio(): ] training_settings = training_settings + [ gr.Textbox(label="Resume State Path", placeholder="./training/${voice}-finetune/training_state/${last_state}.state"), - gr.Checkbox(label="Half Precision", value=False), ] + training_halfp = gr.Checkbox(label="Half Precision", value=args.training_default_halfp) + training_bnb = gr.Checkbox(label="BitsAndBytes", value=args.training_default_bnb) dataset_list = gr.Dropdown( get_dataset_list(), label="Dataset", type="value" ) - training_settings = training_settings + [ dataset_list ] + training_settings = training_settings + [ training_halfp, training_bnb, dataset_list ] + with gr.Row(): refresh_dataset_list = gr.Button(value="Refresh Dataset List") import_dataset_button = gr.Button(value="Import Dataset") @@ -558,6 +570,8 @@ def setup_gradio(): outputs=None ) + exec_inputs = exec_inputs + [ training_halfp, training_bnb ] + for i in exec_inputs: i.change( fn=update_args, inputs=exec_inputs ) @@ -731,8 +745,8 @@ def setup_gradio(): outputs=training_settings[1:8] + [save_yaml_output] #console_output ) import_dataset_button.click(import_training_settings_proxy, - inputs=training_settings, - outputs=training_settings[:8] + [save_yaml_output] #console_output + inputs=dataset_list, + outputs=training_settings[:10] + [save_yaml_output] #console_output ) save_yaml_button.click(save_training_settings_proxy, inputs=training_settings,