huge success

This commit is contained in:
mrq 2023-02-23 06:24:54 +00:00
parent aa96edde2f
commit 225dee22d4
9 changed files with 154 additions and 87 deletions

View File

@ -17,3 +17,18 @@ Please consult [the wiki](https://git.ecker.tech/mrq/ai-voice-cloning/wiki) for
## Bug Reporting ## Bug Reporting
If you run into any problems, please refer to the [issues you may encounter](https://git.ecker.tech/mrq/ai-voice-cloning/wiki/Issues) wiki page first. Please don't hesitate to submit an issue. If you run into any problems, please refer to the [issues you may encounter](https://git.ecker.tech/mrq/ai-voice-cloning/wiki/Issues) wiki page first. Please don't hesitate to submit an issue.
## Changelogs
Below will be a rather-loose changelogss, as I don't think I have a way to chronicle them outside of commit messages:
### `2023.02.22`
* greatly reduced VRAM consumption through the use of [TimDettmers/bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
* cleaned up section of code that handled parsing output from training script
* added button to reconnect to the training script's output (sometimes skips a line to update, but it's better than nothing)
* actually update submodules from the update script (somehow forgot to pass `--remote`)
### `Before 2023.02.22`
Refer to commit logs.

2
dlas

@ -1 +1 @@
Subproject commit 6c284ef8ec4c4769de3181d90ac96ff63581ef55 Subproject commit 0ef8ab6872813d1021d4d75e82b63377d28f5a06

View File

@ -2,7 +2,7 @@ name: ${name}
model: extensibletrainer model: extensibletrainer
scale: 1 scale: 1
gpu_ids: [0] # <-- unless you have multiple gpus, use this gpu_ids: [0] # <-- unless you have multiple gpus, use this
start_step: -1 start_step: 0
checkpointing_enabled: true # <-- Gradient checkpointing. Enable for huge GPU memory savings. Disable for distributed training. checkpointing_enabled: true # <-- Gradient checkpointing. Enable for huge GPU memory savings. Disable for distributed training.
fp16: ${float16} # might want to check this out fp16: ${float16} # might want to check this out
wandb: false # <-- enable to log to wandb. tensorboard logging is always enabled. wandb: false # <-- enable to log to wandb. tensorboard logging is always enabled.

View File

@ -9,5 +9,8 @@ python -m pip install -r .\dlas\requirements.txt
python -m pip install -r .\tortoise-tts\requirements.txt python -m pip install -r .\tortoise-tts\requirements.txt
python -m pip install -r .\requirements.txt python -m pip install -r .\requirements.txt
python -m pip install -e .\tortoise-tts\ python -m pip install -e .\tortoise-tts\
copy .\dlas\bitsandbytes_windows\* .\venv\Lib\site-packages\bitsandbytes\. /Y
deactivate deactivate
pause pause

View File

@ -1,8 +1,8 @@
import torch
import argparse
import os import os
import sys import sys
import argparse
# this is some massive kludge that only works if it's called from a shell and not an import/PIP package # this is some massive kludge that only works if it's called from a shell and not an import/PIP package
# it's smart-yet-irritating module-model loader breaks when trying to load something specifically when not from a shell # it's smart-yet-irritating module-model loader breaks when trying to load something specifically when not from a shell
@ -19,6 +19,17 @@ sys.path.insert(0, './dlas/')
# don't even really bother trying to get DLAS PIP'd # don't even really bother trying to get DLAS PIP'd
# without kludge, it'll have to be accessible as `codes` and not `dlas` # without kludge, it'll have to be accessible as `codes` and not `dlas`
import torch_intermediary
# could just move this auto-toggle into the MITM script
try:
import bitsandbytes as bnb
torch_intermediary.OVERRIDE_ADAM = True
torch_intermediary.OVERRIDE_ADAMW = True
except Exception as e:
torch_intermediary.OVERRIDE_ADAM = False
torch_intermediary.OVERRIDE_ADAMW = False
import torch
from codes import train as tr from codes import train as tr
from utils import util, options as option from utils import util, options as option

View File

@ -17,6 +17,7 @@ import urllib.request
import signal import signal
import gc import gc
import subprocess import subprocess
import yaml
import tqdm import tqdm
import torch import torch
@ -26,6 +27,7 @@ import gradio as gr
import gradio.utils import gradio.utils
from datetime import datetime from datetime import datetime
from datetime import timedelta
from tortoise.api import TextToSpeech, MODELS, get_model_path from tortoise.api import TextToSpeech, MODELS, get_model_path
from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir
@ -42,7 +44,7 @@ tts_loading = False
webui = None webui = None
voicefixer = None voicefixer = None
whisper_model = None whisper_model = None
training_process = None training_state = None
def generate( def generate(
@ -434,8 +436,88 @@ def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm
return voice return voice
# superfluous, but it cleans up some things
class TrainingState():
def __init__(self, config_path, buffer_size=8):
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['bash', './train.sh', config_path]
# parse config to get its iteration
with open(config_path, 'r') as file:
self.config = yaml.safe_load(file)
self.it = 0
self.its = self.config['train']['niter']
self.checkpoint = 0
self.checkpoints = int(self.its / self.config['logger']['save_checkpoint_freq'])
self.buffer = []
self.open_state = False
self.training_started = False
self.info = {}
self.status = ""
self.it_rate = ""
self.it_time_start = 0
self.it_time_end = 0
self.eta = "?"
print("Spawning process: ", " ".join(self.cmd))
self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
def parse(self, line, verbose=False, buffer_size=8, progress=None):
self.buffer.append(f'{line}')
# rip out iteration info
if not self.training_started:
if line.find('Start training from epoch') >= 0:
self.it_time_start = time.time()
self.training_started = True # could just leverage the above variable, but this is python, and there's no point in these aggressive microoptimizations
match = re.findall(r'iter: ([\d,]+)', line)
if match and len(match) > 0:
self.it = int(match[0].replace(",", ""))
elif progress is not None:
if line.find(' 0%|') == 0:
self.open_state = True
elif line.find('100%|') == 0 and self.open_state:
self.open_state = False
self.it = self.it + 1
self.it_time_end = time.time()
self.it_time_delta = self.it_time_end-self.it_time_start
self.it_time_start = time.time()
self.it_rate = f'[{"{:.3f}".format(self.it_time_delta)}s/it]' if self.it_time_delta >= 1 else f'[{"{:.3f}".format(1/self.it_time_delta)}it/s]' # I doubt anyone will have it/s rates, but its here
self.eta = (self.its - self.it) * self.it_time_delta
self.eta_hhmmss = str(timedelta(seconds=int(self.eta)))
progress(self.it / float(self.its), f'[{self.it}/{self.its}] [ETA: {self.eta_hhmmss}] {self.it_rate} Training... {self.status}')
if line.find('INFO: [epoch:') >= 0:
# easily rip out our stats...
match = re.findall(r'\b([a-z_0-9]+?)\b: ([0-9]\.[0-9]+?e[+-]\d+)\b', line)
if match and len(match) > 0:
for k, v in match:
self.info[k] = float(v)
# ...and returns our loss rate
# it would be nice for losses to be shown at every step
if 'loss_gpt_total' in self.info:
# self.info['step'] returns the steps, not iterations, so we won't even bother ripping the reported step count, as iteration count won't get ripped from the regex
self.status = f"Total loss at iteration {self.it}: {self.info['loss_gpt_total']}"
elif line.find('Saving models and training states') >= 0:
self.checkpoint = self.checkpoint + 1
progress(self.checkpoint / float(self.checkpoints), f'[{self.checkpoint}/{self.checkpoints}] Saving checkpoint...')
if verbose or not self.training_started:
return "".join(self.buffer[-buffer_size:])
def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress(track_tqdm=True)): def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress(track_tqdm=True)):
global training_process global training_state
if training_state and training_state.process:
return "Training already in progress"
# I don't know if this is still necessary, as it was bitching at me for not doing this, despite it being in a separate process # I don't know if this is still necessary, as it was bitching at me for not doing this, despite it being in a separate process
torch.multiprocessing.freeze_support() torch.multiprocessing.freeze_support()
@ -444,90 +526,38 @@ def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress
unload_whisper() unload_whisper()
unload_voicefixer() unload_voicefixer()
cmd = ['train.bat', config_path] if os.name == "nt" else ['bash', './train.sh', config_path] training_state = TrainingState(config_path=config_path, buffer_size=buffer_size)
print("Spawning process: ", " ".join(cmd))
training_process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
# parse config to get its iteration
import yaml
with open(config_path, 'r') as file:
config = yaml.safe_load(file)
it = 0
its = config['train']['niter']
checkpoint = 0
checkpoints = its / config['logger']['save_checkpoint_freq']
buffer_size = 8
open_state = False
training_started = False
yield " ".join(cmd)
info = {}
buffer = []
infos = []
yields = True
status = ""
it_rate = ""
it_time_start = 0
it_time_end = 0
for line in iter(training_process.stdout.readline, ""):
buffer.append(f'{line}')
# rip out iteration info
if not training_started:
if line.find('Start training from epoch') >= 0:
training_started = True
match = re.findall(r'iter: ([\d,]+)', line)
if match and len(match) > 0:
it = int(match[0].replace(",", ""))
elif progress is not None:
if line.find(' 0%|') == 0:
open_state = True
elif line.find('100%|') == 0 and open_state:
open_state = False
it = it + 1
it_time_end = time.time()
it_time_delta = it_time_end-it_time_start
it_time_start = time.time()
it_rate = f'[{"{:.3f}".format(it_time_delta)}s/it]' if it_time_delta >= 1 else f'[{"{:.3f}".format(1/it_time_delta)}it/s]' # I doubt anyone will have it/s rates, but its here
progress(it / float(its), f'[{it}/{its}] {it_rate} Training... {status}')
if line.find('INFO: [epoch:') >= 0:
# easily rip out our stats...
match = re.findall(r'\b([a-z_0-9]+?)\b: ([0-9]\.[0-9]+?e[+-]\d+)\b', line)
if match and len(match) > 0:
for k, v in match:
info[k] = float(v)
# ...and returns our loss rate
# it would be nice for losses to be shown at every step
if 'loss_gpt_total' in info:
status = f"Total loss at step {int(info['step'])}: {info['loss_gpt_total']}"
elif line.find('Saving models and training states') >= 0:
checkpoint = checkpoint + 1
progress(checkpoint / float(checkpoints), f'[{checkpoint}/{checkpoints}] Saving checkpoint...')
for line in iter(training_state.process.stdout.readline, ""):
print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}") print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}")
if verbose or not training_started: res = training_state.parse( line=line, verbose=verbose, buffer_size=buffer_size, progress=progress )
yield "".join(buffer[-buffer_size:]) if res:
yield res
training_process.stdout.close() training_state.process.stdout.close()
return_code = training_process.wait() return_code = training_state.process.wait()
training_process = None output = "".join(training_state.buffer[-buffer_size:])
training_state = None
#if return_code: #if return_code:
# raise subprocess.CalledProcessError(return_code, cmd) # raise subprocess.CalledProcessError(return_code, cmd)
return "".join(buffer[-buffer_size:]) return output
def reconnect_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress(track_tqdm=True)):
global training_state
if not training_state or not training_state.process:
return "Training not in progress"
for line in iter(training_state.process.stdout.readline, ""):
res = training_state.parse( line=line, verbose=verbose, buffer_size=buffer_size, progress=progress )
if res:
yield res
output = "".join(training_state.buffer[-buffer_size:])
return output
def stop_training(): def stop_training():
global training_process global training_process

View File

@ -410,6 +410,7 @@ def setup_gradio():
refresh_configs = gr.Button(value="Refresh Configurations") refresh_configs = gr.Button(value="Refresh Configurations")
start_training_button = gr.Button(value="Train") start_training_button = gr.Button(value="Train")
stop_training_button = gr.Button(value="Stop") stop_training_button = gr.Button(value="Stop")
reconnect_training_button = gr.Button(value="Reconnect")
with gr.Column(): with gr.Column():
training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8) training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
verbose_training = gr.Checkbox(label="Verbose Console Output") verbose_training = gr.Checkbox(label="Verbose Console Output")
@ -614,6 +615,13 @@ def setup_gradio():
inputs=None, inputs=None,
outputs=training_output #console_output outputs=training_output #console_output
) )
reconnect_training_button.click(reconnect_training,
inputs=[
verbose_training,
training_buffer_size,
],
outputs=training_output #console_output
)
prepare_dataset_button.click( prepare_dataset_button.click(
prepare_dataset_proxy, prepare_dataset_proxy,
inputs=dataset_settings, inputs=dataset_settings,

View File

@ -1,5 +1,5 @@
git pull git pull
git submodule update git submodule update --remote
python -m venv venv python -m venv venv
call .\venv\Scripts\activate.bat call .\venv\Scripts\activate.bat

View File

@ -1,6 +1,6 @@
#!/bin/bash #!/bin/bash
git pull git pull
git submodule update git submodule update --remote
python3 -m venv venv python3 -m venv venv
source ./venv/bin/activate source ./venv/bin/activate