From 225dee22d47ebba84e6e631ae65283eba11b7a4b Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 23 Feb 2023 06:24:54 +0000 Subject: [PATCH] huge success --- README.md | 17 +++- dlas | 2 +- models/.template.yaml | 2 +- setup-cuda.bat | 3 + src/train.py | 17 +++- src/utils.py | 188 ++++++++++++++++++++++++------------------ src/webui.py | 8 ++ update.bat | 2 +- update.sh | 2 +- 9 files changed, 154 insertions(+), 87 deletions(-) diff --git a/README.md b/README.md index 768606b..9498ed9 100755 --- a/README.md +++ b/README.md @@ -16,4 +16,19 @@ Please consult [the wiki](https://git.ecker.tech/mrq/ai-voice-cloning/wiki) for ## Bug Reporting -If you run into any problems, please refer to the [issues you may encounter](https://git.ecker.tech/mrq/ai-voice-cloning/wiki/Issues) wiki page first. Please don't hesitate to submit an issue. \ No newline at end of file +If you run into any problems, please refer to the [issues you may encounter](https://git.ecker.tech/mrq/ai-voice-cloning/wiki/Issues) wiki page first. Please don't hesitate to submit an issue. + +## Changelogs + +Below will be a rather-loose changelogss, as I don't think I have a way to chronicle them outside of commit messages: + +### `2023.02.22` + +* greatly reduced VRAM consumption through the use of [TimDettmers/bitsandbytes](https://github.com/TimDettmers/bitsandbytes) +* cleaned up section of code that handled parsing output from training script +* added button to reconnect to the training script's output (sometimes skips a line to update, but it's better than nothing) +* actually update submodules from the update script (somehow forgot to pass `--remote`) + +### `Before 2023.02.22` + +Refer to commit logs. \ No newline at end of file diff --git a/dlas b/dlas index 6c284ef..0ef8ab6 160000 --- a/dlas +++ b/dlas @@ -1 +1 @@ -Subproject commit 6c284ef8ec4c4769de3181d90ac96ff63581ef55 +Subproject commit 0ef8ab6872813d1021d4d75e82b63377d28f5a06 diff --git a/models/.template.yaml b/models/.template.yaml index d89c889..7f92c48 100755 --- a/models/.template.yaml +++ b/models/.template.yaml @@ -2,7 +2,7 @@ name: ${name} model: extensibletrainer scale: 1 gpu_ids: [0] # <-- unless you have multiple gpus, use this -start_step: -1 +start_step: 0 checkpointing_enabled: true # <-- Gradient checkpointing. Enable for huge GPU memory savings. Disable for distributed training. fp16: ${float16} # might want to check this out wandb: false # <-- enable to log to wandb. tensorboard logging is always enabled. diff --git a/setup-cuda.bat b/setup-cuda.bat index 53a29f3..c8346e2 100755 --- a/setup-cuda.bat +++ b/setup-cuda.bat @@ -9,5 +9,8 @@ python -m pip install -r .\dlas\requirements.txt python -m pip install -r .\tortoise-tts\requirements.txt python -m pip install -r .\requirements.txt python -m pip install -e .\tortoise-tts\ + +copy .\dlas\bitsandbytes_windows\* .\venv\Lib\site-packages\bitsandbytes\. /Y + deactivate pause \ No newline at end of file diff --git a/src/train.py b/src/train.py index 941ec5c..900261f 100755 --- a/src/train.py +++ b/src/train.py @@ -1,8 +1,8 @@ -import torch -import argparse - import os import sys +import argparse + + # this is some massive kludge that only works if it's called from a shell and not an import/PIP package # it's smart-yet-irritating module-model loader breaks when trying to load something specifically when not from a shell @@ -19,6 +19,17 @@ sys.path.insert(0, './dlas/') # don't even really bother trying to get DLAS PIP'd # without kludge, it'll have to be accessible as `codes` and not `dlas` +import torch_intermediary +# could just move this auto-toggle into the MITM script +try: + import bitsandbytes as bnb + torch_intermediary.OVERRIDE_ADAM = True + torch_intermediary.OVERRIDE_ADAMW = True +except Exception as e: + torch_intermediary.OVERRIDE_ADAM = False + torch_intermediary.OVERRIDE_ADAMW = False + +import torch from codes import train as tr from utils import util, options as option diff --git a/src/utils.py b/src/utils.py index 10e9a4e..fa5a242 100755 --- a/src/utils.py +++ b/src/utils.py @@ -17,6 +17,7 @@ import urllib.request import signal import gc import subprocess +import yaml import tqdm import torch @@ -26,6 +27,7 @@ import gradio as gr import gradio.utils from datetime import datetime +from datetime import timedelta from tortoise.api import TextToSpeech, MODELS, get_model_path from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir @@ -42,7 +44,7 @@ tts_loading = False webui = None voicefixer = None whisper_model = None -training_process = None +training_state = None def generate( @@ -434,8 +436,88 @@ def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm return voice +# superfluous, but it cleans up some things +class TrainingState(): + def __init__(self, config_path, buffer_size=8): + self.cmd = ['train.bat', config_path] if os.name == "nt" else ['bash', './train.sh', config_path] + + # parse config to get its iteration + with open(config_path, 'r') as file: + self.config = yaml.safe_load(file) + + self.it = 0 + self.its = self.config['train']['niter'] + + self.checkpoint = 0 + self.checkpoints = int(self.its / self.config['logger']['save_checkpoint_freq']) + + self.buffer = [] + + self.open_state = False + self.training_started = False + + self.info = {} + self.status = "" + + self.it_rate = "" + self.it_time_start = 0 + self.it_time_end = 0 + self.eta = "?" + + print("Spawning process: ", " ".join(self.cmd)) + self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True) + + def parse(self, line, verbose=False, buffer_size=8, progress=None): + self.buffer.append(f'{line}') + + # rip out iteration info + if not self.training_started: + if line.find('Start training from epoch') >= 0: + self.it_time_start = time.time() + self.training_started = True # could just leverage the above variable, but this is python, and there's no point in these aggressive microoptimizations + + match = re.findall(r'iter: ([\d,]+)', line) + if match and len(match) > 0: + self.it = int(match[0].replace(",", "")) + elif progress is not None: + if line.find(' 0%|') == 0: + self.open_state = True + elif line.find('100%|') == 0 and self.open_state: + self.open_state = False + self.it = self.it + 1 + + self.it_time_end = time.time() + self.it_time_delta = self.it_time_end-self.it_time_start + self.it_time_start = time.time() + self.it_rate = f'[{"{:.3f}".format(self.it_time_delta)}s/it]' if self.it_time_delta >= 1 else f'[{"{:.3f}".format(1/self.it_time_delta)}it/s]' # I doubt anyone will have it/s rates, but its here + self.eta = (self.its - self.it) * self.it_time_delta + self.eta_hhmmss = str(timedelta(seconds=int(self.eta))) + + progress(self.it / float(self.its), f'[{self.it}/{self.its}] [ETA: {self.eta_hhmmss}] {self.it_rate} Training... {self.status}') + + if line.find('INFO: [epoch:') >= 0: + # easily rip out our stats... + match = re.findall(r'\b([a-z_0-9]+?)\b: ([0-9]\.[0-9]+?e[+-]\d+)\b', line) + if match and len(match) > 0: + for k, v in match: + self.info[k] = float(v) + + # ...and returns our loss rate + # it would be nice for losses to be shown at every step + if 'loss_gpt_total' in self.info: + # self.info['step'] returns the steps, not iterations, so we won't even bother ripping the reported step count, as iteration count won't get ripped from the regex + self.status = f"Total loss at iteration {self.it}: {self.info['loss_gpt_total']}" + elif line.find('Saving models and training states') >= 0: + self.checkpoint = self.checkpoint + 1 + progress(self.checkpoint / float(self.checkpoints), f'[{self.checkpoint}/{self.checkpoints}] Saving checkpoint...') + + if verbose or not self.training_started: + return "".join(self.buffer[-buffer_size:]) + def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress(track_tqdm=True)): - global training_process + global training_state + if training_state and training_state.process: + return "Training already in progress" # I don't know if this is still necessary, as it was bitching at me for not doing this, despite it being in a separate process torch.multiprocessing.freeze_support() @@ -444,90 +526,38 @@ def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress unload_whisper() unload_voicefixer() - cmd = ['train.bat', config_path] if os.name == "nt" else ['bash', './train.sh', config_path] - print("Spawning process: ", " ".join(cmd)) - training_process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True) - - # parse config to get its iteration - import yaml - with open(config_path, 'r') as file: - config = yaml.safe_load(file) - - it = 0 - its = config['train']['niter'] - - checkpoint = 0 - checkpoints = its / config['logger']['save_checkpoint_freq'] - - buffer_size = 8 - open_state = False - training_started = False - - yield " ".join(cmd) - - info = {} - buffer = [] - infos = [] - yields = True - status = "" - - it_rate = "" - it_time_start = 0 - it_time_end = 0 - - for line in iter(training_process.stdout.readline, ""): - buffer.append(f'{line}') - - # rip out iteration info - if not training_started: - if line.find('Start training from epoch') >= 0: - training_started = True - - match = re.findall(r'iter: ([\d,]+)', line) - if match and len(match) > 0: - it = int(match[0].replace(",", "")) - elif progress is not None: - if line.find(' 0%|') == 0: - open_state = True - elif line.find('100%|') == 0 and open_state: - open_state = False - it = it + 1 - - it_time_end = time.time() - it_time_delta = it_time_end-it_time_start - it_time_start = time.time() - it_rate = f'[{"{:.3f}".format(it_time_delta)}s/it]' if it_time_delta >= 1 else f'[{"{:.3f}".format(1/it_time_delta)}it/s]' # I doubt anyone will have it/s rates, but its here - - progress(it / float(its), f'[{it}/{its}] {it_rate} Training... {status}') - - if line.find('INFO: [epoch:') >= 0: - # easily rip out our stats... - match = re.findall(r'\b([a-z_0-9]+?)\b: ([0-9]\.[0-9]+?e[+-]\d+)\b', line) - if match and len(match) > 0: - for k, v in match: - info[k] = float(v) - - # ...and returns our loss rate - # it would be nice for losses to be shown at every step - if 'loss_gpt_total' in info: - status = f"Total loss at step {int(info['step'])}: {info['loss_gpt_total']}" - elif line.find('Saving models and training states') >= 0: - checkpoint = checkpoint + 1 - progress(checkpoint / float(checkpoints), f'[{checkpoint}/{checkpoints}] Saving checkpoint...') + training_state = TrainingState(config_path=config_path, buffer_size=buffer_size) + for line in iter(training_state.process.stdout.readline, ""): print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}") + + res = training_state.parse( line=line, verbose=verbose, buffer_size=buffer_size, progress=progress ) + if res: + yield res - if verbose or not training_started: - yield "".join(buffer[-buffer_size:]) - - training_process.stdout.close() - return_code = training_process.wait() - training_process = None + training_state.process.stdout.close() + return_code = training_state.process.wait() + output = "".join(training_state.buffer[-buffer_size:]) + training_state = None #if return_code: # raise subprocess.CalledProcessError(return_code, cmd) - return "".join(buffer[-buffer_size:]) + return output + +def reconnect_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress(track_tqdm=True)): + global training_state + if not training_state or not training_state.process: + return "Training not in progress" + + for line in iter(training_state.process.stdout.readline, ""): + res = training_state.parse( line=line, verbose=verbose, buffer_size=buffer_size, progress=progress ) + if res: + yield res + + output = "".join(training_state.buffer[-buffer_size:]) + + return output def stop_training(): global training_process diff --git a/src/webui.py b/src/webui.py index 68519d4..cc52cf2 100755 --- a/src/webui.py +++ b/src/webui.py @@ -410,6 +410,7 @@ def setup_gradio(): refresh_configs = gr.Button(value="Refresh Configurations") start_training_button = gr.Button(value="Train") stop_training_button = gr.Button(value="Stop") + reconnect_training_button = gr.Button(value="Reconnect") with gr.Column(): training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8) verbose_training = gr.Checkbox(label="Verbose Console Output") @@ -614,6 +615,13 @@ def setup_gradio(): inputs=None, outputs=training_output #console_output ) + reconnect_training_button.click(reconnect_training, + inputs=[ + verbose_training, + training_buffer_size, + ], + outputs=training_output #console_output + ) prepare_dataset_button.click( prepare_dataset_proxy, inputs=dataset_settings, diff --git a/update.bat b/update.bat index 2a3962a..843d5d8 100755 --- a/update.bat +++ b/update.bat @@ -1,5 +1,5 @@ git pull -git submodule update +git submodule update --remote python -m venv venv call .\venv\Scripts\activate.bat diff --git a/update.sh b/update.sh index 58d852b..584a8a4 100755 --- a/update.sh +++ b/update.sh @@ -1,6 +1,6 @@ #!/bin/bash git pull -git submodule update +git submodule update --remote python3 -m venv venv source ./venv/bin/activate