From d8b996911cc050c56dec47e6ffed6d40c414dabd Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 12 Apr 2023 20:02:46 +0000 Subject: [PATCH] a bunch of shit i had uncommited over the past while pertaining to VALL-E --- modules/tortoise-tts | 2 +- src/utils.py | 50 ++++++++++++++++++++++++----------- src/webui.py | 63 ++++++++++++++++++++++++++++++++++++++++---- 3 files changed, 94 insertions(+), 21 deletions(-) diff --git a/modules/tortoise-tts b/modules/tortoise-tts index 0bcdf81..f025470 160000 --- a/modules/tortoise-tts +++ b/modules/tortoise-tts @@ -1 +1 @@ -Subproject commit 0bcdf81d0444218b4dedaefa5c546d42f36b8130 +Subproject commit f025470d60fd18993caaa651e6faa585bcc420f0 diff --git a/src/utils.py b/src/utils.py index ae6fb89..0e1c504 100755 --- a/src/utils.py +++ b/src/utils.py @@ -75,6 +75,7 @@ try: VALLE_ENABLED = True except Exception as e: + print(e) pass if VALLE_ENABLED: @@ -156,10 +157,12 @@ def generate_valle(**kwargs): voice_cache = {} def fetch_voice( voice ): - voice_dir = f'./voices/{voice}/' + voice_dir = f'./training/{voice}/audio/' + if not os.path.isdir(voice_dir): + voice_dir = f'./voices/{voice}/' files = [ f'{voice_dir}/{d}' for d in os.listdir(voice_dir) if d[-4:] == ".wav" ] - return files - # return random.choice(files) + # return files + return random.choice(files) def get_settings( override=None ): settings = { @@ -1089,13 +1092,13 @@ class TrainingState(): 'ar-quarter.lr', 'nar-quarter.lr', ] keys['losses'] = [ - 'ar.loss', 'nar.loss', - 'ar-half.loss', 'nar-half.loss', - 'ar-quarter.loss', 'nar-quarter.loss', + 'ar.loss', 'nar.loss', 'ar+nar.loss', + 'ar-half.loss', 'nar-half.loss', 'ar-half+nar-half.loss', + 'ar-quarter.loss', 'nar-quarter.loss', 'ar-quarter+nar-quarter.loss', - 'ar.loss.nll', 'nar.loss.nll', - 'ar-half.loss.nll', 'nar-half.loss.nll', - 'ar-quarter.loss.nll', 'nar-quarter.loss.nll', + # 'ar.loss.nll', 'nar.loss.nll', + # 'ar-half.loss.nll', 'nar-half.loss.nll', + # 'ar-quarter.loss.nll', 'nar-quarter.loss.nll', ] keys['accuracies'] = [ @@ -1123,7 +1126,7 @@ class TrainingState(): prefix = "" - if data["mode"] == "validation": + if "mode" in self.info and self.info["mode"] == "validation": prefix = f'{self.info["name"] if "name" in self.info else "val"}_' self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{prefix}{k}' }) @@ -1231,6 +1234,7 @@ class TrainingState(): unq = {} averager = None + prev_state = 0 for log in logs: with open(log, 'r', encoding="utf-8") as f: @@ -1250,6 +1254,7 @@ class TrainingState(): name = "train" mode = "training" + prev_state = 0 elif line.find('Validation Metrics:') >= 0: data = json.loads(line.split("Validation Metrics:")[-1]) if "it" not in data: @@ -1257,8 +1262,15 @@ class TrainingState(): if "epoch" not in data: data['epoch'] = epoch - name = data['name'] if 'name' in data else "val" + # name = data['name'] if 'name' in data else "val" mode = "validation" + + if prev_state == 0: + name = "subtrain" + else: + name = "val" + + prev_state += 1 else: continue @@ -1272,6 +1284,7 @@ class TrainingState(): if not averager or averager['key'] != f'{it}_{name}' or averager['mode'] != mode: averager = { 'key': f'{it}_{name}', + 'name': name, 'mode': mode, "metrics": {} } @@ -1292,11 +1305,13 @@ class TrainingState(): if update and it <= self.last_info_check_at: continue + blacklist = [ "batch", "eval" ] for it in unq: if args.tts_backend == "vall-e": stats = unq[it] - data = {k: sum(v) / len(v) for k, v in stats['metrics'].items()} - data['mode'] = stats + data = {k: sum(v) / len(v) for k, v in stats['metrics'].items() if k not in blacklist } + data['name'] = stats['name'] + data['mode'] = stats['mode'] data['steps'] = len(stats['metrics']['it']) else: data = unq[it] @@ -1633,6 +1648,7 @@ def whisper_transcribe( file, language=None ): device = "cuda" if get_device_name() == "cuda" else "cpu" if whisper_vad: + # omits a considerable amount of the end """ if args.whisper_batchsize > 1: result = whisperx.transcribe_with_vad_parallel(whisper_model, file, whisper_vad, batch_size=args.whisper_batchsize, language=language, task="transcribe") @@ -1778,7 +1794,9 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, resul messages = [] if not os.path.exists(infile): - raise Exception(f"Missing dataset: {infile}") + message = f"Missing dataset: {infile}" + print(message) + return message if results is None: results = json.load(open(infile, 'r', encoding="utf-8")) @@ -1903,7 +1921,9 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p indir = f'./training/{voice}/' infile = f'{indir}/whisper.json' if not os.path.exists(infile): - raise Exception(f"Missing dataset: {infile}") + message = f"Missing dataset: {infile}" + print(message) + return message results = json.load(open(infile, 'r', encoding="utf-8")) diff --git a/src/webui.py b/src/webui.py index a0cf253..e0f7873 100755 --- a/src/webui.py +++ b/src/webui.py @@ -196,6 +196,50 @@ def read_generate_settings_proxy(file, saveAs='.temp'): def slice_dataset_proxy( voice, trim_silence, start_offset, end_offset, progress=gr.Progress(track_tqdm=True) ): return slice_dataset( voice, trim_silence=trim_silence, start_offset=start_offset, end_offset=end_offset, results=None, progress=progress ) +def diarize_dataset( voice, progress=gr.Progress(track_tqdm=False) ): + from pyannote.audio import Pipeline + pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=args.hf_token) + + messages = [] + files = sorted( get_voices(load_latents=False)[voice] ) + for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress): + diarization = pipeline(file) + for turn, _, speaker in diarization.itertracks(yield_label=True): + message = f"start={turn.start:.1f}s stop={turn.end:.1f}s speaker_{speaker}" + print(message) + messages.append(message) + + return "\n".join(messages) + +def prepare_all_datasets( language, validation_text_length, validation_audio_length, skip_existings, slice_audio, trim_silence, slice_start_offset, slice_end_offset, progress=gr.Progress(track_tqdm=False) ): + kwargs = locals() + + messages = [] + voices = get_voice_list() + + """ + for voice in voices: + message = prepare_dataset_proxy(voice, **kwargs) + messages.append(message) + """ + for voice in voices: + print("Processing:", voice) + message = transcribe_dataset( voice=voice, language=language, skip_existings=skip_existings, progress=progress ) + messages.append(message) + + if slice_audio: + for voice in voices: + print("Processing:", voice) + message = slice_dataset( voice, trim_silence=trim_silence, start_offset=slice_start_offset, end_offset=slice_end_offset, results=None, progress=progress ) + messages.append(message) + + for voice in voices: + print("Processing:", voice) + message = prepare_dataset( voice, use_segments=slice_audio, text_length=validation_text_length, audio_length=validation_audio_length, progress=progress ) + messages.append(message) + + return "\n".join(messages) + def prepare_dataset_proxy( voice, language, validation_text_length, validation_audio_length, skip_existings, slice_audio, trim_silence, slice_start_offset, slice_end_offset, progress=gr.Progress(track_tqdm=False) ): messages = [] @@ -468,6 +512,8 @@ def setup_gradio(): DATASET_SETTINGS['slice_end_offset'] = gr.Number(label="Slice End Offset", value=0) transcribe_button = gr.Button(value="Transcribe and Process") + transcribe_all_button = gr.Button(value="Transcribe All") + diarize_button = gr.Button(value="Diarize") with gr.Row(): slice_dataset_button = gr.Button(value="(Re)Slice Audio") @@ -579,7 +625,7 @@ def setup_gradio(): tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350, - visible=args.tts_backend=="vall-e" + visible=False, # args.tts_backend=="vall-e" ) view_losses = gr.Button(value="View Losses") @@ -611,10 +657,7 @@ def setup_gradio(): # EXEC_SETTINGS['tts_backend'] = gr.Dropdown(TTSES, label="TTS Backend", value=args.tts_backend if args.tts_backend else TTSES[0]) with gr.Column(visible=args.tts_backend=="vall-e"): - default_valle_model_choice = "" - if len(valle_models): - default_valle_model_choice = valle_models[0] - EXEC_SETTINGS['valle_model'] = gr.Dropdown(choices=valle_models, label="VALL-E Model Config", value=args.valle_model if args.valle_model else default_valle_model_choice) + EXEC_SETTINGS['valle_model'] = gr.Dropdown(choices=valle_models, label="VALL-E Model Config", value=args.valle_model if args.valle_model else valle_models[0]) with gr.Column(visible=args.tts_backend=="tortoise"): EXEC_SETTINGS['autoregressive_model'] = gr.Dropdown(choices=["auto"] + autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else "auto") @@ -859,6 +902,16 @@ def setup_gradio(): inputs=dataset_settings, outputs=prepare_dataset_output #console_output ) + transcribe_all_button.click( + prepare_all_datasets, + inputs=dataset_settings[1:], + outputs=prepare_dataset_output #console_output + ) + diarize_button.click( + diarize_dataset, + inputs=dataset_settings[0], + outputs=prepare_dataset_output #console_output + ) prepare_dataset_button.click( prepare_dataset, inputs=[