From ee1b048d07551b68b946b764f70f2a3daefece22 Mon Sep 17 00:00:00 2001
From: mrq <mrq@ecker.tech>
Date: Mon, 13 Mar 2023 04:26:00 +0000
Subject: [PATCH] when creating the train/validatio datasets, use segments if
 the main audio's duration is too long, and slice to make the segments if they
 don't exist

---
 src/utils.py | 123 +++++++++++++++++++++++++++++++++++++--------------
 1 file changed, 91 insertions(+), 32 deletions(-)

diff --git a/src/utils.py b/src/utils.py
index 45f43d1..45f8e9f 100755
--- a/src/utils.py
+++ b/src/utils.py
@@ -51,6 +51,9 @@ LEARNING_RATE_SCHEDULE = [ 9, 18, 25, 33 ]
 
 RESAMPLERS = {}
 
+MIN_TRAINING_DURATION = 0.6
+MAX_TRAINING_DURATION = 11.6097505669
+
 args = None
 tts = None
 tts_loading = False
@@ -62,6 +65,9 @@ training_state = None
 current_voice = None
 
 def resample( waveform, input_rate, output_rate=44100 ):
+	# mono-ize
+	waveform = torch.mean(waveform, dim=0, keepdim=True)
+
 	if input_rate == output_rate:
 		return waveform, output_rate
 
@@ -1066,18 +1072,19 @@ def whisper_transcribe( file, language=None ):
 			result['segments'].append(reparsed)
 		return result
 
-def validate_waveform( waveform, sample_rate ):
+def validate_waveform( waveform, sample_rate, min_only=False ):
 	if not torch.any(waveform < 0):
 		return "Waveform is empty"
 
 	num_channels, num_frames = waveform.shape
 	duration = num_channels * num_frames / sample_rate
 	
-	if duration < 0.6:
-		return "Duration too short ({:.3f} < 0.6s)".format(duration)
+	if duration < MIN_TRAINING_DURATION:
+		return "Duration too short ({:.3f}s < {:.3f}s)".format(duration, MIN_TRAINING_DURATION)
 
-	if duration > 11:
-		return "Duration too long (11s < {:.3f})".format(duration)
+	if not min_only:
+		if duration > MAX_TRAINING_DURATION:
+			return "Duration too long ({:.3f}s < {:.3f}s)".format(MAX_TRAINING_DURATION, duration)
 
 	return
 
@@ -1122,7 +1129,24 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non
 
 	return f"Processed dataset to: {indir}"
 
-def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0 ):
+def slice_waveform( waveform, sample_rate, start, end, trim ):
+	start = int(start * sample_rate)
+	end = int(end * sample_rate)
+
+	if start < 0:
+		start = 0
+	if end >= waveform.shape[-1]:
+		end = waveform.shape[-1] - 1
+
+	sliced = waveform[:, start:end]
+
+	error = validate_waveform( sliced, sample_rate, min_only=True )
+	if trim and not error:
+		sliced = torchaudio.functional.vad( sliced, sample_rate )
+
+	return sliced, error
+
+def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, results=None ):
 	indir = f'./training/{voice}/'
 	infile = f'{indir}/whisper.json'
 	messages = []
@@ -1130,7 +1154,8 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0 ):
 	if not os.path.exists(infile):
 		raise Exception(f"Missing dataset: {infile}")
 
-	results = json.load(open(infile, 'r', encoding="utf-8"))
+	if results is None:
+		results = json.load(open(infile, 'r', encoding="utf-8"))
 
 	files = 0
 	segments = 0
@@ -1140,37 +1165,35 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0 ):
 			path = f'./training/{voice}/{filename}'
 
 		if not os.path.exists(path):
-			messages.append(f"Missing source audio: {filename}")
+			message = f"Missing source audio: {filename}"
+			print(message)
+			messages.append(message)
 			continue
 
 		files += 1
 		result = results[filename]
 		waveform, sample_rate = torchaudio.load(path)
+		num_channels, num_frames = waveform.shape
+		duration = num_channels * num_frames / sample_rate
 
 		for segment in result['segments']: 
-			start = int((segment['start'] + start_offset) * sample_rate)
-			end = int((segment['end'] + end_offset) * sample_rate)
-
-			if start < 0:
-				start = 0
-			if end >= waveform.shape[-1]:
-				end = waveform.shape[-1] - 1
-
-			sliced = waveform[:, start:end]
 			file = filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav")
-
-			if trim_silence:
-				sliced = torchaudio.functional.vad( sliced, sample_rate )
 			
-			sliced, sample_rate = resample( sliced, sample_rate, 22050 )
-			torchaudio.save(f"{indir}/audio/{file}", sliced, sample_rate)
+			sliced, error = slice_waveform( waveform, sample_rate, segment['start'] + start_offset, segment['end'] + end_offset, trim_silence )
+			if error:
+				message = f"{error}, skipping... {file}"
+				print(message)
+				messages.append(message)
+				continue
+			sliced, _ = resample( sliced, sample_rate, 22050 )
+			torchaudio.save(f"{indir}/audio/{file}", sliced, 22050)
 			
 			segments +=1
 
 	messages.append(f"Sliced segments: {files} => {segments}.")
 	return "\n".join(messages)
 
-def prepare_dataset( voice, use_segments, text_length, audio_length ):
+def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=True ):
 	indir = f'./training/{voice}/'
 	infile = f'{indir}/whisper.json'
 	messages = []
@@ -1187,31 +1210,67 @@ def prepare_dataset( voice, use_segments, text_length, audio_length ):
 
 	for filename in results:
 		result = results[filename]
-		segments = result['segments'] if use_segments else [{'text': result['text']}]
-		for segment in segments:
-			file = filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav") if use_segments else filename
-			path = f'{indir}/audio/{file}'
+		use_segment = use_segments
+
+		# check if unsegmented audio exceeds 11.6s
+		if not use_segment:
+			path = f'{indir}/audio/{filename}'
 			if not os.path.exists(path):
-				messages.append(f"Missing source audio: {file}")
+				messages.append(f"Missing source audio: {filename}")
+				continue
+
+			metadata = torchaudio.info(path)
+			duration = metadata.num_channels * metadata.num_frames / metadata.sample_rate
+			if duration >= MAX_TRAINING_DURATION:
+				message = f"Audio too large, using segments: {filename}"
+				print(message)
+				messages.append(message)
+				use_segment = True
+
+		segments = result['segments'] if use_segment else [{'text': result['text']}]
+
+		for segment in segments:
+			file = filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav") if use_segment else filename
+			path = f'{indir}/audio/{file}'
+			# segment when needed
+			if not os.path.exists(path):
+				tmp_results = {}
+				tmp_results[filename] = result
+				print(f"Audio not segmented, segmenting: {filename}")
+				message = slice_dataset( voice, results=tmp_results )
+				print(message)
+				messages = messages + message.split("\n")
+
+			if not os.path.exists(path):
+				message = f"Missing source audio: {file}"
+				print(message)
+				messages.append(message)
 				continue
 			
 			text = segment['text'].strip()
+			normalized_text = text
+
 			if len(text) > 200:
-				messages.append(f"[{file}] Text length too long (200 < {len(text)}), skipping...")
+				message = f"Text length too long (200 < {len(text)}), skipping... {file}"
+				print(message)
+				messages.append(message)
 
 			waveform, sample_rate = torchaudio.load(path)
-			num_channels, num_frames = waveform.shape
-			duration = num_channels * num_frames / sample_rate
 
 			error = validate_waveform( waveform, sample_rate )
 			if error:
-				messages.append(f"[{file}]: {error}, skipping...")
+				message = f"{error}, skipping... {file}"
+				print(message)
+				messages.append(message)
 				continue
 
 			culled = len(text) < text_length
 			if not culled and audio_length > 0:
+				num_channels, num_frames = waveform.shape
+				duration = num_channels * num_frames / sample_rate
 				culled = duration < audio_length
 
+			# lines['training' if not culled else 'validation'].append(f'audio/{file}|{text}|{normalized_text}')
 			lines['training' if not culled else 'validation'].append(f'audio/{file}|{text}')
 
 	training_joined = "\n".join(lines['training'])