From 1fd88afccad45f5579412c9c380eb0e9cf59ee6d Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 20 Feb 2023 22:56:39 +0000 Subject: [PATCH] updated notebook for newer setup structure, added formatting of getting it/s and lass loss rate (have not tested loss rate yet) --- notebook.ipynb | 22 +++++++++------------- src/utils.py | 37 ++++++++++++++++++++++++++++++++----- 2 files changed, 41 insertions(+), 18 deletions(-) diff --git a/notebook.ipynb b/notebook.ipynb index f484d38..cb019a8 100755 --- a/notebook.ipynb +++ b/notebook.ipynb @@ -41,17 +41,20 @@ "!git clone https://git.ecker.tech/mrq/ai-voice-cloning/\n", "%cd ai-voice-cloning\n", "\n", + "!git submodule init\n", + "!git submodule update\n", + "\n", "# TODO: fix venvs working for subprocess.Popen calling a bash script\n", "#!apt install python3.8-venv\n", "#!python -m venv venv\n", "#!source ./venv/bin/activate\n", "\n", - "!git clone https://git.ecker.tech/mrq/DL-Art-School dlas\n", "!python -m pip install --upgrade pip\n", "!pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116\n", - "!./setup-tortoise.sh\n", - "!./setup-training.sh\n", - "!python -m pip install -r ./requirements.txt" + "!python -m pip install -r ./dlas/requirements.txt\n", + "!python -m pip install -r ./tortoise-tts/requirements.txt\n", + "!python -m pip install -r ./requirements.txt\n", + "!python -m pip install -e ./tortoise-tts/" ] }, { @@ -67,15 +70,8 @@ "cell_type":"code", "source":[ "# for my debugging purposes\n", - "%cd /content/ai-voice-cloning/dlas\n", - "!git reset --hard HEAD\n", - "!git pull\n", - "%cd ../tortoise-tts/\n", - "!git reset --hard HEAD\n", - "!git pull\n", - "!cd ..\n", - "!git reset --hard HEAD\n", - "!git pull\n", + "%cd /content/ai-voice-cloning/\n", + "!./update.sh\n", "# exit()" ], "metadata":{ diff --git a/src/utils.py b/src/utils.py index 5cd2121..edb6487 100755 --- a/src/utils.py +++ b/src/utils.py @@ -417,9 +417,16 @@ def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress yield " ".join(cmd) + info = {} buffer = [] infos = [] yields = True + status = "" + + it_rate = "" + it_time_start = 0 + it_time_end = 0 + for line in iter(training_process.stdout.readline, ""): buffer.append(f'{line}') @@ -430,13 +437,34 @@ def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress elif progress is not None: if line.find(' 0%|') == 0: open_state = True + it_time_start = time.time() elif line.find('100%|') == 0 and open_state: + it_time_end = time.time() open_state = False it = it + 1 - progress(it / float(its), f'[{it}/{its}] Training...') - elif line.find('INFO: [epoch:') >= 0: - infos.append(f'{line}') - elif line.find('Saving models and training states') >= 0: + + it_time_delta = it_time_end-it_time_start + it_rate = f'[{"{:.3f}".format(it_time_delta)}s/it]' if it_time_delta >= 1 and it_time_delta != 0 else f'[{"{:.3f}".format(1/it_time_delta)}it/s]' # I doubt anyone will have it/s rates, but its here + + progress(it / float(its), f'[{it}/{its}] {it_rate} Training... {status}') + + # try because I haven't tested this yet + try: + if line.find('INFO: [epoch:') >= 0: + # easily rip out our stats... + match = re.findall(r'\b([a-z_0-9]+?)\b: ([0-9]\.[0-9]+?e[+-]\d+)\b', line) + if match and len(match) > 0: + for k, v in match: + info[k] = float(v) + + # ...and returns our loss rate + # it would be nice for losses to be shown at every step + if 'loss_gpt_total' in info: + status = f"Total loss at step {int(info['step'])}: {info['loss_gpt_total']}" + except Exception as e: + pass + + if line.find('Saving models and training states') >= 0: checkpoint = checkpoint + 1 progress(checkpoint / float(checkpoints), f'[{checkpoint}/{checkpoints}] Saving checkpoint...') @@ -459,7 +487,6 @@ def stop_training(): if training_process is None: return "No training in progress" training_process.kill() - training_process = None return "Training cancelled" def prepare_dataset( files, outdir, language=None, progress=None ):