forked from mrq/ai-voice-cloning
fixed some files not copying for bitsandbytes (I was wrong to assume it copied folders too), fixed stopping generating and training, some other thing that I forgot since it's been slowly worked on in my small free times
This commit is contained in:
parent
e5e16bc5b5
commit
d5d8821a9d
2
dlas
2
dlas
|
@ -1 +1 @@
|
|||
Subproject commit 1433b7c0eabcc797dac8e68e9acc3043b9a28e12
|
||||
Subproject commit 0f04206aa20b1ab632c0cbf7bb6a43d5c1fd9eb0
|
|
@ -10,7 +10,9 @@ python -m pip install -r .\tortoise-tts\requirements.txt
|
|||
python -m pip install -r .\requirements.txt
|
||||
python -m pip install -e .\tortoise-tts\
|
||||
|
||||
copy .\dlas\bitsandbytes_windows\* .\venv\Lib\site-packages\bitsandbytes\. /Y
|
||||
xcopy .\dlas\bitsandbytes_windows\* .\venv\Lib\site-packages\bitsandbytes\. /Y
|
||||
xcopy .\dlas\bitsandbytes_windows\cuda_setup\* .\venv\Lib\site-packages\bitsandbytes\cuda_setup\. /Y
|
||||
xcopy .\dlas\bitsandbytes_windows\nn\* .\venv\Lib\site-packages\bitsandbytes\nn\. /Y
|
||||
|
||||
deactivate
|
||||
pause
|
|
@ -1,14 +1,14 @@
|
|||
import os
|
||||
|
||||
from utils import *
|
||||
from webui import *
|
||||
|
||||
if 'TORTOISE_MODELS_DIR' not in os.environ:
|
||||
os.environ['TORTOISE_MODELS_DIR'] = os.path.realpath(os.path.join(os.getcwd(), './models/tortoise/'))
|
||||
|
||||
if 'TRANSFORMERS_CACHE' not in os.environ:
|
||||
os.environ['TRANSFORMERS_CACHE'] = os.path.realpath(os.path.join(os.getcwd(), './models/transformers/'))
|
||||
|
||||
from utils import *
|
||||
from webui import *
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = setup_args()
|
||||
|
||||
|
|
13
src/train.py
13
src/train.py
|
@ -2,6 +2,19 @@ import os
|
|||
import sys
|
||||
import argparse
|
||||
|
||||
|
||||
"""
|
||||
if 'BITSANDBYTES_OVERRIDE_LINEAR' not in os.environ:
|
||||
os.environ['BITSANDBYTES_OVERRIDE_LINEAR'] = '0'
|
||||
if 'BITSANDBYTES_OVERRIDE_EMBEDDING' not in os.environ:
|
||||
os.environ['BITSANDBYTES_OVERRIDE_EMBEDDING'] = '1'
|
||||
if 'BITSANDBYTES_OVERRIDE_ADAM' not in os.environ:
|
||||
os.environ['BITSANDBYTES_OVERRIDE_ADAM'] = '1'
|
||||
if 'BITSANDBYTES_OVERRIDE_ADAMW' not in os.environ:
|
||||
os.environ['BITSANDBYTES_OVERRIDE_ADAMW'] = '1'
|
||||
"""
|
||||
|
||||
|
||||
# this is some massive kludge that only works if it's called from a shell and not an import/PIP package
|
||||
# it's smart-yet-irritating module-model loader breaks when trying to load something specifically when not from a shell
|
||||
|
||||
|
|
34
src/utils.py
34
src/utils.py
|
@ -407,8 +407,8 @@ def generate(
|
|||
)
|
||||
|
||||
def cancel_generate():
|
||||
from tortoise.api import STOP_SIGNAL
|
||||
STOP_SIGNAL = True
|
||||
import tortoise.api
|
||||
tortoise.api.STOP_SIGNAL = True
|
||||
|
||||
def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
|
||||
global tts
|
||||
|
@ -557,9 +557,10 @@ def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress
|
|||
if res:
|
||||
yield res
|
||||
|
||||
training_state.process.stdout.close()
|
||||
return_code = training_state.process.wait()
|
||||
training_state = None
|
||||
if training_state:
|
||||
training_state.process.stdout.close()
|
||||
return_code = training_state.process.wait()
|
||||
training_state = None
|
||||
|
||||
#if return_code:
|
||||
# raise subprocess.CalledProcessError(return_code, cmd)
|
||||
|
@ -575,10 +576,15 @@ def reconnect_training(config_path, verbose=False, buffer_size=8, progress=gr.Pr
|
|||
yield res
|
||||
|
||||
def stop_training():
|
||||
global training_process
|
||||
if training_process is None:
|
||||
global training_state
|
||||
if training_state is None:
|
||||
return "No training in progress"
|
||||
training_process.kill()
|
||||
print("Killing training process...")
|
||||
training_state.killed = True
|
||||
training_state.process.stdout.close()
|
||||
training_state.process.kill()
|
||||
return_code = training_state.process.wait()
|
||||
training_state = None
|
||||
return "Training cancelled"
|
||||
|
||||
def get_halfp_model_path():
|
||||
|
@ -1234,6 +1240,7 @@ def load_voicefixer(restart=False):
|
|||
print("Loading Voicefixer")
|
||||
from voicefixer import VoiceFixer
|
||||
voicefixer = VoiceFixer()
|
||||
print("Loaded Voicefixer")
|
||||
except Exception as e:
|
||||
print(f"Error occurred while tring to initialize voicefixer: {e}")
|
||||
|
||||
|
@ -1244,6 +1251,7 @@ def unload_voicefixer():
|
|||
print("Unloading Voicefixer")
|
||||
del voicefixer
|
||||
voicefixer = None
|
||||
print("Unloaded Voicefixer")
|
||||
|
||||
do_gc()
|
||||
|
||||
|
@ -1254,9 +1262,11 @@ def load_whisper_model(name=None, progress=None):
|
|||
name = args.whisper_model
|
||||
else:
|
||||
args.whisper_model = name
|
||||
save_args_settings()
|
||||
|
||||
notify_progress(f"Loading Whisper model: {args.whisper_model}", progress)
|
||||
whisper_model = whisper.load_model(args.whisper_model)
|
||||
print("Loaded Whisper model")
|
||||
|
||||
def unload_whisper():
|
||||
global whisper_model
|
||||
|
@ -1265,6 +1275,7 @@ def unload_whisper():
|
|||
print("Unloading Whisper")
|
||||
del whisper_model
|
||||
whisper_model = None
|
||||
print("Unloaded Whisper")
|
||||
|
||||
do_gc()
|
||||
|
||||
|
@ -1272,8 +1283,11 @@ def update_whisper_model(name, progress=None):
|
|||
if not name:
|
||||
return
|
||||
|
||||
|
||||
global whisper_model
|
||||
if whisper_model:
|
||||
unload_whisper()
|
||||
|
||||
load_whisper_model(name)
|
||||
load_whisper_model(name)
|
||||
else:
|
||||
args.whisper_model = name
|
||||
save_args_settings()
|
|
@ -78,7 +78,7 @@ def run_generation(
|
|||
except Exception as e:
|
||||
message = str(e)
|
||||
if message == "Kill signal detected":
|
||||
reload_tts()
|
||||
unload_tts()
|
||||
|
||||
raise gr.Error(message)
|
||||
|
||||
|
@ -745,7 +745,7 @@ def setup_gradio():
|
|||
if args.check_for_updates:
|
||||
ui.load(check_for_updates)
|
||||
|
||||
stop.click(fn=cancel_generate, inputs=None, outputs=None, cancels=[submit_event])
|
||||
stop.click(fn=cancel_generate, inputs=None, outputs=None)
|
||||
|
||||
|
||||
ui.queue(concurrency_count=args.concurrency_count)
|
||||
|
|
|
@ -1 +1 @@
|
|||
Subproject commit de46cf783193292b2fdf57eef1d2c328f0c08c8a
|
||||
Subproject commit 7cc0250a1a559da90965812fdefcba0d54a59c41
|
Loading…
Reference in New Issue
Block a user