forked from mrq/ai-voice-cloning
added button to just load a training set's loss information, added installing broncotc/bitsandbytes-rocm when running setup-rocm.sh
This commit is contained in:
parent
534a761e49
commit
c956d81baf
|
@ -7,12 +7,12 @@ python -m pip install --upgrade pip
|
||||||
python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu117
|
python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu117
|
||||||
python -m pip install -r .\dlas\requirements.txt
|
python -m pip install -r .\dlas\requirements.txt
|
||||||
python -m pip install -r .\tortoise-tts\requirements.txt
|
python -m pip install -r .\tortoise-tts\requirements.txt
|
||||||
python -m pip install -r .\requirements.txt
|
|
||||||
python -m pip install -e .\tortoise-tts\
|
python -m pip install -e .\tortoise-tts\
|
||||||
|
python -m pip install -r .\requirements.txt
|
||||||
|
|
||||||
xcopy .\dlas\bitsandbytes_windows\* .\venv\Lib\site-packages\bitsandbytes\. /Y
|
xcopy .\dlas\bitsandbytes_windows\* .\venv\Lib\site-packages\bitsandbytes\. /Y
|
||||||
xcopy .\dlas\bitsandbytes_windows\cuda_setup\* .\venv\Lib\site-packages\bitsandbytes\cuda_setup\. /Y
|
xcopy .\dlas\bitsandbytes_windows\cuda_setup\* .\venv\Lib\site-packages\bitsandbytes\cuda_setup\. /Y
|
||||||
xcopy .\dlas\bitsandbytes_windows\nn\* .\venv\Lib\site-packages\bitsandbytes\nn\. /Y
|
xcopy .\dlas\bitsandbytes_windows\nn\* .\venv\Lib\site-packages\bitsandbytes\nn\. /Y
|
||||||
|
|
||||||
deactivate
|
|
||||||
pause
|
pause
|
||||||
|
deactivate
|
|
@ -1,14 +1,17 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
# get local dependencies
|
||||||
git submodule init
|
git submodule init
|
||||||
git submodule update --remote
|
git submodule update --remote
|
||||||
|
# setup venv
|
||||||
python3 -m venv venv
|
python3 -m venv venv
|
||||||
source ./venv/bin/activate
|
source ./venv/bin/activate
|
||||||
python3 -m pip install --upgrade pip
|
python3 -m pip install --upgrade pip # just to be safe
|
||||||
# CUDA
|
# CUDA
|
||||||
pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116
|
pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu117
|
||||||
python3 -m pip install -r ./dlas/requirements.txt
|
# install requirements
|
||||||
python3 -m pip install -r ./tortoise-tts/requirements.txt
|
python3 -m pip install -r ./dlas/requirements.txt # instal DLAS requirements
|
||||||
python3 -m pip install -r ./requirements.txt
|
python3 -m pip install -r ./tortoise-tts/requirements.txt # install TorToiSe requirements
|
||||||
python3 -m pip install -e ./tortoise-tts/
|
python3 -m pip install -e ./tortoise-tts/ # install TorToiSe
|
||||||
|
python3 -m pip install -r ./requirements.txt # install local requirements
|
||||||
|
|
||||||
deactivate
|
deactivate
|
|
@ -4,10 +4,11 @@ git submodule update --remote
|
||||||
python -m venv venv
|
python -m venv venv
|
||||||
call .\venv\Scripts\activate.bat
|
call .\venv\Scripts\activate.bat
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
python -m pip install torch torchvision torchaudio torch-directml==0.1.13.1.dev230119
|
python -m pip install torch torchvision torchaudio torch-directml
|
||||||
python -m pip install -r .\dlas\requirements.txt
|
python -m pip install -r .\dlas\requirements.txt
|
||||||
python -m pip install -r .\tortoise-tts\requirements.txt
|
python -m pip install -r .\tortoise-tts\requirements.txt
|
||||||
python -m pip install -r .\requirements.txt
|
|
||||||
python -m pip install -e .\tortoise-tts\
|
python -m pip install -e .\tortoise-tts\
|
||||||
deactivate
|
python -m pip install -r .\requirements.txt
|
||||||
|
|
||||||
pause
|
pause
|
||||||
|
deactivate
|
|
@ -1,14 +1,20 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
# get local dependencies
|
||||||
git submodule init
|
git submodule init
|
||||||
git submodule update --remote
|
git submodule update --remote
|
||||||
|
# setup venv
|
||||||
python3 -m venv venv
|
python3 -m venv venv
|
||||||
source ./venv/bin/activate
|
source ./venv/bin/activate
|
||||||
python3 -m pip install --upgrade pip
|
python3 -m pip install --upgrade pip # just to be safe
|
||||||
# ROCM
|
# ROCM
|
||||||
pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.1.1 # 5.2 does not work for me desu
|
pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.1.1 # 5.2 does not work for me desu
|
||||||
python3 -m pip install -r ./dlas/requirements.txt
|
# install requirements
|
||||||
python3 -m pip install -r ./tortoise-tts/requirements.txt
|
python3 -m pip install -r ./dlas/requirements.txt # instal DLAS requirements
|
||||||
python3 -m pip install -r ./requirements.txt
|
python3 -m pip install -r ./tortoise-tts/requirements.txt # install TorToiSe requirements
|
||||||
python3 -m pip install -e ./tortoise-tts/
|
python3 -m pip install -e ./tortoise-tts/ # install TorToiSe
|
||||||
|
python3 -m pip install -r ./requirements.txt # install local requirements
|
||||||
|
# swap to ROCm version of BitsAndBytes
|
||||||
|
pip3 uninstall bitsandbytes
|
||||||
|
pip3 install git+https://github.com/broncotc/bitsandbytes-rocm
|
||||||
|
|
||||||
deactivate
|
deactivate
|
20
src/utils.py
20
src/utils.py
|
@ -477,7 +477,7 @@ def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm
|
||||||
|
|
||||||
# superfluous, but it cleans up some things
|
# superfluous, but it cleans up some things
|
||||||
class TrainingState():
|
class TrainingState():
|
||||||
def __init__(self, config_path, keep_x_past_datasets=0):
|
def __init__(self, config_path, keep_x_past_datasets=0, start=True):
|
||||||
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['bash', './train.sh', config_path]
|
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['bash', './train.sh', config_path]
|
||||||
|
|
||||||
# parse config to get its iteration
|
# parse config to get its iteration
|
||||||
|
@ -527,7 +527,9 @@ class TrainingState():
|
||||||
self.losses = []
|
self.losses = []
|
||||||
|
|
||||||
self.load_losses()
|
self.load_losses()
|
||||||
|
if keep_x_past_datasets > 0:
|
||||||
self.cleanup_old(keep=keep_x_past_datasets)
|
self.cleanup_old(keep=keep_x_past_datasets)
|
||||||
|
if start:
|
||||||
self.spawn_process()
|
self.spawn_process()
|
||||||
|
|
||||||
def spawn_process(self):
|
def spawn_process(self):
|
||||||
|
@ -778,11 +780,19 @@ def get_training_losses():
|
||||||
return
|
return
|
||||||
return pd.DataFrame(training_state.losses)
|
return pd.DataFrame(training_state.losses)
|
||||||
|
|
||||||
def update_training_dataplot():
|
def update_training_dataplot(config_path=None):
|
||||||
global training_state
|
global training_state
|
||||||
if not training_state or not training_state.losses:
|
update = None
|
||||||
return
|
|
||||||
return gr.LinePlot.update(value=pd.DataFrame(training_state.losses))
|
if not training_state:
|
||||||
|
training_state = TrainingState(config_path=config_path, start=False)
|
||||||
|
update = gr.LinePlot.update(value=pd.DataFrame(training_state.losses))
|
||||||
|
del training_state
|
||||||
|
training_state = None
|
||||||
|
else:
|
||||||
|
update = gr.LinePlot.update(value=pd.DataFrame(training_state.losses))
|
||||||
|
|
||||||
|
return update
|
||||||
|
|
||||||
def reconnect_training(verbose=False, buffer_size=8, progress=gr.Progress(track_tqdm=True)):
|
def reconnect_training(verbose=False, buffer_size=8, progress=gr.Progress(track_tqdm=True)):
|
||||||
global training_state
|
global training_state
|
||||||
|
|
34
src/webui.py
34
src/webui.py
|
@ -527,16 +527,8 @@ def setup_gradio():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
training_configs = gr.Dropdown(label="Training Configuration", choices=get_training_list())
|
training_configs = gr.Dropdown(label="Training Configuration", choices=get_training_list())
|
||||||
refresh_configs = gr.Button(value="Refresh Configurations")
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
start_training_button = gr.Button(value="Train")
|
refresh_configs = gr.Button(value="Refresh Configurations")
|
||||||
stop_training_button = gr.Button(value="Stop")
|
|
||||||
reconnect_training_button = gr.Button(value="Reconnect")
|
|
||||||
with gr.Column():
|
|
||||||
training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
|
|
||||||
verbose_training = gr.Checkbox(label="Verbose Console Output", value=True)
|
|
||||||
training_buffer_size = gr.Slider(label="Console Buffer Size", minimum=4, maximum=32, value=8)
|
|
||||||
training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
|
|
||||||
|
|
||||||
training_loss_graph = gr.LinePlot(label="Training Metrics",
|
training_loss_graph = gr.LinePlot(label="Training Metrics",
|
||||||
x="step",
|
x="step",
|
||||||
|
@ -545,8 +537,19 @@ def setup_gradio():
|
||||||
color="type",
|
color="type",
|
||||||
tooltip=['step', 'value', 'type'],
|
tooltip=['step', 'value', 'type'],
|
||||||
width=600,
|
width=600,
|
||||||
height=350
|
height=350,
|
||||||
)
|
)
|
||||||
|
view_losses = gr.Button(value="View Losses")
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
|
||||||
|
verbose_training = gr.Checkbox(label="Verbose Console Output", value=True)
|
||||||
|
training_buffer_size = gr.Slider(label="Console Buffer Size", minimum=4, maximum=32, value=8)
|
||||||
|
training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
|
||||||
|
with gr.Row():
|
||||||
|
start_training_button = gr.Button(value="Train")
|
||||||
|
stop_training_button = gr.Button(value="Stop")
|
||||||
|
reconnect_training_button = gr.Button(value="Reconnect")
|
||||||
with gr.Tab("Settings"):
|
with gr.Tab("Settings"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
exec_inputs = []
|
exec_inputs = []
|
||||||
|
@ -763,6 +766,17 @@ def setup_gradio():
|
||||||
],
|
],
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
view_losses.click(
|
||||||
|
fn=update_training_dataplot,
|
||||||
|
inputs=[
|
||||||
|
training_configs
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
training_loss_graph,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
stop_training_button.click(stop_training,
|
stop_training_button.click(stop_training,
|
||||||
inputs=None,
|
inputs=None,
|
||||||
outputs=training_output #console_output
|
outputs=training_output #console_output
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
call .\venv\Scripts\activate.bat
|
call .\venv\Scripts\activate.bat
|
||||||
set PATH=.\bin\;%PATH%
|
set PATH=.\bin\;%PATH%
|
||||||
python .\src\main.py %*
|
python .\src\main.py %*
|
||||||
deactivate
|
|
||||||
pause
|
pause
|
|
@ -1,4 +1,4 @@
|
||||||
call .\venv\Scripts\activate.bat
|
call .\venv\Scripts\activate.bat
|
||||||
python ./src/train.py -opt "%1"
|
python ./src/train.py -opt "%1"
|
||||||
deactivate
|
|
||||||
pause
|
pause
|
||||||
|
deactivate
|
|
@ -1,3 +1,15 @@
|
||||||
git fetch --all
|
git fetch --all
|
||||||
git reset --hard origin/master
|
git reset --hard origin/master
|
||||||
call .\update.bat
|
call .\update.bat
|
||||||
|
|
||||||
|
python -m venv venv
|
||||||
|
call .\venv\Scripts\activate.bat
|
||||||
|
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
python -m pip install -r .\dlas\requirements.txt
|
||||||
|
python -m pip install -r .\tortoise-tts\requirements.txt
|
||||||
|
python -m pip install -e .\tortoise-tts
|
||||||
|
python -m pip install -r .\requirements.txt
|
||||||
|
|
||||||
|
pause
|
||||||
|
deactivate
|
|
@ -1,4 +1,17 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
git fetch --all
|
git fetch --all
|
||||||
git reset --hard origin/master
|
git reset --hard origin/master
|
||||||
|
|
||||||
./update.sh
|
./update.sh
|
||||||
|
|
||||||
|
# force install requirements
|
||||||
|
python3 -m venv venv
|
||||||
|
source ./venv/bin/activate
|
||||||
|
|
||||||
|
python3 -m pip install --upgrade pip
|
||||||
|
python3 -m pip install -r ./dlas/requirements.txt
|
||||||
|
python3 -m pip install -r ./tortoise-tts/requirements.txt
|
||||||
|
python3 -m pip install -e ./tortoise-tts
|
||||||
|
python3 -m pip install -r ./requirements.txt
|
||||||
|
|
||||||
|
deactivate
|
12
update.bat
12
update.bat
|
@ -1,14 +1,2 @@
|
||||||
git pull
|
git pull
|
||||||
git submodule update --remote
|
git submodule update --remote
|
||||||
|
|
||||||
python -m venv venv
|
|
||||||
call .\venv\Scripts\activate.bat
|
|
||||||
|
|
||||||
python -m pip install --upgrade pip
|
|
||||||
python -m pip install -r .\dlas\requirements.txt
|
|
||||||
python -m pip install -r .\tortoise-tts\requirements.txt
|
|
||||||
python -m pip install -e .\tortoise-tts
|
|
||||||
python -m pip install -r .\requirements.txt
|
|
||||||
|
|
||||||
deactivate
|
|
||||||
pause
|
|
10
update.sh
10
update.sh
|
@ -2,14 +2,4 @@
|
||||||
git pull
|
git pull
|
||||||
git submodule update --remote
|
git submodule update --remote
|
||||||
|
|
||||||
python3 -m venv venv
|
|
||||||
source ./venv/bin/activate
|
|
||||||
|
|
||||||
python3 -m pip install --upgrade pip
|
|
||||||
python3 -m pip install -r ./dlas/requirements.txt
|
|
||||||
python3 -m pip install -r ./tortoise-tts/requirements.txt
|
|
||||||
python3 -m pip install -e ./tortoise-tts
|
|
||||||
python3 -m pip install -r ./requirements.txt
|
|
||||||
|
|
||||||
|
|
||||||
deactivate
|
deactivate
|
Loading…
Reference in New Issue
Block a user