From e859a7c01d00493ed4c911e7088cdf08f51489a3 Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 3 Mar 2023 04:37:18 +0000 Subject: [PATCH] experimental multi-gpu training (Linux only, because I can't into batch files) --- .gitignore | 7 ++++--- src/train.py | 5 ++++- src/utils.py | 29 ++++++++++++++++------------- src/webui.py | 2 ++ train.sh | 11 ++++++++++- 5 files changed, 36 insertions(+), 18 deletions(-) diff --git a/.gitignore b/.gitignore index eb93a40..9d5407e 100755 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,8 @@ # ignores user files -/tortoise-venv/ -/tortoise/voices/ -/models/ +/venv/ +/voices/* +/models/* +/training/* /config/* # Byte-compiled / optimized / DLL files diff --git a/src/train.py b/src/train.py index b1f61a3..ef3cb0c 100755 --- a/src/train.py +++ b/src/train.py @@ -18,9 +18,12 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml', nargs='+') # ugh parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') + parser.add_argument('--local_rank', type=int, help='Rank Number') args = parser.parse_args() args.opt = " ".join(args.opt) # absolutely disgusting + os.environ['LOCAL_RANK'] = str(args.local_rank) + with open(args.opt, 'r') as file: opt_config = yaml.safe_load(file) @@ -71,7 +74,7 @@ def train(yaml, launcher='none'): print('Disabled distributed training.') else: opt['dist'] = True - init_dist('nccl') + tr.init_dist('nccl') trainer.world_size = torch.distributed.get_world_size() trainer.rank = torch.distributed.get_rank() torch.cuda.set_device(torch.distributed.get_rank()) diff --git a/src/utils.py b/src/utils.py index e4d36be..215fc7c 100755 --- a/src/utils.py +++ b/src/utils.py @@ -477,9 +477,7 @@ def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm # superfluous, but it cleans up some things class TrainingState(): - def __init__(self, config_path, keep_x_past_datasets=0, start=True): - self.cmd = ['train.bat', config_path] if os.name == "nt" else ['bash', './train.sh', config_path] - + def __init__(self, config_path, keep_x_past_datasets=0, start=True, gpus=1): # parse config to get its iteration with open(config_path, 'r') as file: self.config = yaml.safe_load(file) @@ -530,9 +528,11 @@ class TrainingState(): if keep_x_past_datasets > 0: self.cleanup_old(keep=keep_x_past_datasets) if start: - self.spawn_process() + self.spawn_process(config_path=config_path, gpus=gpus) + + def spawn_process(self, config_path, gpus=1): + self.cmd = ['train.bat', config_path] if os.name == "nt" else ['bash', './train.sh', str(int(gpus)), config_path] - def spawn_process(self): print("Spawning process: ", " ".join(self.cmd)) self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True) @@ -745,7 +745,7 @@ class TrainingState(): if should_return: return "".join(self.buffer) -def run_training(config_path, verbose=False, buffer_size=8, keep_x_past_datasets=0, progress=gr.Progress(track_tqdm=True)): +def run_training(config_path, verbose=False, gpus=1, buffer_size=8, keep_x_past_datasets=0, progress=gr.Progress(track_tqdm=True)): global training_state if training_state and training_state.process: return "Training already in progress" @@ -757,7 +757,7 @@ def run_training(config_path, verbose=False, buffer_size=8, keep_x_past_datasets unload_whisper() unload_voicefixer() - training_state = TrainingState(config_path=config_path, keep_x_past_datasets=keep_x_past_datasets) + training_state = TrainingState(config_path=config_path, keep_x_past_datasets=keep_x_past_datasets, gpus=gpus) for line in iter(training_state.process.stdout.readline, ""): @@ -785,11 +785,13 @@ def update_training_dataplot(config_path=None): update = None if not training_state: - training_state = TrainingState(config_path=config_path, start=False) - update = gr.LinePlot.update(value=pd.DataFrame(training_state.losses)) - del training_state - training_state = None - else: + if config_path: + training_state = TrainingState(config_path=config_path, start=False) + if training_state.losses: + update = gr.LinePlot.update(value=pd.DataFrame(training_state.losses)) + del training_state + training_state = None + elif training_state.losses: update = gr.LinePlot.update(value=pd.DataFrame(training_state.losses)) return update @@ -1285,12 +1287,13 @@ def setup_args(): if not args.device_override: set_device_name(args.device_override) + args.listen_host = None args.listen_port = None args.listen_path = None if args.listen: try: - match = re.findall(r"^(?:(.+?):(\d+))?(\/.+?)?$", args.listen)[0] + match = re.findall(r"^(?:(.+?):(\d+))?(\/.*?)?$", args.listen)[0] args.listen_host = match[0] if match[0] != "" else "127.0.0.1" args.listen_port = match[1] if match[1] != "" else None diff --git a/src/webui.py b/src/webui.py index a969033..a215876 100755 --- a/src/webui.py +++ b/src/webui.py @@ -546,6 +546,7 @@ def setup_gradio(): verbose_training = gr.Checkbox(label="Verbose Console Output", value=True) training_buffer_size = gr.Slider(label="Console Buffer Size", minimum=4, maximum=32, value=8) training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1) + training_gpu_count = gr.Number(label="GPUs", value=1) with gr.Row(): start_training_button = gr.Button(value="Train") stop_training_button = gr.Button(value="Stop") @@ -751,6 +752,7 @@ def setup_gradio(): inputs=[ training_configs, verbose_training, + training_gpu_count, training_buffer_size, training_keep_x_past_datasets, ], diff --git a/train.sh b/train.sh index 5e83e27..70f2651 100755 --- a/train.sh +++ b/train.sh @@ -1,4 +1,13 @@ #!/bin/bash source ./venv/bin/activate -python3 ./src/train.py -opt "$1" + +GPUS=$1 +CONFIG=$2 +PORT=1234 + +if (( $GPUS > 1 )); then + python3 -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT ./src/train.py -opt "$CONFIG" --launcher=pytorch +else + python3 ./src/train.py -opt "$CONFIG" +fi deactivate