From d5d8821a9dc748a6601db849a3978a5c33a1b496 Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 24 Feb 2023 23:13:13 +0000 Subject: [PATCH] fixed some files not copying for bitsandbytes (I was wrong to assume it copied folders too), fixed stopping generating and training, some other thing that I forgot since it's been slowly worked on in my small free times --- dlas | 2 +- setup-cuda.bat | 4 +++- src/main.py | 6 +++--- src/train.py | 13 +++++++++++++ src/utils.py | 34 ++++++++++++++++++++++++---------- src/webui.py | 4 ++-- tortoise-tts | 2 +- 7 files changed, 47 insertions(+), 18 deletions(-) diff --git a/dlas b/dlas index 1433b7c..0f04206 160000 --- a/dlas +++ b/dlas @@ -1 +1 @@ -Subproject commit 1433b7c0eabcc797dac8e68e9acc3043b9a28e12 +Subproject commit 0f04206aa20b1ab632c0cbf7bb6a43d5c1fd9eb0 diff --git a/setup-cuda.bat b/setup-cuda.bat index 6ad903a..d748123 100755 --- a/setup-cuda.bat +++ b/setup-cuda.bat @@ -10,7 +10,9 @@ python -m pip install -r .\tortoise-tts\requirements.txt python -m pip install -r .\requirements.txt python -m pip install -e .\tortoise-tts\ -copy .\dlas\bitsandbytes_windows\* .\venv\Lib\site-packages\bitsandbytes\. /Y +xcopy .\dlas\bitsandbytes_windows\* .\venv\Lib\site-packages\bitsandbytes\. /Y +xcopy .\dlas\bitsandbytes_windows\cuda_setup\* .\venv\Lib\site-packages\bitsandbytes\cuda_setup\. /Y +xcopy .\dlas\bitsandbytes_windows\nn\* .\venv\Lib\site-packages\bitsandbytes\nn\. /Y deactivate pause \ No newline at end of file diff --git a/src/main.py b/src/main.py index 582d50a..d2d3c27 100755 --- a/src/main.py +++ b/src/main.py @@ -1,14 +1,14 @@ import os -from utils import * -from webui import * - if 'TORTOISE_MODELS_DIR' not in os.environ: os.environ['TORTOISE_MODELS_DIR'] = os.path.realpath(os.path.join(os.getcwd(), './models/tortoise/')) if 'TRANSFORMERS_CACHE' not in os.environ: os.environ['TRANSFORMERS_CACHE'] = os.path.realpath(os.path.join(os.getcwd(), './models/transformers/')) +from utils import * +from webui import * + if __name__ == "__main__": args = setup_args() diff --git a/src/train.py b/src/train.py index bd0656a..a3e480e 100755 --- a/src/train.py +++ b/src/train.py @@ -2,6 +2,19 @@ import os import sys import argparse + +""" +if 'BITSANDBYTES_OVERRIDE_LINEAR' not in os.environ: + os.environ['BITSANDBYTES_OVERRIDE_LINEAR'] = '0' +if 'BITSANDBYTES_OVERRIDE_EMBEDDING' not in os.environ: + os.environ['BITSANDBYTES_OVERRIDE_EMBEDDING'] = '1' +if 'BITSANDBYTES_OVERRIDE_ADAM' not in os.environ: + os.environ['BITSANDBYTES_OVERRIDE_ADAM'] = '1' +if 'BITSANDBYTES_OVERRIDE_ADAMW' not in os.environ: + os.environ['BITSANDBYTES_OVERRIDE_ADAMW'] = '1' +""" + + # this is some massive kludge that only works if it's called from a shell and not an import/PIP package # it's smart-yet-irritating module-model loader breaks when trying to load something specifically when not from a shell diff --git a/src/utils.py b/src/utils.py index 586020c..d8150ed 100755 --- a/src/utils.py +++ b/src/utils.py @@ -407,8 +407,8 @@ def generate( ) def cancel_generate(): - from tortoise.api import STOP_SIGNAL - STOP_SIGNAL = True + import tortoise.api + tortoise.api.STOP_SIGNAL = True def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)): global tts @@ -557,9 +557,10 @@ def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress if res: yield res - training_state.process.stdout.close() - return_code = training_state.process.wait() - training_state = None + if training_state: + training_state.process.stdout.close() + return_code = training_state.process.wait() + training_state = None #if return_code: # raise subprocess.CalledProcessError(return_code, cmd) @@ -575,10 +576,15 @@ def reconnect_training(config_path, verbose=False, buffer_size=8, progress=gr.Pr yield res def stop_training(): - global training_process - if training_process is None: + global training_state + if training_state is None: return "No training in progress" - training_process.kill() + print("Killing training process...") + training_state.killed = True + training_state.process.stdout.close() + training_state.process.kill() + return_code = training_state.process.wait() + training_state = None return "Training cancelled" def get_halfp_model_path(): @@ -1234,6 +1240,7 @@ def load_voicefixer(restart=False): print("Loading Voicefixer") from voicefixer import VoiceFixer voicefixer = VoiceFixer() + print("Loaded Voicefixer") except Exception as e: print(f"Error occurred while tring to initialize voicefixer: {e}") @@ -1244,6 +1251,7 @@ def unload_voicefixer(): print("Unloading Voicefixer") del voicefixer voicefixer = None + print("Unloaded Voicefixer") do_gc() @@ -1254,9 +1262,11 @@ def load_whisper_model(name=None, progress=None): name = args.whisper_model else: args.whisper_model = name + save_args_settings() notify_progress(f"Loading Whisper model: {args.whisper_model}", progress) whisper_model = whisper.load_model(args.whisper_model) + print("Loaded Whisper model") def unload_whisper(): global whisper_model @@ -1265,6 +1275,7 @@ def unload_whisper(): print("Unloading Whisper") del whisper_model whisper_model = None + print("Unloaded Whisper") do_gc() @@ -1272,8 +1283,11 @@ def update_whisper_model(name, progress=None): if not name: return + global whisper_model if whisper_model: unload_whisper() - - load_whisper_model(name) \ No newline at end of file + load_whisper_model(name) + else: + args.whisper_model = name + save_args_settings() \ No newline at end of file diff --git a/src/webui.py b/src/webui.py index d6c16ef..513f6cb 100755 --- a/src/webui.py +++ b/src/webui.py @@ -78,7 +78,7 @@ def run_generation( except Exception as e: message = str(e) if message == "Kill signal detected": - reload_tts() + unload_tts() raise gr.Error(message) @@ -745,7 +745,7 @@ def setup_gradio(): if args.check_for_updates: ui.load(check_for_updates) - stop.click(fn=cancel_generate, inputs=None, outputs=None, cancels=[submit_event]) + stop.click(fn=cancel_generate, inputs=None, outputs=None) ui.queue(concurrency_count=args.concurrency_count) diff --git a/tortoise-tts b/tortoise-tts index de46cf7..7cc0250 160000 --- a/tortoise-tts +++ b/tortoise-tts @@ -1 +1 @@ -Subproject commit de46cf783193292b2fdf57eef1d2c328f0c08c8a +Subproject commit 7cc0250a1a559da90965812fdefcba0d54a59c41