fixed some files not copying for bitsandbytes (I was wrong to assume it copied folders too), fixed stopping generating and training, some other thing that I forgot since it's been slowly worked on in my small free times

This commit is contained in:
mrq 2023-02-24 23:13:13 +00:00
parent e5e16bc5b5
commit d5d8821a9d
7 changed files with 47 additions and 18 deletions

2
dlas

@ -1 +1 @@
Subproject commit 1433b7c0eabcc797dac8e68e9acc3043b9a28e12
Subproject commit 0f04206aa20b1ab632c0cbf7bb6a43d5c1fd9eb0

View File

@ -10,7 +10,9 @@ python -m pip install -r .\tortoise-tts\requirements.txt
python -m pip install -r .\requirements.txt
python -m pip install -e .\tortoise-tts\
copy .\dlas\bitsandbytes_windows\* .\venv\Lib\site-packages\bitsandbytes\. /Y
xcopy .\dlas\bitsandbytes_windows\* .\venv\Lib\site-packages\bitsandbytes\. /Y
xcopy .\dlas\bitsandbytes_windows\cuda_setup\* .\venv\Lib\site-packages\bitsandbytes\cuda_setup\. /Y
xcopy .\dlas\bitsandbytes_windows\nn\* .\venv\Lib\site-packages\bitsandbytes\nn\. /Y
deactivate
pause

View File

@ -1,14 +1,14 @@
import os
from utils import *
from webui import *
if 'TORTOISE_MODELS_DIR' not in os.environ:
os.environ['TORTOISE_MODELS_DIR'] = os.path.realpath(os.path.join(os.getcwd(), './models/tortoise/'))
if 'TRANSFORMERS_CACHE' not in os.environ:
os.environ['TRANSFORMERS_CACHE'] = os.path.realpath(os.path.join(os.getcwd(), './models/transformers/'))
from utils import *
from webui import *
if __name__ == "__main__":
args = setup_args()

View File

@ -2,6 +2,19 @@ import os
import sys
import argparse
"""
if 'BITSANDBYTES_OVERRIDE_LINEAR' not in os.environ:
os.environ['BITSANDBYTES_OVERRIDE_LINEAR'] = '0'
if 'BITSANDBYTES_OVERRIDE_EMBEDDING' not in os.environ:
os.environ['BITSANDBYTES_OVERRIDE_EMBEDDING'] = '1'
if 'BITSANDBYTES_OVERRIDE_ADAM' not in os.environ:
os.environ['BITSANDBYTES_OVERRIDE_ADAM'] = '1'
if 'BITSANDBYTES_OVERRIDE_ADAMW' not in os.environ:
os.environ['BITSANDBYTES_OVERRIDE_ADAMW'] = '1'
"""
# this is some massive kludge that only works if it's called from a shell and not an import/PIP package
# it's smart-yet-irritating module-model loader breaks when trying to load something specifically when not from a shell

View File

@ -407,8 +407,8 @@ def generate(
)
def cancel_generate():
from tortoise.api import STOP_SIGNAL
STOP_SIGNAL = True
import tortoise.api
tortoise.api.STOP_SIGNAL = True
def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
global tts
@ -557,6 +557,7 @@ def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress
if res:
yield res
if training_state:
training_state.process.stdout.close()
return_code = training_state.process.wait()
training_state = None
@ -575,10 +576,15 @@ def reconnect_training(config_path, verbose=False, buffer_size=8, progress=gr.Pr
yield res
def stop_training():
global training_process
if training_process is None:
global training_state
if training_state is None:
return "No training in progress"
training_process.kill()
print("Killing training process...")
training_state.killed = True
training_state.process.stdout.close()
training_state.process.kill()
return_code = training_state.process.wait()
training_state = None
return "Training cancelled"
def get_halfp_model_path():
@ -1234,6 +1240,7 @@ def load_voicefixer(restart=False):
print("Loading Voicefixer")
from voicefixer import VoiceFixer
voicefixer = VoiceFixer()
print("Loaded Voicefixer")
except Exception as e:
print(f"Error occurred while tring to initialize voicefixer: {e}")
@ -1244,6 +1251,7 @@ def unload_voicefixer():
print("Unloading Voicefixer")
del voicefixer
voicefixer = None
print("Unloaded Voicefixer")
do_gc()
@ -1254,9 +1262,11 @@ def load_whisper_model(name=None, progress=None):
name = args.whisper_model
else:
args.whisper_model = name
save_args_settings()
notify_progress(f"Loading Whisper model: {args.whisper_model}", progress)
whisper_model = whisper.load_model(args.whisper_model)
print("Loaded Whisper model")
def unload_whisper():
global whisper_model
@ -1265,6 +1275,7 @@ def unload_whisper():
print("Unloading Whisper")
del whisper_model
whisper_model = None
print("Unloaded Whisper")
do_gc()
@ -1272,8 +1283,11 @@ def update_whisper_model(name, progress=None):
if not name:
return
global whisper_model
if whisper_model:
unload_whisper()
load_whisper_model(name)
else:
args.whisper_model = name
save_args_settings()

View File

@ -78,7 +78,7 @@ def run_generation(
except Exception as e:
message = str(e)
if message == "Kill signal detected":
reload_tts()
unload_tts()
raise gr.Error(message)
@ -745,7 +745,7 @@ def setup_gradio():
if args.check_for_updates:
ui.load(check_for_updates)
stop.click(fn=cancel_generate, inputs=None, outputs=None, cancels=[submit_event])
stop.click(fn=cancel_generate, inputs=None, outputs=None)
ui.queue(concurrency_count=args.concurrency_count)

@ -1 +1 @@
Subproject commit de46cf783193292b2fdf57eef1d2c328f0c08c8a
Subproject commit 7cc0250a1a559da90965812fdefcba0d54a59c41