added option to set worker size in training config generator (because the default is overkill), for whisper transcriptions, load a specialized language model if it exists (for now, only english), output transcription to web UI when done transcribing
This commit is contained in:
parent
37cab14272
commit
3e220ed306
|
@ -11,7 +11,7 @@ use_tb_logger: true
|
||||||
datasets:
|
datasets:
|
||||||
train:
|
train:
|
||||||
name: ${dataset_name}
|
name: ${dataset_name}
|
||||||
n_workers: 8
|
n_workers: ${workers}
|
||||||
batch_size: ${batch_size}
|
batch_size: ${batch_size}
|
||||||
mode: paired_voice_audio
|
mode: paired_voice_audio
|
||||||
path: ${dataset_path}
|
path: ${dataset_path}
|
||||||
|
|
94
src/utils.py
94
src/utils.py
|
@ -37,6 +37,8 @@ from tortoise.utils.text import split_and_recombine_text
|
||||||
from tortoise.utils.device import get_device_name, set_device_name
|
from tortoise.utils.device import get_device_name, set_device_name
|
||||||
|
|
||||||
MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth"
|
MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth"
|
||||||
|
WHISPER_MODELS = ["tiny", "base", "small", "medium", "large"]
|
||||||
|
WHISPER_SPECIALIZED_MODELS = ["tiny.en", "base.en", "small.en", "medium.en"]
|
||||||
|
|
||||||
args = None
|
args = None
|
||||||
tts = None
|
tts = None
|
||||||
|
@ -663,6 +665,7 @@ class TrainingState():
|
||||||
# rip out iteration info
|
# rip out iteration info
|
||||||
if not self.training_started:
|
if not self.training_started:
|
||||||
if line.find('Start training from epoch') >= 0:
|
if line.find('Start training from epoch') >= 0:
|
||||||
|
self.it_time_start = time.time()
|
||||||
self.epoch_time_start = time.time()
|
self.epoch_time_start = time.time()
|
||||||
self.training_started = True # could just leverage the above variable, but this is python, and there's no point in these aggressive microoptimizations
|
self.training_started = True # could just leverage the above variable, but this is python, and there's no point in these aggressive microoptimizations
|
||||||
should_return = True
|
should_return = True
|
||||||
|
@ -703,6 +706,7 @@ class TrainingState():
|
||||||
self.it_time_delta = self.it_time_end-self.it_time_start
|
self.it_time_delta = self.it_time_end-self.it_time_start
|
||||||
self.it_time_start = time.time()
|
self.it_time_start = time.time()
|
||||||
self.it_taken = self.it_taken + 1
|
self.it_taken = self.it_taken + 1
|
||||||
|
if self.it_time_delta:
|
||||||
try:
|
try:
|
||||||
rate = f'{"{:.3f}".format(self.it_time_delta)}s/it' if self.it_time_delta >= 1 else f'{"{:.3f}".format(1/self.it_time_delta)}it/s'
|
rate = f'{"{:.3f}".format(self.it_time_delta)}s/it' if self.it_time_delta >= 1 else f'{"{:.3f}".format(1/self.it_time_delta)}it/s'
|
||||||
self.it_rate = rate
|
self.it_rate = rate
|
||||||
|
@ -733,9 +737,23 @@ class TrainingState():
|
||||||
metric_loss = []
|
metric_loss = []
|
||||||
if len(self.losses) > 0:
|
if len(self.losses) > 0:
|
||||||
metric_loss.append(f'Loss: {"{:3f}".format(self.losses[-1]["value"])}')
|
metric_loss.append(f'Loss: {"{:3f}".format(self.losses[-1]["value"])}')
|
||||||
|
|
||||||
|
if len(self.losses) >= 2:
|
||||||
|
delta_loss = self.losses[-2]["value"] - self.losses[-1]["value"]
|
||||||
|
delta_step = self.losses[-2]["step"] - self.losses[-1]["step"]
|
||||||
|
|
||||||
|
inst_deriv = delta_loss / delta_step
|
||||||
|
est_loss = delta_loss + (self.its - self.it) * inst_deriv
|
||||||
|
metric_loss.append(f'Est. Final Loss: {"{:3f}".format(est_loss)}')
|
||||||
|
|
||||||
|
print(delta_loss, delta_step, inst_deriv, est_loss)
|
||||||
|
|
||||||
|
|
||||||
metric_loss = ", ".join(metric_loss)
|
metric_loss = ", ".join(metric_loss)
|
||||||
|
|
||||||
message = f'[{metric_step}] [{metric_rate}] [{metric_loss}] [ETA: {eta_hhmmss}]'
|
|
||||||
|
|
||||||
|
message = f'[{metric_step}] [{metric_rate}] [ETA: {eta_hhmmss}] [{metric_loss}]'
|
||||||
|
|
||||||
if lapsed:
|
if lapsed:
|
||||||
self.epoch = self.epoch + 1
|
self.epoch = self.epoch + 1
|
||||||
|
@ -764,6 +782,13 @@ class TrainingState():
|
||||||
self.buffer.append(f'[{"{:.3f}".format(percent*100)}%] {message}')
|
self.buffer.append(f'[{"{:.3f}".format(percent*100)}%] {message}')
|
||||||
|
|
||||||
if line.find('INFO: [epoch:') >= 0:
|
if line.find('INFO: [epoch:') >= 0:
|
||||||
|
# to-do, actually validate this works, and probably kill training when it's found, the model's dead by this point
|
||||||
|
if ': nan' in line:
|
||||||
|
should_return = True
|
||||||
|
|
||||||
|
print("! NAN DETECTED !")
|
||||||
|
self.buffer.append("! NAN DETECTED !")
|
||||||
|
|
||||||
# easily rip out our stats...
|
# easily rip out our stats...
|
||||||
match = re.findall(r'\b([a-z_0-9]+?)\b: +?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', line)
|
match = re.findall(r'\b([a-z_0-9]+?)\b: +?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', line)
|
||||||
if match and len(match) > 0:
|
if match and len(match) > 0:
|
||||||
|
@ -824,14 +849,14 @@ def run_training(config_path, verbose=False, gpus=1, keep_x_past_datasets=0, pro
|
||||||
if result:
|
if result:
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
|
if progress is not None and message:
|
||||||
|
progress(percent, message)
|
||||||
|
|
||||||
if training_state:
|
if training_state:
|
||||||
training_state.process.stdout.close()
|
training_state.process.stdout.close()
|
||||||
return_code = training_state.process.wait()
|
return_code = training_state.process.wait()
|
||||||
training_state = None
|
training_state = None
|
||||||
|
|
||||||
#if return_code:
|
|
||||||
# raise subprocess.CalledProcessError(return_code, cmd)
|
|
||||||
|
|
||||||
def get_training_losses():
|
def get_training_losses():
|
||||||
global training_state
|
global training_state
|
||||||
if not training_state or not training_state.losses:
|
if not training_state or not training_state.losses:
|
||||||
|
@ -866,6 +891,9 @@ def reconnect_training(verbose=False, progress=gr.Progress(track_tqdm=True)):
|
||||||
if result:
|
if result:
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
|
if progress is not None and message:
|
||||||
|
progress(percent, message)
|
||||||
|
|
||||||
def stop_training():
|
def stop_training():
|
||||||
global training_state
|
global training_state
|
||||||
if training_state is None:
|
if training_state is None:
|
||||||
|
@ -910,10 +938,10 @@ def convert_to_halfp():
|
||||||
def whisper_transcribe( file, language=None ):
|
def whisper_transcribe( file, language=None ):
|
||||||
# shouldn't happen, but it's for safety
|
# shouldn't happen, but it's for safety
|
||||||
if not whisper_model:
|
if not whisper_model:
|
||||||
load_whisper_model(language=language if language else b'en')
|
load_whisper_model(language=language)
|
||||||
|
|
||||||
if not args.whisper_cpp:
|
if not args.whisper_cpp:
|
||||||
return whisper_model.transcribe(file, language=language if language else "English")
|
return whisper_model.transcribe(file, language=language)
|
||||||
|
|
||||||
res = whisper_model.transcribe(file)
|
res = whisper_model.transcribe(file)
|
||||||
segments = whisper_model.extract_text_and_timestamps( res )
|
segments = whisper_model.extract_text_and_timestamps( res )
|
||||||
|
@ -945,11 +973,8 @@ def prepare_dataset( files, outdir, language=None, progress=None ):
|
||||||
transcription = []
|
transcription = []
|
||||||
|
|
||||||
for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress):
|
for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress):
|
||||||
print(f"Transcribing file: {file}")
|
result = whisper_transcribe(file, language=language)
|
||||||
|
|
||||||
result = whisper_transcribe(file, language=language) # whisper_model.transcribe(file, language=language if language else "English")
|
|
||||||
results[os.path.basename(file)] = result
|
results[os.path.basename(file)] = result
|
||||||
|
|
||||||
print(f"Transcribed file: {file}, {len(result['segments'])} found.")
|
print(f"Transcribed file: {file}, {len(result['segments'])} found.")
|
||||||
|
|
||||||
waveform, sampling_rate = torchaudio.load(file)
|
waveform, sampling_rate = torchaudio.load(file)
|
||||||
|
@ -988,7 +1013,7 @@ EPOCH_SCHEDULE = [ 9, 18, 25, 33 ]
|
||||||
def schedule_learning_rate( iterations, schedule=EPOCH_SCHEDULE ):
|
def schedule_learning_rate( iterations, schedule=EPOCH_SCHEDULE ):
|
||||||
return [int(iterations * d) for d in schedule]
|
return [int(iterations * d) for d in schedule]
|
||||||
|
|
||||||
def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, gradient_accumulation_size, print_rate, save_rate, resume_path, half_p, bnb, source_model, voice ):
|
def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, gradient_accumulation_size, print_rate, save_rate, resume_path, half_p, bnb, workers, source_model, voice ):
|
||||||
name = f"{voice}-finetune"
|
name = f"{voice}-finetune"
|
||||||
dataset_name = f"{voice}-train"
|
dataset_name = f"{voice}-train"
|
||||||
dataset_path = f"./training/{voice}/train.txt"
|
dataset_path = f"./training/{voice}/train.txt"
|
||||||
|
@ -1065,7 +1090,7 @@ def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learni
|
||||||
messages
|
messages
|
||||||
)
|
)
|
||||||
|
|
||||||
def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weight=None, learning_rate_schedule=None, batch_size=None, gradient_accumulation_size=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None, half_p=None, bnb=None, source_model=None ):
|
def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weight=None, learning_rate_schedule=None, batch_size=None, gradient_accumulation_size=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None, half_p=None, bnb=None, workers=None, source_model=None ):
|
||||||
if not source_model:
|
if not source_model:
|
||||||
source_model = f"./models/tortoise/autoregressive{'_half' if half_p else ''}.pth"
|
source_model = f"./models/tortoise/autoregressive{'_half' if half_p else ''}.pth"
|
||||||
|
|
||||||
|
@ -1090,6 +1115,8 @@ def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weig
|
||||||
|
|
||||||
'float16': 'true' if half_p else 'false',
|
'float16': 'true' if half_p else 'false',
|
||||||
'bitsandbytes': 'true' if bnb else 'false',
|
'bitsandbytes': 'true' if bnb else 'false',
|
||||||
|
|
||||||
|
'workers': workers if workers else 2,
|
||||||
}
|
}
|
||||||
|
|
||||||
if resume_path:
|
if resume_path:
|
||||||
|
@ -1581,9 +1608,9 @@ def unload_tts():
|
||||||
global tts
|
global tts
|
||||||
|
|
||||||
if tts:
|
if tts:
|
||||||
print("Unloading TTS")
|
|
||||||
del tts
|
del tts
|
||||||
tts = None
|
tts = None
|
||||||
|
print("Unloaded TTS")
|
||||||
do_gc()
|
do_gc()
|
||||||
|
|
||||||
def reload_tts( model=None ):
|
def reload_tts( model=None ):
|
||||||
|
@ -1656,55 +1683,44 @@ def unload_voicefixer():
|
||||||
global voicefixer
|
global voicefixer
|
||||||
|
|
||||||
if voicefixer:
|
if voicefixer:
|
||||||
print("Unloading Voicefixer")
|
|
||||||
del voicefixer
|
del voicefixer
|
||||||
voicefixer = None
|
voicefixer = None
|
||||||
print("Unloaded Voicefixer")
|
print("Unloaded Voicefixer")
|
||||||
|
|
||||||
do_gc()
|
do_gc()
|
||||||
|
|
||||||
def load_whisper_model(name=None, progress=None, language=b'en'):
|
def load_whisper_model(language=None, model_name=None, progress=None):
|
||||||
global whisper_model
|
global whisper_model
|
||||||
|
|
||||||
if not name:
|
if not model_name:
|
||||||
name = args.whisper_model
|
model_name = args.whisper_model
|
||||||
else:
|
else:
|
||||||
args.whisper_model = name
|
args.whisper_model = model_name
|
||||||
save_args_settings()
|
save_args_settings()
|
||||||
|
|
||||||
notify_progress(f"Loading Whisper model: {args.whisper_model}", progress)
|
if language and f'{model_name}.{language}' in WHISPER_SPECIALIZED_MODELS:
|
||||||
|
model_name = f'{model_name}.{language}'
|
||||||
|
print(f"Loading specialized model for language: {language}")
|
||||||
|
|
||||||
|
notify_progress(f"Loading Whisper model: {model_name}", progress)
|
||||||
if args.whisper_cpp:
|
if args.whisper_cpp:
|
||||||
from whispercpp import Whisper
|
from whispercpp import Whisper
|
||||||
whisper_model = Whisper(name, models_dir='./models/', language=language.encode('ascii'))
|
if not language:
|
||||||
|
language = 'auto'
|
||||||
|
|
||||||
|
whisper_model = Whisper(model_name, models_dir='./models/', language=language.encode('ascii'))
|
||||||
else:
|
else:
|
||||||
import whisper
|
import whisper
|
||||||
whisper_model = whisper.load_model(args.whisper_model)
|
whisper_model = whisper.load_model(model_name)
|
||||||
|
|
||||||
print("Loaded Whisper model")
|
print("Loaded Whisper model")
|
||||||
|
|
||||||
def unload_whisper():
|
def unload_whisper():
|
||||||
global whisper_model
|
global whisper_model
|
||||||
|
|
||||||
if whisper_model:
|
if whisper_model:
|
||||||
print("Unloading Whisper")
|
|
||||||
del whisper_model
|
del whisper_model
|
||||||
whisper_model = None
|
whisper_model = None
|
||||||
print("Unloaded Whisper")
|
print("Unloaded Whisper")
|
||||||
|
|
||||||
do_gc()
|
do_gc()
|
||||||
|
|
||||||
"""
|
|
||||||
def update_whisper_model(name, progress=None):
|
|
||||||
if not name:
|
|
||||||
return
|
|
||||||
|
|
||||||
args.whisper_model = name
|
|
||||||
save_args_settings()
|
|
||||||
|
|
||||||
global whisper_model
|
|
||||||
if whisper_model:
|
|
||||||
unload_whisper()
|
|
||||||
load_whisper_model(name)
|
|
||||||
else:
|
|
||||||
args.whisper_model = name
|
|
||||||
save_args_settings()
|
|
||||||
"""
|
|
19
src/webui.py
19
src/webui.py
|
@ -268,6 +268,8 @@ def import_training_settings_proxy( voice ):
|
||||||
if "ext" in config and "bitsandbytes" in config["ext"]:
|
if "ext" in config and "bitsandbytes" in config["ext"]:
|
||||||
bnb = config["ext"]["bitsandbytes"]
|
bnb = config["ext"]["bitsandbytes"]
|
||||||
|
|
||||||
|
workers = config['datasets']['train']['n_workers']
|
||||||
|
|
||||||
messages = "\n".join(messages)
|
messages = "\n".join(messages)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
@ -282,12 +284,13 @@ def import_training_settings_proxy( voice ):
|
||||||
resume_path,
|
resume_path,
|
||||||
half_p,
|
half_p,
|
||||||
bnb,
|
bnb,
|
||||||
|
workers,
|
||||||
source_model,
|
source_model,
|
||||||
messages
|
messages
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, gradient_accumulation_size, print_rate, save_rate, resume_path, half_p, bnb, source_model, voice ):
|
def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, gradient_accumulation_size, print_rate, save_rate, resume_path, half_p, bnb, workers, source_model, voice ):
|
||||||
name = f"{voice}-finetune"
|
name = f"{voice}-finetune"
|
||||||
dataset_name = f"{voice}-train"
|
dataset_name = f"{voice}-train"
|
||||||
dataset_path = f"./training/{voice}/train.txt"
|
dataset_path = f"./training/{voice}/train.txt"
|
||||||
|
@ -330,6 +333,7 @@ def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, lear
|
||||||
resume_path=resume_path,
|
resume_path=resume_path,
|
||||||
half_p=half_p,
|
half_p=half_p,
|
||||||
bnb=bnb,
|
bnb=bnb,
|
||||||
|
workers=workers,
|
||||||
source_model=source_model,
|
source_model=source_model,
|
||||||
))
|
))
|
||||||
return "\n".join(messages)
|
return "\n".join(messages)
|
||||||
|
@ -466,7 +470,7 @@ def setup_gradio():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
dataset_settings = [
|
dataset_settings = [
|
||||||
gr.Dropdown( choices=voice_list, label="Dataset Source", type="value", value=voice_list[0] if len(voice_list) > 0 else "" ),
|
gr.Dropdown( choices=voice_list, label="Dataset Source", type="value", value=voice_list[0] if len(voice_list) > 0 else "" ),
|
||||||
gr.Textbox(label="Language", placeholder="English")
|
gr.Textbox(label="Language", value="en")
|
||||||
]
|
]
|
||||||
prepare_dataset_button = gr.Button(value="Prepare")
|
prepare_dataset_button = gr.Button(value="Prepare")
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
|
@ -499,11 +503,16 @@ def setup_gradio():
|
||||||
training_settings = training_settings + [
|
training_settings = training_settings + [
|
||||||
gr.Textbox(label="Resume State Path", placeholder="./training/${voice}-finetune/training_state/${last_state}.state"),
|
gr.Textbox(label="Resume State Path", placeholder="./training/${voice}-finetune/training_state/${last_state}.state"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
training_halfp = gr.Checkbox(label="Half Precision", value=args.training_default_halfp)
|
training_halfp = gr.Checkbox(label="Half Precision", value=args.training_default_halfp)
|
||||||
training_bnb = gr.Checkbox(label="BitsAndBytes", value=args.training_default_bnb)
|
training_bnb = gr.Checkbox(label="BitsAndBytes", value=args.training_default_bnb)
|
||||||
|
|
||||||
|
training_workers = gr.Number(label="Worker Processes", value=2, precision=0)
|
||||||
|
|
||||||
source_model = gr.Dropdown( choices=autoregressive_models, label="Source Model", type="value", value=autoregressive_models[0] )
|
source_model = gr.Dropdown( choices=autoregressive_models, label="Source Model", type="value", value=autoregressive_models[0] )
|
||||||
dataset_list_dropdown = gr.Dropdown( choices=dataset_list, label="Dataset", type="value", value=dataset_list[0] if len(dataset_list) else "" )
|
dataset_list_dropdown = gr.Dropdown( choices=dataset_list, label="Dataset", type="value", value=dataset_list[0] if len(dataset_list) else "" )
|
||||||
training_settings = training_settings + [ training_halfp, training_bnb, source_model, dataset_list_dropdown ]
|
training_settings = training_settings + [ training_halfp, training_bnb, training_workers, source_model, dataset_list_dropdown ]
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
refresh_dataset_list = gr.Button(value="Refresh Dataset List")
|
refresh_dataset_list = gr.Button(value="Refresh Dataset List")
|
||||||
|
@ -572,7 +581,7 @@ def setup_gradio():
|
||||||
|
|
||||||
autoregressive_model_dropdown = gr.Dropdown(choices=autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else autoregressive_models[0])
|
autoregressive_model_dropdown = gr.Dropdown(choices=autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else autoregressive_models[0])
|
||||||
|
|
||||||
whisper_model_dropdown = gr.Dropdown(["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large"], label="Whisper Model", value=args.whisper_model)
|
whisper_model_dropdown = gr.Dropdown(WHISPER_MODELS, label="Whisper Model", value=args.whisper_model)
|
||||||
use_whisper_cpp = gr.Checkbox(label="Use Whisper.cpp", value=args.whisper_cpp)
|
use_whisper_cpp = gr.Checkbox(label="Use Whisper.cpp", value=args.whisper_cpp)
|
||||||
|
|
||||||
exec_inputs = exec_inputs + [ autoregressive_model_dropdown, whisper_model_dropdown, use_whisper_cpp, training_halfp, training_bnb ]
|
exec_inputs = exec_inputs + [ autoregressive_model_dropdown, whisper_model_dropdown, use_whisper_cpp, training_halfp, training_bnb ]
|
||||||
|
@ -797,7 +806,7 @@ def setup_gradio():
|
||||||
)
|
)
|
||||||
import_dataset_button.click(import_training_settings_proxy,
|
import_dataset_button.click(import_training_settings_proxy,
|
||||||
inputs=dataset_list_dropdown,
|
inputs=dataset_list_dropdown,
|
||||||
outputs=training_settings[:11] + [save_yaml_output] #console_output
|
outputs=training_settings[:13] + [save_yaml_output] #console_output
|
||||||
)
|
)
|
||||||
save_yaml_button.click(save_training_settings_proxy,
|
save_yaml_button.click(save_training_settings_proxy,
|
||||||
inputs=training_settings,
|
inputs=training_settings,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user