From 3e220ed306c619868fb4195afffd622a013d771d Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 5 Mar 2023 05:17:19 +0000 Subject: [PATCH] added option to set worker size in training config generator (because the default is overkill), for whisper transcriptions, load a specialized language model if it exists (for now, only english), output transcription to web UI when done transcribing --- models/.template.yaml | 2 +- src/utils.py | 106 ++++++++++++++++++++++++------------------ src/webui.py | 23 ++++++--- 3 files changed, 78 insertions(+), 53 deletions(-) diff --git a/models/.template.yaml b/models/.template.yaml index 825ac0c..77e5113 100755 --- a/models/.template.yaml +++ b/models/.template.yaml @@ -11,7 +11,7 @@ use_tb_logger: true datasets: train: name: ${dataset_name} - n_workers: 8 + n_workers: ${workers} batch_size: ${batch_size} mode: paired_voice_audio path: ${dataset_path} diff --git a/src/utils.py b/src/utils.py index 1f6a707..89dd1d1 100755 --- a/src/utils.py +++ b/src/utils.py @@ -37,6 +37,8 @@ from tortoise.utils.text import split_and_recombine_text from tortoise.utils.device import get_device_name, set_device_name MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth" +WHISPER_MODELS = ["tiny", "base", "small", "medium", "large"] +WHISPER_SPECIALIZED_MODELS = ["tiny.en", "base.en", "small.en", "medium.en"] args = None tts = None @@ -663,6 +665,7 @@ class TrainingState(): # rip out iteration info if not self.training_started: if line.find('Start training from epoch') >= 0: + self.it_time_start = time.time() self.epoch_time_start = time.time() self.training_started = True # could just leverage the above variable, but this is python, and there's no point in these aggressive microoptimizations should_return = True @@ -703,11 +706,12 @@ class TrainingState(): self.it_time_delta = self.it_time_end-self.it_time_start self.it_time_start = time.time() self.it_taken = self.it_taken + 1 - try: - rate = f'{"{:.3f}".format(self.it_time_delta)}s/it' if self.it_time_delta >= 1 else f'{"{:.3f}".format(1/self.it_time_delta)}it/s' - self.it_rate = rate - except Exception as e: - pass + if self.it_time_delta: + try: + rate = f'{"{:.3f}".format(self.it_time_delta)}s/it' if self.it_time_delta >= 1 else f'{"{:.3f}".format(1/self.it_time_delta)}it/s' + self.it_rate = rate + except Exception as e: + pass metric_step = [f"{self.epoch}/{self.epochs}", f"{self.it}/{self.its}", f"{step}/{steps}"] metric_step = ", ".join(metric_step) @@ -733,9 +737,23 @@ class TrainingState(): metric_loss = [] if len(self.losses) > 0: metric_loss.append(f'Loss: {"{:3f}".format(self.losses[-1]["value"])}') + + if len(self.losses) >= 2: + delta_loss = self.losses[-2]["value"] - self.losses[-1]["value"] + delta_step = self.losses[-2]["step"] - self.losses[-1]["step"] + + inst_deriv = delta_loss / delta_step + est_loss = delta_loss + (self.its - self.it) * inst_deriv + metric_loss.append(f'Est. Final Loss: {"{:3f}".format(est_loss)}') + + print(delta_loss, delta_step, inst_deriv, est_loss) + + metric_loss = ", ".join(metric_loss) - message = f'[{metric_step}] [{metric_rate}] [{metric_loss}] [ETA: {eta_hhmmss}]' + + + message = f'[{metric_step}] [{metric_rate}] [ETA: {eta_hhmmss}] [{metric_loss}]' if lapsed: self.epoch = self.epoch + 1 @@ -764,6 +782,13 @@ class TrainingState(): self.buffer.append(f'[{"{:.3f}".format(percent*100)}%] {message}') if line.find('INFO: [epoch:') >= 0: + # to-do, actually validate this works, and probably kill training when it's found, the model's dead by this point + if ': nan' in line: + should_return = True + + print("! NAN DETECTED !") + self.buffer.append("! NAN DETECTED !") + # easily rip out our stats... match = re.findall(r'\b([a-z_0-9]+?)\b: +?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', line) if match and len(match) > 0: @@ -824,13 +849,13 @@ def run_training(config_path, verbose=False, gpus=1, keep_x_past_datasets=0, pro if result: yield result + if progress is not None and message: + progress(percent, message) + if training_state: training_state.process.stdout.close() return_code = training_state.process.wait() training_state = None - - #if return_code: - # raise subprocess.CalledProcessError(return_code, cmd) def get_training_losses(): global training_state @@ -866,6 +891,9 @@ def reconnect_training(verbose=False, progress=gr.Progress(track_tqdm=True)): if result: yield result + if progress is not None and message: + progress(percent, message) + def stop_training(): global training_state if training_state is None: @@ -910,10 +938,10 @@ def convert_to_halfp(): def whisper_transcribe( file, language=None ): # shouldn't happen, but it's for safety if not whisper_model: - load_whisper_model(language=language if language else b'en') + load_whisper_model(language=language) if not args.whisper_cpp: - return whisper_model.transcribe(file, language=language if language else "English") + return whisper_model.transcribe(file, language=language) res = whisper_model.transcribe(file) segments = whisper_model.extract_text_and_timestamps( res ) @@ -945,11 +973,8 @@ def prepare_dataset( files, outdir, language=None, progress=None ): transcription = [] for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress): - print(f"Transcribing file: {file}") - - result = whisper_transcribe(file, language=language) # whisper_model.transcribe(file, language=language if language else "English") + result = whisper_transcribe(file, language=language) results[os.path.basename(file)] = result - print(f"Transcribed file: {file}, {len(result['segments'])} found.") waveform, sampling_rate = torchaudio.load(file) @@ -988,7 +1013,7 @@ EPOCH_SCHEDULE = [ 9, 18, 25, 33 ] def schedule_learning_rate( iterations, schedule=EPOCH_SCHEDULE ): return [int(iterations * d) for d in schedule] -def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, gradient_accumulation_size, print_rate, save_rate, resume_path, half_p, bnb, source_model, voice ): +def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, gradient_accumulation_size, print_rate, save_rate, resume_path, half_p, bnb, workers, source_model, voice ): name = f"{voice}-finetune" dataset_name = f"{voice}-train" dataset_path = f"./training/{voice}/train.txt" @@ -1065,7 +1090,7 @@ def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learni messages ) -def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weight=None, learning_rate_schedule=None, batch_size=None, gradient_accumulation_size=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None, half_p=None, bnb=None, source_model=None ): +def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weight=None, learning_rate_schedule=None, batch_size=None, gradient_accumulation_size=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None, half_p=None, bnb=None, workers=None, source_model=None ): if not source_model: source_model = f"./models/tortoise/autoregressive{'_half' if half_p else ''}.pth" @@ -1090,6 +1115,8 @@ def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weig 'float16': 'true' if half_p else 'false', 'bitsandbytes': 'true' if bnb else 'false', + + 'workers': workers if workers else 2, } if resume_path: @@ -1581,9 +1608,9 @@ def unload_tts(): global tts if tts: - print("Unloading TTS") del tts tts = None + print("Unloaded TTS") do_gc() def reload_tts( model=None ): @@ -1656,55 +1683,44 @@ def unload_voicefixer(): global voicefixer if voicefixer: - print("Unloading Voicefixer") del voicefixer voicefixer = None print("Unloaded Voicefixer") do_gc() -def load_whisper_model(name=None, progress=None, language=b'en'): +def load_whisper_model(language=None, model_name=None, progress=None): global whisper_model - if not name: - name = args.whisper_model + if not model_name: + model_name = args.whisper_model else: - args.whisper_model = name + args.whisper_model = model_name save_args_settings() - notify_progress(f"Loading Whisper model: {args.whisper_model}", progress) + if language and f'{model_name}.{language}' in WHISPER_SPECIALIZED_MODELS: + model_name = f'{model_name}.{language}' + print(f"Loading specialized model for language: {language}") + + notify_progress(f"Loading Whisper model: {model_name}", progress) if args.whisper_cpp: from whispercpp import Whisper - whisper_model = Whisper(name, models_dir='./models/', language=language.encode('ascii')) + if not language: + language = 'auto' + + whisper_model = Whisper(model_name, models_dir='./models/', language=language.encode('ascii')) else: import whisper - whisper_model = whisper.load_model(args.whisper_model) + whisper_model = whisper.load_model(model_name) + print("Loaded Whisper model") def unload_whisper(): global whisper_model if whisper_model: - print("Unloading Whisper") del whisper_model whisper_model = None print("Unloaded Whisper") - do_gc() - -""" -def update_whisper_model(name, progress=None): - if not name: - return - - args.whisper_model = name - save_args_settings() - - global whisper_model - if whisper_model: - unload_whisper() - load_whisper_model(name) - else: - args.whisper_model = name - save_args_settings() -""" \ No newline at end of file + do_gc() \ No newline at end of file diff --git a/src/webui.py b/src/webui.py index 89d9462..3ca76d0 100755 --- a/src/webui.py +++ b/src/webui.py @@ -268,6 +268,8 @@ def import_training_settings_proxy( voice ): if "ext" in config and "bitsandbytes" in config["ext"]: bnb = config["ext"]["bitsandbytes"] + workers = config['datasets']['train']['n_workers'] + messages = "\n".join(messages) return ( @@ -282,12 +284,13 @@ def import_training_settings_proxy( voice ): resume_path, half_p, bnb, + workers, source_model, messages ) -def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, gradient_accumulation_size, print_rate, save_rate, resume_path, half_p, bnb, source_model, voice ): +def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, gradient_accumulation_size, print_rate, save_rate, resume_path, half_p, bnb, workers, source_model, voice ): name = f"{voice}-finetune" dataset_name = f"{voice}-train" dataset_path = f"./training/{voice}/train.txt" @@ -330,6 +333,7 @@ def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, lear resume_path=resume_path, half_p=half_p, bnb=bnb, + workers=workers, source_model=source_model, )) return "\n".join(messages) @@ -466,7 +470,7 @@ def setup_gradio(): with gr.Column(): dataset_settings = [ gr.Dropdown( choices=voice_list, label="Dataset Source", type="value", value=voice_list[0] if len(voice_list) > 0 else "" ), - gr.Textbox(label="Language", placeholder="English") + gr.Textbox(label="Language", value="en") ] prepare_dataset_button = gr.Button(value="Prepare") with gr.Column(): @@ -499,11 +503,16 @@ def setup_gradio(): training_settings = training_settings + [ gr.Textbox(label="Resume State Path", placeholder="./training/${voice}-finetune/training_state/${last_state}.state"), ] - training_halfp = gr.Checkbox(label="Half Precision", value=args.training_default_halfp) - training_bnb = gr.Checkbox(label="BitsAndBytes", value=args.training_default_bnb) + + with gr.Row(): + training_halfp = gr.Checkbox(label="Half Precision", value=args.training_default_halfp) + training_bnb = gr.Checkbox(label="BitsAndBytes", value=args.training_default_bnb) + + training_workers = gr.Number(label="Worker Processes", value=2, precision=0) + source_model = gr.Dropdown( choices=autoregressive_models, label="Source Model", type="value", value=autoregressive_models[0] ) dataset_list_dropdown = gr.Dropdown( choices=dataset_list, label="Dataset", type="value", value=dataset_list[0] if len(dataset_list) else "" ) - training_settings = training_settings + [ training_halfp, training_bnb, source_model, dataset_list_dropdown ] + training_settings = training_settings + [ training_halfp, training_bnb, training_workers, source_model, dataset_list_dropdown ] with gr.Row(): refresh_dataset_list = gr.Button(value="Refresh Dataset List") @@ -572,7 +581,7 @@ def setup_gradio(): autoregressive_model_dropdown = gr.Dropdown(choices=autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else autoregressive_models[0]) - whisper_model_dropdown = gr.Dropdown(["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large"], label="Whisper Model", value=args.whisper_model) + whisper_model_dropdown = gr.Dropdown(WHISPER_MODELS, label="Whisper Model", value=args.whisper_model) use_whisper_cpp = gr.Checkbox(label="Use Whisper.cpp", value=args.whisper_cpp) exec_inputs = exec_inputs + [ autoregressive_model_dropdown, whisper_model_dropdown, use_whisper_cpp, training_halfp, training_bnb ] @@ -797,7 +806,7 @@ def setup_gradio(): ) import_dataset_button.click(import_training_settings_proxy, inputs=dataset_list_dropdown, - outputs=training_settings[:11] + [save_yaml_output] #console_output + outputs=training_settings[:13] + [save_yaml_output] #console_output ) save_yaml_button.click(save_training_settings_proxy, inputs=training_settings,