From 382a3e41048f0c371d525ed19464165a8cd5a63a Mon Sep 17 00:00:00 2001
From: mrq <barry.quiggles@protonmail.com>
Date: Sat, 11 Mar 2023 21:17:11 +0000
Subject: [PATCH] rely on the whisper.json for handling a lot more things

---
 src/utils.py | 194 +++++++++++++++++++++------------------------------
 src/webui.py |  26 ++++---
 2 files changed, 94 insertions(+), 126 deletions(-)

diff --git a/src/utils.py b/src/utils.py
index d549145..8fe8fc6 100755
--- a/src/utils.py
+++ b/src/utils.py
@@ -33,7 +33,7 @@ from datetime import datetime
 from datetime import timedelta
 
 from tortoise.api import TextToSpeech, MODELS, get_model_path, pad_or_truncate
-from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir
+from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir, get_voices
 from tortoise.utils.text import split_and_recombine_text
 from tortoise.utils.device import get_device_name, set_device_name, get_device_count, get_device_vram
 
@@ -1059,6 +1059,47 @@ def validate_waveform( waveform, sample_rate ):
 		return False
 	return True
 
+def transcribe_dataset( voice, language=None, skip_existings=False, progress=None ):
+	unload_tts()
+
+	global whisper_model
+	if whisper_model is None:
+		load_whisper_model(language=language)
+
+
+	results = {}
+
+	files = sorted( get_voices(load_latents=False)[voice] )
+	indir = f'./training/{voice}/'
+	infile = f'{indir}/whisper.json'
+	
+	os.makedirs(f'{indir}/audio/', exist_ok=True)
+	
+	if os.path.exists(infile):
+		results = json.load(open(infile, 'r', encoding="utf-8"))
+
+	for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress):
+		basename = os.path.basename(file)
+
+		if basename in results and skip_existings:
+			print(f"Skipping already parsed file: {basename}")
+			continue
+
+		results[basename] = whisper_transcribe(file, language=language)
+
+		# lazy copy
+		waveform, sampling_rate = torchaudio.load(file)
+		torchaudio.save(f"{indir}/audio/{basename}", waveform, sampling_rate)
+
+		with open(infile, 'w', encoding="utf-8") as f:
+			f.write(json.dumps(results, indent='\t'))
+
+		do_gc()
+
+	unload_whisper()
+
+	return f"Processed dataset to: {indir}"
+
 def slice_dataset( voice, start_offset=0, end_offset=0 ):
 	indir = f'./training/{voice}/'
 	infile = f'{indir}/whisper.json'
@@ -1066,148 +1107,71 @@ def slice_dataset( voice, start_offset=0, end_offset=0 ):
 	if not os.path.exists(infile):
 		raise Exception(f"Missing dataset: {infile}")
 
-	with open(infile, 'r', encoding="utf-8") as f:
-		results = json.load(f)
+	results = json.load(open(infile, 'r', encoding="utf-8"))
 
-	transcription = []
+	files = 0
+	segments = 0
 	for filename in results:
-		idx = 0
+		files += 1
+
 		result = results[filename]
 		waveform, sampling_rate = torchaudio.load(f'./voices/{voice}/{filename}')
 
 		for segment in result['segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress):
+			segments +=1 
 			start = int((segment['start'] + start_offset) * sampling_rate)
 			end = int((segment['end'] + end_offset) * sampling_rate)
 
-			sliced_waveform = waveform[:, start:end]
-			sliced_name = filename.replace(".wav", f"_{pad(idx, 4)}.wav")
+			sliced = waveform[:, start:end]
+			file = filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav")
 
-			if not validate_waveform( sliced_waveform, sampling_rate ):
-				print(f"Invalid waveform segment ({segment['start']}:{segment['end']}): {sliced_name}, skipping...")
+			if not validate_waveform( sliced, sampling_rate ):
+				print(f"Invalid waveform segment ({segment['start']}:{segment['end']}): {file}, skipping...")
 				continue
 
-			torchaudio.save(f"{indir}/audio/{sliced_name}", sliced_waveform, sampling_rate)
+			torchaudio.save(f"{indir}/audio/{file}", sliced, sampling_rate)
 
-			idx = idx + 1
-			line = f"audio/{sliced_name}|{segment['text'].strip()}"
-			transcription.append(line)
-			with open(f'{indir}/train.txt', 'a', encoding="utf-8") as f:
-				f.write(f'\n{line}')
+	return f"Sliced segments: {files} => {segments}."
 
-	joined = "\n".join(transcription)
-	with open(f'{indir}/train.txt', 'w', encoding="utf-8") as f:
-		f.write(joined)
-
-	return f"Processed dataset to: {indir}\n{joined}"
-
-def prepare_dataset( files, outdir, language=None, skip_existings=False, progress=None ):
-	unload_tts()
-
-	global whisper_model
-	if whisper_model is None:
-		load_whisper_model(language=language)
-
-	os.makedirs(f'{outdir}/audio/', exist_ok=True)
-
-	results = {}
-	transcription = []
-	files = sorted(files)
-
-	previous_list = []
-	if skip_existings and os.path.exists(f'{outdir}/train.txt'):
-		parsed_list = []
-		with open(f'{outdir}/train.txt', 'r', encoding="utf-8") as f:
-			parsed_list = f.readlines()
-
-		for line in parsed_list:
-			match = re.findall(r"^(.+?)_\d+\.wav$", line.split("|")[0])
-
-			if match is None or len(match) == 0:
-				continue
-			
-			if match[0] not in previous_list:
-				previous_list.append(f'{match[0].split("/")[-1]}.wav')
-
-
-	for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress):
-		basename = os.path.basename(file)
-
-		if basename in previous_list:
-			print(f"Skipping already parsed file: {basename}")
-			continue
-
-		result = whisper_transcribe(file, language=language)
-		results[basename] = result
-		print(f"Transcribed file: {file}, {len(result['segments'])} found.")
-
-		waveform, sampling_rate = torchaudio.load(file)
-
-		if not validate_waveform( waveform, sampling_rate ):
-			print(f"Invalid waveform: {basename}, skipping...")
-			continue
-
-		torchaudio.save(f"{outdir}/audio/{basename}", waveform, sampling_rate)
-		line = f"audio/{basename}|{result['text'].strip()}"
-		transcription.append(line)
-		with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f:
-			f.write(f'\n{line}')
-
-		do_gc()
-	
-	with open(f'{outdir}/whisper.json', 'w', encoding="utf-8") as f:
-		f.write(json.dumps(results, indent='\t'))
-
-	unload_whisper()
-
-	joined = "\n".join(transcription)
-	if not skip_existings:
-		with open(f'{outdir}/train.txt', 'w', encoding="utf-8") as f:
-			f.write(joined)
-
-	return f"Processed dataset to: {outdir}\n{joined}"
-
-def prepare_validation_dataset( voice, text_length, audio_length ):
+def prepare_dataset( voice, use_segments, text_length, audio_length ):
 	indir = f'./training/{voice}/'
-	infile = f'{indir}/dataset.txt'
-	if not os.path.exists(infile):
-		infile = f'{indir}/train.txt'
-		with open(f'{indir}/train.txt', 'r', encoding="utf-8") as src:
-			with open(f'{indir}/dataset.txt', 'w', encoding="utf-8") as dst:
-				dst.write(src.read())
+	infile = f'{indir}/whisper.json'
 
 	if not os.path.exists(infile):
 		raise Exception(f"Missing dataset: {infile}")
 
-	with open(infile, 'r', encoding="utf-8") as f:
-		lines = f.readlines()
+	results = json.load(open(infile, 'r', encoding="utf-8"))
 
-	validation = []
-	training = []
+	lines = {
+		'training': [],
+		'validation': [],
+	}
 
-	for line in lines:
-		split = line.split("|")
-		filename = split[0]
-		text = split[1]
-		culled = len(text) < text_length
+	for filename in results:
+		result = results[filename]
+		segments = result['segments'] if use_segments else [{'text': result['text']}]
+		for segment in segments:
+			text = segment['text'].strip()
+			file = filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav") if use_segments else filename
 
-		if not culled and audio_length > 0:
-			metadata = torchaudio.info(f'{indir}/{filename}')
-			duration = metadata.num_channels * metadata.num_frames / metadata.sample_rate
-			culled = duration < audio_length
+			culled = len(text) < text_length
+			if not culled and audio_length > 0:
+				metadata = torchaudio.info(f'{indir}/audio/{file}')
+				duration = metadata.num_channels * metadata.num_frames / metadata.sample_rate
+				culled = duration < audio_length
 
-		if culled:
-			validation.append(line.strip())
-		else:
-			training.append(line.strip())
+			lines['training' if not culled else 'validation'].append(f'audio/{file}|{text}')
+
+	training_joined = "\n".join(lines['training'])
+	validation_joined = "\n".join(lines['validation'])
 
 	with open(f'{indir}/train.txt', 'w', encoding="utf-8") as f:
-		f.write("\n".join(training))
+		f.write(training_joined)
 
 	with open(f'{indir}/validation.txt', 'w', encoding="utf-8") as f:
-		f.write("\n".join(validation))
+		f.write(validation_joined)
 
-	msg = f"Culled {len(validation)}/{len(lines)} lines."
-	print(msg)
+	msg = f"Prepared {len(lines['training'])} lines (validation: {len(lines['validation'])}).\n{training_joined}\n\n{validation_joined}"
 	return msg
 
 def calc_iterations( epochs, lines, batch_size ):
diff --git a/src/webui.py b/src/webui.py
index c4e0190..82bec2f 100755
--- a/src/webui.py
+++ b/src/webui.py
@@ -182,16 +182,19 @@ def read_generate_settings_proxy(file, saveAs='.temp'):
 		gr.update(visible=j is not None),
 	)
 
-def prepare_dataset_proxy( voice, language, validation_text_length, validation_audio_length, skip_existings, slice_audio, progress=gr.Progress(track_tqdm=True) ):
+def prepare_dataset_proxy( voice, language, validation_text_length, validation_audio_length, skip_existings, slice_audio, progress=gr.Progress(track_tqdm=False) ):
 	messages = []
-	message = prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, skip_existings=skip_existings, progress=progress )
+	
+	message = transcribe_dataset( voice=voice, language=language, skip_existings=skip_existings, progress=progress )
 	messages.append(message)
+
 	if slice_audio:
 		message = slice_dataset( voice )
 		messages.append(message)
-	if validation_text_length > 0 or validation_audio_length > 0:
-		message = prepare_validation_dataset( voice, text_length=validation_text_length, audio_length=validation_audio_length )
-		messages.append(message)
+
+	message = prepare_dataset( voice, use_segments=slice_audio, text_length=validation_text_length, audio_length=validation_audio_length )
+	messages.append(message)
+
 	return "\n".join(messages)
 
 def update_args_proxy( *args ):
@@ -421,8 +424,8 @@ def setup_gradio():
 
 						with gr.Row():
 							transcribe_button = gr.Button(value="Transcribe")
-							prepare_validation_button = gr.Button(value="(Re)Create Validation Dataset")
 							slice_dataset_button = gr.Button(value="(Re)Slice Audio")
+							prepare_dataset_button = gr.Button(value="(Re)Create Dataset")
 
 						with gr.Row():
 							EXEC_SETTINGS['whisper_backend'] = gr.Dropdown(WHISPER_BACKENDS, label="Whisper Backends", value=args.whisper_backend)
@@ -654,7 +657,7 @@ def setup_gradio():
 			inputs=None,
 			outputs=[
 				GENERATE_SETTINGS['voice'],
-				dataset_settings[0],
+				DATASET_SETTINGS['voice'],
 				history_voices
 			]
 		)
@@ -742,10 +745,11 @@ def setup_gradio():
 			inputs=dataset_settings,
 			outputs=prepare_dataset_output #console_output
 		)
-		prepare_validation_button.click(
-			prepare_validation_dataset,
+		prepare_dataset_button.click(
+			prepare_dataset,
 			inputs=[
-				dataset_settings[0],
+				DATASET_SETTINGS['voice'],
+				DATASET_SETTINGS['slice'],
 				DATASET_SETTINGS['validation_text_length'],
 				DATASET_SETTINGS['validation_audio_length'],
 			],
@@ -754,7 +758,7 @@ def setup_gradio():
 		slice_dataset_button.click(
 			slice_dataset,
 			inputs=[
-				dataset_settings[0]
+				DATASET_SETTINGS['voice']
 			],
 			outputs=prepare_dataset_output
 		)