fixed some files not copying for bitsandbytes (I was wrong to assume it copied folders too), fixed stopping generating and training, some other thing that I forgot since it's been slowly worked on in my small free times
This commit is contained in:
parent
e5e16bc5b5
commit
d5d8821a9d
2
dlas
2
dlas
|
@ -1 +1 @@
|
||||||
Subproject commit 1433b7c0eabcc797dac8e68e9acc3043b9a28e12
|
Subproject commit 0f04206aa20b1ab632c0cbf7bb6a43d5c1fd9eb0
|
|
@ -10,7 +10,9 @@ python -m pip install -r .\tortoise-tts\requirements.txt
|
||||||
python -m pip install -r .\requirements.txt
|
python -m pip install -r .\requirements.txt
|
||||||
python -m pip install -e .\tortoise-tts\
|
python -m pip install -e .\tortoise-tts\
|
||||||
|
|
||||||
copy .\dlas\bitsandbytes_windows\* .\venv\Lib\site-packages\bitsandbytes\. /Y
|
xcopy .\dlas\bitsandbytes_windows\* .\venv\Lib\site-packages\bitsandbytes\. /Y
|
||||||
|
xcopy .\dlas\bitsandbytes_windows\cuda_setup\* .\venv\Lib\site-packages\bitsandbytes\cuda_setup\. /Y
|
||||||
|
xcopy .\dlas\bitsandbytes_windows\nn\* .\venv\Lib\site-packages\bitsandbytes\nn\. /Y
|
||||||
|
|
||||||
deactivate
|
deactivate
|
||||||
pause
|
pause
|
|
@ -1,14 +1,14 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from utils import *
|
|
||||||
from webui import *
|
|
||||||
|
|
||||||
if 'TORTOISE_MODELS_DIR' not in os.environ:
|
if 'TORTOISE_MODELS_DIR' not in os.environ:
|
||||||
os.environ['TORTOISE_MODELS_DIR'] = os.path.realpath(os.path.join(os.getcwd(), './models/tortoise/'))
|
os.environ['TORTOISE_MODELS_DIR'] = os.path.realpath(os.path.join(os.getcwd(), './models/tortoise/'))
|
||||||
|
|
||||||
if 'TRANSFORMERS_CACHE' not in os.environ:
|
if 'TRANSFORMERS_CACHE' not in os.environ:
|
||||||
os.environ['TRANSFORMERS_CACHE'] = os.path.realpath(os.path.join(os.getcwd(), './models/transformers/'))
|
os.environ['TRANSFORMERS_CACHE'] = os.path.realpath(os.path.join(os.getcwd(), './models/transformers/'))
|
||||||
|
|
||||||
|
from utils import *
|
||||||
|
from webui import *
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = setup_args()
|
args = setup_args()
|
||||||
|
|
||||||
|
|
13
src/train.py
13
src/train.py
|
@ -2,6 +2,19 @@ import os
|
||||||
import sys
|
import sys
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
if 'BITSANDBYTES_OVERRIDE_LINEAR' not in os.environ:
|
||||||
|
os.environ['BITSANDBYTES_OVERRIDE_LINEAR'] = '0'
|
||||||
|
if 'BITSANDBYTES_OVERRIDE_EMBEDDING' not in os.environ:
|
||||||
|
os.environ['BITSANDBYTES_OVERRIDE_EMBEDDING'] = '1'
|
||||||
|
if 'BITSANDBYTES_OVERRIDE_ADAM' not in os.environ:
|
||||||
|
os.environ['BITSANDBYTES_OVERRIDE_ADAM'] = '1'
|
||||||
|
if 'BITSANDBYTES_OVERRIDE_ADAMW' not in os.environ:
|
||||||
|
os.environ['BITSANDBYTES_OVERRIDE_ADAMW'] = '1'
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
# this is some massive kludge that only works if it's called from a shell and not an import/PIP package
|
# this is some massive kludge that only works if it's called from a shell and not an import/PIP package
|
||||||
# it's smart-yet-irritating module-model loader breaks when trying to load something specifically when not from a shell
|
# it's smart-yet-irritating module-model loader breaks when trying to load something specifically when not from a shell
|
||||||
|
|
||||||
|
|
26
src/utils.py
26
src/utils.py
|
@ -407,8 +407,8 @@ def generate(
|
||||||
)
|
)
|
||||||
|
|
||||||
def cancel_generate():
|
def cancel_generate():
|
||||||
from tortoise.api import STOP_SIGNAL
|
import tortoise.api
|
||||||
STOP_SIGNAL = True
|
tortoise.api.STOP_SIGNAL = True
|
||||||
|
|
||||||
def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
|
def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
|
||||||
global tts
|
global tts
|
||||||
|
@ -557,6 +557,7 @@ def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress
|
||||||
if res:
|
if res:
|
||||||
yield res
|
yield res
|
||||||
|
|
||||||
|
if training_state:
|
||||||
training_state.process.stdout.close()
|
training_state.process.stdout.close()
|
||||||
return_code = training_state.process.wait()
|
return_code = training_state.process.wait()
|
||||||
training_state = None
|
training_state = None
|
||||||
|
@ -575,10 +576,15 @@ def reconnect_training(config_path, verbose=False, buffer_size=8, progress=gr.Pr
|
||||||
yield res
|
yield res
|
||||||
|
|
||||||
def stop_training():
|
def stop_training():
|
||||||
global training_process
|
global training_state
|
||||||
if training_process is None:
|
if training_state is None:
|
||||||
return "No training in progress"
|
return "No training in progress"
|
||||||
training_process.kill()
|
print("Killing training process...")
|
||||||
|
training_state.killed = True
|
||||||
|
training_state.process.stdout.close()
|
||||||
|
training_state.process.kill()
|
||||||
|
return_code = training_state.process.wait()
|
||||||
|
training_state = None
|
||||||
return "Training cancelled"
|
return "Training cancelled"
|
||||||
|
|
||||||
def get_halfp_model_path():
|
def get_halfp_model_path():
|
||||||
|
@ -1234,6 +1240,7 @@ def load_voicefixer(restart=False):
|
||||||
print("Loading Voicefixer")
|
print("Loading Voicefixer")
|
||||||
from voicefixer import VoiceFixer
|
from voicefixer import VoiceFixer
|
||||||
voicefixer = VoiceFixer()
|
voicefixer = VoiceFixer()
|
||||||
|
print("Loaded Voicefixer")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error occurred while tring to initialize voicefixer: {e}")
|
print(f"Error occurred while tring to initialize voicefixer: {e}")
|
||||||
|
|
||||||
|
@ -1244,6 +1251,7 @@ def unload_voicefixer():
|
||||||
print("Unloading Voicefixer")
|
print("Unloading Voicefixer")
|
||||||
del voicefixer
|
del voicefixer
|
||||||
voicefixer = None
|
voicefixer = None
|
||||||
|
print("Unloaded Voicefixer")
|
||||||
|
|
||||||
do_gc()
|
do_gc()
|
||||||
|
|
||||||
|
@ -1254,9 +1262,11 @@ def load_whisper_model(name=None, progress=None):
|
||||||
name = args.whisper_model
|
name = args.whisper_model
|
||||||
else:
|
else:
|
||||||
args.whisper_model = name
|
args.whisper_model = name
|
||||||
|
save_args_settings()
|
||||||
|
|
||||||
notify_progress(f"Loading Whisper model: {args.whisper_model}", progress)
|
notify_progress(f"Loading Whisper model: {args.whisper_model}", progress)
|
||||||
whisper_model = whisper.load_model(args.whisper_model)
|
whisper_model = whisper.load_model(args.whisper_model)
|
||||||
|
print("Loaded Whisper model")
|
||||||
|
|
||||||
def unload_whisper():
|
def unload_whisper():
|
||||||
global whisper_model
|
global whisper_model
|
||||||
|
@ -1265,6 +1275,7 @@ def unload_whisper():
|
||||||
print("Unloading Whisper")
|
print("Unloading Whisper")
|
||||||
del whisper_model
|
del whisper_model
|
||||||
whisper_model = None
|
whisper_model = None
|
||||||
|
print("Unloaded Whisper")
|
||||||
|
|
||||||
do_gc()
|
do_gc()
|
||||||
|
|
||||||
|
@ -1272,8 +1283,11 @@ def update_whisper_model(name, progress=None):
|
||||||
if not name:
|
if not name:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
global whisper_model
|
global whisper_model
|
||||||
if whisper_model:
|
if whisper_model:
|
||||||
unload_whisper()
|
unload_whisper()
|
||||||
|
|
||||||
load_whisper_model(name)
|
load_whisper_model(name)
|
||||||
|
else:
|
||||||
|
args.whisper_model = name
|
||||||
|
save_args_settings()
|
|
@ -78,7 +78,7 @@ def run_generation(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
message = str(e)
|
message = str(e)
|
||||||
if message == "Kill signal detected":
|
if message == "Kill signal detected":
|
||||||
reload_tts()
|
unload_tts()
|
||||||
|
|
||||||
raise gr.Error(message)
|
raise gr.Error(message)
|
||||||
|
|
||||||
|
@ -745,7 +745,7 @@ def setup_gradio():
|
||||||
if args.check_for_updates:
|
if args.check_for_updates:
|
||||||
ui.load(check_for_updates)
|
ui.load(check_for_updates)
|
||||||
|
|
||||||
stop.click(fn=cancel_generate, inputs=None, outputs=None, cancels=[submit_event])
|
stop.click(fn=cancel_generate, inputs=None, outputs=None)
|
||||||
|
|
||||||
|
|
||||||
ui.queue(concurrency_count=args.concurrency_count)
|
ui.queue(concurrency_count=args.concurrency_count)
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
Subproject commit de46cf783193292b2fdf57eef1d2c328f0c08c8a
|
Subproject commit 7cc0250a1a559da90965812fdefcba0d54a59c41
|
Loading…
Reference in New Issue
Block a user