huge success
This commit is contained in:
parent
aa96edde2f
commit
225dee22d4
17
README.md
17
README.md
|
@ -16,4 +16,19 @@ Please consult [the wiki](https://git.ecker.tech/mrq/ai-voice-cloning/wiki) for
|
||||||
|
|
||||||
## Bug Reporting
|
## Bug Reporting
|
||||||
|
|
||||||
If you run into any problems, please refer to the [issues you may encounter](https://git.ecker.tech/mrq/ai-voice-cloning/wiki/Issues) wiki page first. Please don't hesitate to submit an issue.
|
If you run into any problems, please refer to the [issues you may encounter](https://git.ecker.tech/mrq/ai-voice-cloning/wiki/Issues) wiki page first. Please don't hesitate to submit an issue.
|
||||||
|
|
||||||
|
## Changelogs
|
||||||
|
|
||||||
|
Below will be a rather-loose changelogss, as I don't think I have a way to chronicle them outside of commit messages:
|
||||||
|
|
||||||
|
### `2023.02.22`
|
||||||
|
|
||||||
|
* greatly reduced VRAM consumption through the use of [TimDettmers/bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
|
||||||
|
* cleaned up section of code that handled parsing output from training script
|
||||||
|
* added button to reconnect to the training script's output (sometimes skips a line to update, but it's better than nothing)
|
||||||
|
* actually update submodules from the update script (somehow forgot to pass `--remote`)
|
||||||
|
|
||||||
|
### `Before 2023.02.22`
|
||||||
|
|
||||||
|
Refer to commit logs.
|
2
dlas
2
dlas
|
@ -1 +1 @@
|
||||||
Subproject commit 6c284ef8ec4c4769de3181d90ac96ff63581ef55
|
Subproject commit 0ef8ab6872813d1021d4d75e82b63377d28f5a06
|
|
@ -2,7 +2,7 @@ name: ${name}
|
||||||
model: extensibletrainer
|
model: extensibletrainer
|
||||||
scale: 1
|
scale: 1
|
||||||
gpu_ids: [0] # <-- unless you have multiple gpus, use this
|
gpu_ids: [0] # <-- unless you have multiple gpus, use this
|
||||||
start_step: -1
|
start_step: 0
|
||||||
checkpointing_enabled: true # <-- Gradient checkpointing. Enable for huge GPU memory savings. Disable for distributed training.
|
checkpointing_enabled: true # <-- Gradient checkpointing. Enable for huge GPU memory savings. Disable for distributed training.
|
||||||
fp16: ${float16} # might want to check this out
|
fp16: ${float16} # might want to check this out
|
||||||
wandb: false # <-- enable to log to wandb. tensorboard logging is always enabled.
|
wandb: false # <-- enable to log to wandb. tensorboard logging is always enabled.
|
||||||
|
|
|
@ -9,5 +9,8 @@ python -m pip install -r .\dlas\requirements.txt
|
||||||
python -m pip install -r .\tortoise-tts\requirements.txt
|
python -m pip install -r .\tortoise-tts\requirements.txt
|
||||||
python -m pip install -r .\requirements.txt
|
python -m pip install -r .\requirements.txt
|
||||||
python -m pip install -e .\tortoise-tts\
|
python -m pip install -e .\tortoise-tts\
|
||||||
|
|
||||||
|
copy .\dlas\bitsandbytes_windows\* .\venv\Lib\site-packages\bitsandbytes\. /Y
|
||||||
|
|
||||||
deactivate
|
deactivate
|
||||||
pause
|
pause
|
17
src/train.py
17
src/train.py
|
@ -1,8 +1,8 @@
|
||||||
import torch
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# this is some massive kludge that only works if it's called from a shell and not an import/PIP package
|
# this is some massive kludge that only works if it's called from a shell and not an import/PIP package
|
||||||
# it's smart-yet-irritating module-model loader breaks when trying to load something specifically when not from a shell
|
# it's smart-yet-irritating module-model loader breaks when trying to load something specifically when not from a shell
|
||||||
|
@ -19,6 +19,17 @@ sys.path.insert(0, './dlas/')
|
||||||
# don't even really bother trying to get DLAS PIP'd
|
# don't even really bother trying to get DLAS PIP'd
|
||||||
# without kludge, it'll have to be accessible as `codes` and not `dlas`
|
# without kludge, it'll have to be accessible as `codes` and not `dlas`
|
||||||
|
|
||||||
|
import torch_intermediary
|
||||||
|
# could just move this auto-toggle into the MITM script
|
||||||
|
try:
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
torch_intermediary.OVERRIDE_ADAM = True
|
||||||
|
torch_intermediary.OVERRIDE_ADAMW = True
|
||||||
|
except Exception as e:
|
||||||
|
torch_intermediary.OVERRIDE_ADAM = False
|
||||||
|
torch_intermediary.OVERRIDE_ADAMW = False
|
||||||
|
|
||||||
|
import torch
|
||||||
from codes import train as tr
|
from codes import train as tr
|
||||||
from utils import util, options as option
|
from utils import util, options as option
|
||||||
|
|
||||||
|
|
188
src/utils.py
188
src/utils.py
|
@ -17,6 +17,7 @@ import urllib.request
|
||||||
import signal
|
import signal
|
||||||
import gc
|
import gc
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import yaml
|
||||||
|
|
||||||
import tqdm
|
import tqdm
|
||||||
import torch
|
import torch
|
||||||
|
@ -26,6 +27,7 @@ import gradio as gr
|
||||||
import gradio.utils
|
import gradio.utils
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
from tortoise.api import TextToSpeech, MODELS, get_model_path
|
from tortoise.api import TextToSpeech, MODELS, get_model_path
|
||||||
from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir
|
from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir
|
||||||
|
@ -42,7 +44,7 @@ tts_loading = False
|
||||||
webui = None
|
webui = None
|
||||||
voicefixer = None
|
voicefixer = None
|
||||||
whisper_model = None
|
whisper_model = None
|
||||||
training_process = None
|
training_state = None
|
||||||
|
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
|
@ -434,8 +436,88 @@ def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm
|
||||||
|
|
||||||
return voice
|
return voice
|
||||||
|
|
||||||
|
# superfluous, but it cleans up some things
|
||||||
|
class TrainingState():
|
||||||
|
def __init__(self, config_path, buffer_size=8):
|
||||||
|
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['bash', './train.sh', config_path]
|
||||||
|
|
||||||
|
# parse config to get its iteration
|
||||||
|
with open(config_path, 'r') as file:
|
||||||
|
self.config = yaml.safe_load(file)
|
||||||
|
|
||||||
|
self.it = 0
|
||||||
|
self.its = self.config['train']['niter']
|
||||||
|
|
||||||
|
self.checkpoint = 0
|
||||||
|
self.checkpoints = int(self.its / self.config['logger']['save_checkpoint_freq'])
|
||||||
|
|
||||||
|
self.buffer = []
|
||||||
|
|
||||||
|
self.open_state = False
|
||||||
|
self.training_started = False
|
||||||
|
|
||||||
|
self.info = {}
|
||||||
|
self.status = ""
|
||||||
|
|
||||||
|
self.it_rate = ""
|
||||||
|
self.it_time_start = 0
|
||||||
|
self.it_time_end = 0
|
||||||
|
self.eta = "?"
|
||||||
|
|
||||||
|
print("Spawning process: ", " ".join(self.cmd))
|
||||||
|
self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
|
||||||
|
|
||||||
|
def parse(self, line, verbose=False, buffer_size=8, progress=None):
|
||||||
|
self.buffer.append(f'{line}')
|
||||||
|
|
||||||
|
# rip out iteration info
|
||||||
|
if not self.training_started:
|
||||||
|
if line.find('Start training from epoch') >= 0:
|
||||||
|
self.it_time_start = time.time()
|
||||||
|
self.training_started = True # could just leverage the above variable, but this is python, and there's no point in these aggressive microoptimizations
|
||||||
|
|
||||||
|
match = re.findall(r'iter: ([\d,]+)', line)
|
||||||
|
if match and len(match) > 0:
|
||||||
|
self.it = int(match[0].replace(",", ""))
|
||||||
|
elif progress is not None:
|
||||||
|
if line.find(' 0%|') == 0:
|
||||||
|
self.open_state = True
|
||||||
|
elif line.find('100%|') == 0 and self.open_state:
|
||||||
|
self.open_state = False
|
||||||
|
self.it = self.it + 1
|
||||||
|
|
||||||
|
self.it_time_end = time.time()
|
||||||
|
self.it_time_delta = self.it_time_end-self.it_time_start
|
||||||
|
self.it_time_start = time.time()
|
||||||
|
self.it_rate = f'[{"{:.3f}".format(self.it_time_delta)}s/it]' if self.it_time_delta >= 1 else f'[{"{:.3f}".format(1/self.it_time_delta)}it/s]' # I doubt anyone will have it/s rates, but its here
|
||||||
|
self.eta = (self.its - self.it) * self.it_time_delta
|
||||||
|
self.eta_hhmmss = str(timedelta(seconds=int(self.eta)))
|
||||||
|
|
||||||
|
progress(self.it / float(self.its), f'[{self.it}/{self.its}] [ETA: {self.eta_hhmmss}] {self.it_rate} Training... {self.status}')
|
||||||
|
|
||||||
|
if line.find('INFO: [epoch:') >= 0:
|
||||||
|
# easily rip out our stats...
|
||||||
|
match = re.findall(r'\b([a-z_0-9]+?)\b: ([0-9]\.[0-9]+?e[+-]\d+)\b', line)
|
||||||
|
if match and len(match) > 0:
|
||||||
|
for k, v in match:
|
||||||
|
self.info[k] = float(v)
|
||||||
|
|
||||||
|
# ...and returns our loss rate
|
||||||
|
# it would be nice for losses to be shown at every step
|
||||||
|
if 'loss_gpt_total' in self.info:
|
||||||
|
# self.info['step'] returns the steps, not iterations, so we won't even bother ripping the reported step count, as iteration count won't get ripped from the regex
|
||||||
|
self.status = f"Total loss at iteration {self.it}: {self.info['loss_gpt_total']}"
|
||||||
|
elif line.find('Saving models and training states') >= 0:
|
||||||
|
self.checkpoint = self.checkpoint + 1
|
||||||
|
progress(self.checkpoint / float(self.checkpoints), f'[{self.checkpoint}/{self.checkpoints}] Saving checkpoint...')
|
||||||
|
|
||||||
|
if verbose or not self.training_started:
|
||||||
|
return "".join(self.buffer[-buffer_size:])
|
||||||
|
|
||||||
def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress(track_tqdm=True)):
|
def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress(track_tqdm=True)):
|
||||||
global training_process
|
global training_state
|
||||||
|
if training_state and training_state.process:
|
||||||
|
return "Training already in progress"
|
||||||
|
|
||||||
# I don't know if this is still necessary, as it was bitching at me for not doing this, despite it being in a separate process
|
# I don't know if this is still necessary, as it was bitching at me for not doing this, despite it being in a separate process
|
||||||
torch.multiprocessing.freeze_support()
|
torch.multiprocessing.freeze_support()
|
||||||
|
@ -444,90 +526,38 @@ def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress
|
||||||
unload_whisper()
|
unload_whisper()
|
||||||
unload_voicefixer()
|
unload_voicefixer()
|
||||||
|
|
||||||
cmd = ['train.bat', config_path] if os.name == "nt" else ['bash', './train.sh', config_path]
|
training_state = TrainingState(config_path=config_path, buffer_size=buffer_size)
|
||||||
print("Spawning process: ", " ".join(cmd))
|
|
||||||
training_process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
|
|
||||||
|
|
||||||
# parse config to get its iteration
|
|
||||||
import yaml
|
|
||||||
with open(config_path, 'r') as file:
|
|
||||||
config = yaml.safe_load(file)
|
|
||||||
|
|
||||||
it = 0
|
|
||||||
its = config['train']['niter']
|
|
||||||
|
|
||||||
checkpoint = 0
|
|
||||||
checkpoints = its / config['logger']['save_checkpoint_freq']
|
|
||||||
|
|
||||||
buffer_size = 8
|
|
||||||
open_state = False
|
|
||||||
training_started = False
|
|
||||||
|
|
||||||
yield " ".join(cmd)
|
|
||||||
|
|
||||||
info = {}
|
|
||||||
buffer = []
|
|
||||||
infos = []
|
|
||||||
yields = True
|
|
||||||
status = ""
|
|
||||||
|
|
||||||
it_rate = ""
|
|
||||||
it_time_start = 0
|
|
||||||
it_time_end = 0
|
|
||||||
|
|
||||||
for line in iter(training_process.stdout.readline, ""):
|
|
||||||
buffer.append(f'{line}')
|
|
||||||
|
|
||||||
# rip out iteration info
|
|
||||||
if not training_started:
|
|
||||||
if line.find('Start training from epoch') >= 0:
|
|
||||||
training_started = True
|
|
||||||
|
|
||||||
match = re.findall(r'iter: ([\d,]+)', line)
|
|
||||||
if match and len(match) > 0:
|
|
||||||
it = int(match[0].replace(",", ""))
|
|
||||||
elif progress is not None:
|
|
||||||
if line.find(' 0%|') == 0:
|
|
||||||
open_state = True
|
|
||||||
elif line.find('100%|') == 0 and open_state:
|
|
||||||
open_state = False
|
|
||||||
it = it + 1
|
|
||||||
|
|
||||||
it_time_end = time.time()
|
|
||||||
it_time_delta = it_time_end-it_time_start
|
|
||||||
it_time_start = time.time()
|
|
||||||
it_rate = f'[{"{:.3f}".format(it_time_delta)}s/it]' if it_time_delta >= 1 else f'[{"{:.3f}".format(1/it_time_delta)}it/s]' # I doubt anyone will have it/s rates, but its here
|
|
||||||
|
|
||||||
progress(it / float(its), f'[{it}/{its}] {it_rate} Training... {status}')
|
|
||||||
|
|
||||||
if line.find('INFO: [epoch:') >= 0:
|
|
||||||
# easily rip out our stats...
|
|
||||||
match = re.findall(r'\b([a-z_0-9]+?)\b: ([0-9]\.[0-9]+?e[+-]\d+)\b', line)
|
|
||||||
if match and len(match) > 0:
|
|
||||||
for k, v in match:
|
|
||||||
info[k] = float(v)
|
|
||||||
|
|
||||||
# ...and returns our loss rate
|
|
||||||
# it would be nice for losses to be shown at every step
|
|
||||||
if 'loss_gpt_total' in info:
|
|
||||||
status = f"Total loss at step {int(info['step'])}: {info['loss_gpt_total']}"
|
|
||||||
elif line.find('Saving models and training states') >= 0:
|
|
||||||
checkpoint = checkpoint + 1
|
|
||||||
progress(checkpoint / float(checkpoints), f'[{checkpoint}/{checkpoints}] Saving checkpoint...')
|
|
||||||
|
|
||||||
|
for line in iter(training_state.process.stdout.readline, ""):
|
||||||
print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}")
|
print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}")
|
||||||
|
|
||||||
|
res = training_state.parse( line=line, verbose=verbose, buffer_size=buffer_size, progress=progress )
|
||||||
|
if res:
|
||||||
|
yield res
|
||||||
|
|
||||||
if verbose or not training_started:
|
training_state.process.stdout.close()
|
||||||
yield "".join(buffer[-buffer_size:])
|
return_code = training_state.process.wait()
|
||||||
|
output = "".join(training_state.buffer[-buffer_size:])
|
||||||
training_process.stdout.close()
|
training_state = None
|
||||||
return_code = training_process.wait()
|
|
||||||
training_process = None
|
|
||||||
|
|
||||||
#if return_code:
|
#if return_code:
|
||||||
# raise subprocess.CalledProcessError(return_code, cmd)
|
# raise subprocess.CalledProcessError(return_code, cmd)
|
||||||
|
|
||||||
return "".join(buffer[-buffer_size:])
|
return output
|
||||||
|
|
||||||
|
def reconnect_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress(track_tqdm=True)):
|
||||||
|
global training_state
|
||||||
|
if not training_state or not training_state.process:
|
||||||
|
return "Training not in progress"
|
||||||
|
|
||||||
|
for line in iter(training_state.process.stdout.readline, ""):
|
||||||
|
res = training_state.parse( line=line, verbose=verbose, buffer_size=buffer_size, progress=progress )
|
||||||
|
if res:
|
||||||
|
yield res
|
||||||
|
|
||||||
|
output = "".join(training_state.buffer[-buffer_size:])
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
def stop_training():
|
def stop_training():
|
||||||
global training_process
|
global training_process
|
||||||
|
|
|
@ -410,6 +410,7 @@ def setup_gradio():
|
||||||
refresh_configs = gr.Button(value="Refresh Configurations")
|
refresh_configs = gr.Button(value="Refresh Configurations")
|
||||||
start_training_button = gr.Button(value="Train")
|
start_training_button = gr.Button(value="Train")
|
||||||
stop_training_button = gr.Button(value="Stop")
|
stop_training_button = gr.Button(value="Stop")
|
||||||
|
reconnect_training_button = gr.Button(value="Reconnect")
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
|
training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
|
||||||
verbose_training = gr.Checkbox(label="Verbose Console Output")
|
verbose_training = gr.Checkbox(label="Verbose Console Output")
|
||||||
|
@ -614,6 +615,13 @@ def setup_gradio():
|
||||||
inputs=None,
|
inputs=None,
|
||||||
outputs=training_output #console_output
|
outputs=training_output #console_output
|
||||||
)
|
)
|
||||||
|
reconnect_training_button.click(reconnect_training,
|
||||||
|
inputs=[
|
||||||
|
verbose_training,
|
||||||
|
training_buffer_size,
|
||||||
|
],
|
||||||
|
outputs=training_output #console_output
|
||||||
|
)
|
||||||
prepare_dataset_button.click(
|
prepare_dataset_button.click(
|
||||||
prepare_dataset_proxy,
|
prepare_dataset_proxy,
|
||||||
inputs=dataset_settings,
|
inputs=dataset_settings,
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
git pull
|
git pull
|
||||||
git submodule update
|
git submodule update --remote
|
||||||
|
|
||||||
python -m venv venv
|
python -m venv venv
|
||||||
call .\venv\Scripts\activate.bat
|
call .\venv\Scripts\activate.bat
|
||||||
|
|
Loading…
Reference in New Issue
Block a user