From 72a38ff2fc6aeb8b8184bb8e8103497a37ce151d Mon Sep 17 00:00:00 2001
From: mrq <mrq@ecker.tech>
Date: Mon, 21 Aug 2023 03:31:49 +0000
Subject: [PATCH] made initialization faster if there's a lot of voice files
 (because glob fucking sucks), commiting changes buried on my training rig

---
 modules/tortoise-tts |   2 +-
 requirements.txt     |   3 +-
 src/utils.py         | 314 +++++++++++++++++++++++++++----------------
 3 files changed, 201 insertions(+), 118 deletions(-)

diff --git a/modules/tortoise-tts b/modules/tortoise-tts
index 5ff00bf..cbd3c95 160000
--- a/modules/tortoise-tts
+++ b/modules/tortoise-tts
@@ -1 +1 @@
-Subproject commit 5ff00bf3bfa97e2c8e9f166b920273f83ac9d8f0
+Subproject commit cbd3c95c42ac1da9772f61b9895954ee693075c9
diff --git a/requirements.txt b/requirements.txt
index e1794a4..fcb0746 100755
--- a/requirements.txt
+++ b/requirements.txt
@@ -7,4 +7,5 @@ music-tag
 voicefixer
 psutil
 phonemizer
-pydantic==1.10.11
\ No newline at end of file
+pydantic==1.10.11
+websockets
\ No newline at end of file
diff --git a/src/utils.py b/src/utils.py
index 759e054..a79c6ce 100755
--- a/src/utils.py
+++ b/src/utils.py
@@ -45,7 +45,7 @@ from tortoise.utils.device import get_device_name, set_device_name, get_device_c
 
 MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth"
 
-WHISPER_MODELS = ["tiny", "base", "small", "medium", "large"]
+WHISPER_MODELS = ["tiny", "base", "small", "medium", "large", "large-v1", "large-v2"]
 WHISPER_SPECIALIZED_MODELS = ["tiny.en", "base.en", "small.en", "medium.en"]
 WHISPER_BACKENDS = ["openai/whisper", "lightmare/whispercpp", "m-bain/whisperx"]
 VOCODERS = ['univnet', 'bigvgan_base_24khz_100band', 'bigvgan_24khz_100band']
@@ -61,12 +61,15 @@ RESAMPLERS = {}
 
 MIN_TRAINING_DURATION = 0.6
 MAX_TRAINING_DURATION = 11.6097505669
+MAX_TRAINING_CHAR_LENGTH = 200
 
 VALLE_ENABLED = False
 BARK_ENABLED = False
 
 VERBOSE_DEBUG = True
 
+import traceback
+
 try:
 	from whisper.normalizers.english import EnglishTextNormalizer
 	from whisper.normalizers.basic import BasicTextNormalizer
@@ -75,7 +78,7 @@ try:
 	print("Whisper detected")
 except Exception as e:
 	if VERBOSE_DEBUG:
-		print("Error:", e)
+		print(traceback.format_exc())
 	pass
 
 try:
@@ -90,12 +93,14 @@ try:
 	VALLE_ENABLED = True
 except Exception as e:
 	if VERBOSE_DEBUG:
-		print("Error:", e)
+		print(traceback.format_exc())
 	pass
 
 if VALLE_ENABLED:
 	TTSES.append('vall-e')
 
+# torchaudio.set_audio_backend('soundfile')
+
 try:
 	import bark
 	from bark import text_to_semantic
@@ -109,35 +114,10 @@ try:
 	BARK_ENABLED = True
 except Exception as e:
 	if VERBOSE_DEBUG:
-		print("Error:", e)
+		print(traceback.format_exc())
 	pass
 
 if BARK_ENABLED:
-	try:
-		from vocos import Vocos
-		VOCOS_ENABLED = True
-		print("Vocos detected")
-	except Exception as e:
-		if VERBOSE_DEBUG:
-			print("Error:", e)
-		VOCOS_ENABLED = False
-
-	try:
-		from hubert.hubert_manager import HuBERTManager
-		from hubert.pre_kmeans_hubert import CustomHubert
-		from hubert.customtokenizer import CustomTokenizer
-
-		hubert_manager = HuBERTManager()
-		hubert_manager.make_sure_hubert_installed()
-		hubert_manager.make_sure_tokenizer_installed()
-
-		HUBERT_ENABLED = True
-		print("HuBERT detected")
-	except Exception as e:
-		if VERBOSE_DEBUG:
-			print("Error:", e)
-		HUBERT_ENABLED = False
-
 	TTSES.append('bark')
 
 	def semantic_to_audio_tokens(
@@ -181,7 +161,32 @@ if BARK_ENABLED:
 
 			self.device = get_device_name()
 
-			if VOCOS_ENABLED:
+			try:
+				from vocos import Vocos
+				self.vocos_enabled = True
+				print("Vocos detected")
+			except Exception as e:
+				if VERBOSE_DEBUG:
+					print(traceback.format_exc())
+				self.vocos_enabled = False
+
+			try:
+				from hubert.hubert_manager import HuBERTManager
+				from hubert.pre_kmeans_hubert import CustomHubert
+				from hubert.customtokenizer import CustomTokenizer
+
+				hubert_manager = HuBERTManager()
+				hubert_manager.make_sure_hubert_installed()
+				hubert_manager.make_sure_tokenizer_installed()
+
+				self.hubert_enabled = True
+				print("HuBERT detected")
+			except Exception as e:
+				if VERBOSE_DEBUG:
+					print(traceback.format_exc())
+				self.hubert_enabled = False
+
+			if self.vocos_enabled:
 				self.vocos = Vocos.from_pretrained("charactr/vocos-encodec-24khz").to(self.device)
 
 		def create_voice( self, voice ):
@@ -238,7 +243,7 @@ if BARK_ENABLED:
 
 			# generate semantic tokens
 
-			if HUBERT_ENABLED:
+			if self.hubert_enabled:
 				wav = wav.to(self.device)
 
 				# Extract discrete codes from EnCodec
@@ -426,7 +431,7 @@ def generate_bark(**kwargs):
 	idx_cache = {}
 	for i, file in enumerate(os.listdir(outdir)):
 		filename = os.path.basename(file)
-		extension = os.path.splitext(filename)[1]
+		extension = os.path.splitext(filename)[-1][1:]
 		if extension != ".json" and extension != ".wav":
 			continue
 		match = re.findall(rf"^{cleanup_voice_name(voice)}_(\d+)(?:.+?)?{extension}$", filename)
@@ -672,18 +677,23 @@ def generate_valle(**kwargs):
 
 	voice_cache = {}
 	def fetch_voice( voice ):
+		if voice in voice_cache:
+			return voice_cache[voice]
 		voice_dir = f'./training/{voice}/audio/'
-		if not os.path.isdir(voice_dir):
+
+		if not os.path.isdir(voice_dir) or len(os.listdir(voice_dir)) == 0:
 			voice_dir = f'./voices/{voice}/'
+
 		files = [ f'{voice_dir}/{d}' for d in os.listdir(voice_dir) if d[-4:] == ".wav" ]
 		# return files
-		return random.choice(files)
+		voice_cache[voice] = random.choice(files)
+		return voice_cache[voice]
 
 	def get_settings( override=None ):
 		settings = {
 			'ar_temp': float(parameters['temperature']),
 			'nar_temp': float(parameters['temperature']),
-			'max_ar_samples': parameters['num_autoregressive_samples'],
+			'max_ar_steps': parameters['num_autoregressive_samples'],
 		}
 
 		# could be better to just do a ternary on everything above, but i am not a professional
@@ -697,7 +707,7 @@ def generate_valle(**kwargs):
 					continue
 				settings[k] = override[k]
 
-		settings['reference'] = fetch_voice(voice=selected_voice)
+		settings['references'] = [ fetch_voice(voice=selected_voice) for _ in range(3) ]
 		return settings
 
 	if not parameters['delimiter']:
@@ -723,7 +733,7 @@ def generate_valle(**kwargs):
 	idx_cache = {}
 	for i, file in enumerate(os.listdir(outdir)):
 		filename = os.path.basename(file)
-		extension = os.path.splitext(filename)[1]
+		extension = os.path.splitext(filename)[-1][1:]
 		if extension != ".json" and extension != ".wav":
 			continue
 		match = re.findall(rf"^{voice}_(\d+)(?:.+?)?{extension}$", filename)
@@ -783,11 +793,14 @@ def generate_valle(**kwargs):
 			except Exception as e:
 				raise Exception("Prompt settings editing requested, but received invalid JSON")
 
-		settings = get_settings( override=override )
-		reference = settings['reference']
-		settings.pop("reference")
+		name = get_name(line=line, candidate=0)
 
-		gen = tts.inference(cut_text, reference, **settings )
+		settings = get_settings( override=override )
+		references = settings['references']
+		settings.pop("references")
+		settings['out_path'] = f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav'
+
+		gen = tts.inference(cut_text, references, **settings )
 
 		run_time = time.time()-start_time
 		print(f"Generating line took {run_time} seconds")
@@ -805,7 +818,7 @@ def generate_valle(**kwargs):
 
 			# save here in case some error happens mid-batch
 			#torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu(), sr)
-			soundfile.write(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu()[0,0], sr)
+			#soundfile.write(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu()[0,0], sr)
 			wav, sr = torchaudio.load(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
 
 			audio_cache[name] = {
@@ -1085,7 +1098,7 @@ def generate_tortoise(**kwargs):
 	idx_cache = {}
 	for i, file in enumerate(os.listdir(outdir)):
 		filename = os.path.basename(file)
-		extension = os.path.splitext(filename)[1]
+		extension = os.path.splitext(filename)[-1][1:]
 		if extension != ".json" and extension != ".wav":
 			continue
 		match = re.findall(rf"^{voice}_(\d+)(?:.+?)?{extension}$", filename)
@@ -1605,30 +1618,18 @@ class TrainingState():
 			if args.tts_backend == "vall-e":
 				keys['lrs'] = [
 					'ar.lr', 'nar.lr',
-					'ar-half.lr', 'nar-half.lr',
-					'ar-quarter.lr', 'nar-quarter.lr',
 				]
 				keys['losses'] = [
-					'ar.loss', 'nar.loss', 'ar+nar.loss',
-					'ar-half.loss', 'nar-half.loss', 'ar-half+nar-half.loss',
-					'ar-quarter.loss', 'nar-quarter.loss', 'ar-quarter+nar-quarter.loss',
-
+				#	'ar.loss', 'nar.loss', 'ar+nar.loss',
 					'ar.loss.nll', 'nar.loss.nll',
-					'ar-half.loss.nll', 'nar-half.loss.nll',
-					'ar-quarter.loss.nll', 'nar-quarter.loss.nll',
 				]
 
 				keys['accuracies'] = [
 					'ar.loss.acc', 'nar.loss.acc',
-					'ar-half.loss.acc', 'nar-half.loss.acc',
-					'ar-quarter.loss.acc', 'nar-quarter.loss.acc',
+					'ar.stats.acc', 'nar.loss.acc',
 				]
-				keys['precisions'] = [
-					'ar.loss.precision', 'nar.loss.precision',
-					'ar-half.loss.precision', 'nar-half.loss.precision',
-					'ar-quarter.loss.precision', 'nar-quarter.loss.precision',
-				]
-				keys['grad_norms'] = ['ar.grad_norm', 'nar.grad_norm', 'ar-half.grad_norm', 'nar-half.grad_norm', 'ar-quarter.grad_norm', 'nar-quarter.grad_norm']
+				keys['precisions'] = [ 'ar.loss.precision', 'nar.loss.precision', ]
+				keys['grad_norms'] = ['ar.grad_norm', 'nar.grad_norm']
 
 			for k in keys['lrs']:
 				if k not in self.info:
@@ -1746,7 +1747,8 @@ class TrainingState():
 		if args.tts_backend == "tortoise":
 			logs = sorted([f'{self.training_dir}/finetune/{d}' for d in os.listdir(f'{self.training_dir}/finetune/') if d[-4:] == ".log" ])
 		else:
-			logs = sorted([f'{self.training_dir}/logs/{d}/log.txt' for d in os.listdir(f'{self.training_dir}/logs/') ])
+			log_dir = "logs"
+			logs = sorted([f'{self.training_dir}/{log_dir}/{d}/log.txt' for d in os.listdir(f'{self.training_dir}/{log_dir}/') ])
 
 		if update:
 			logs = [logs[-1]]
@@ -2219,6 +2221,8 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non
 	files = get_voice(voice, load_latents=False)
 	indir = f'./training/{voice}/'
 	infile = f'{indir}/whisper.json'
+
+	quantize_in_memory = args.tts_backend == "vall-e"
 	
 	os.makedirs(f'{indir}/audio/', exist_ok=True)
 	
@@ -2245,13 +2249,24 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non
 			continue
 
 		results[basename] = result
-		waveform, sample_rate = torchaudio.load(file)
-		# resample to the input rate, since it'll get resampled for training anyways
-		# this should also "help" increase throughput a bit when filling the dataloaders
-		waveform, sample_rate = resample(waveform, sample_rate, TARGET_SAMPLE_RATE)
-		if waveform.shape[0] == 2:
-			waveform = waveform[:1]
-		torchaudio.save(f"{indir}/audio/{basename}", waveform, sample_rate, encoding="PCM_S", bits_per_sample=16)
+
+		if not quantize_in_memory:
+			waveform, sample_rate = torchaudio.load(file)
+			# resample to the input rate, since it'll get resampled for training anyways
+			# this should also "help" increase throughput a bit when filling the dataloaders
+			waveform, sample_rate = resample(waveform, sample_rate, TARGET_SAMPLE_RATE)
+			if waveform.shape[0] == 2:
+				waveform = waveform[:1]
+			
+			try:
+				kwargs = {}
+				if basename[-4:] == ".wav":
+					kwargs['encoding'] = "PCM_S"
+					kwargs['bits_per_sample'] = 16
+
+				torchaudio.save(f"{indir}/audio/{basename}", waveform, sample_rate, **kwargs)
+			except Exception as e:
+				print(e)
 
 		with open(infile, 'w', encoding="utf-8") as f:
 			f.write(json.dumps(results, indent='\t'))
@@ -2317,6 +2332,9 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, resul
 	segments = 0
 	for filename in results:
 		path = f'./voices/{voice}/{filename}'
+		extension = os.path.splitext(filename)[-1][1:]
+		out_extension = extension # "wav"
+
 		if not os.path.exists(path):
 			path = f'./training/{voice}/{filename}'
 
@@ -2333,7 +2351,7 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, resul
 		duration = num_frames / sample_rate
 
 		for segment in result['segments']: 
-			file = filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav")
+			file = filename.replace(f".{extension}", f"_{pad(segment['id'], 4)}.{out_extension}")
 			
 			sliced, error = slice_waveform( waveform, sample_rate, segment['start'] + start_offset, segment['end'] + end_offset, trim_silence )
 			if error:
@@ -2341,12 +2359,17 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, resul
 				print(message)
 				messages.append(message)
 				continue
-			sliced, _ = resample( sliced, sample_rate, TARGET_SAMPLE_RATE )
+		#	sliced, _ = resample( sliced, sample_rate, TARGET_SAMPLE_RATE )
 
 			if waveform.shape[0] == 2:
 				waveform = waveform[:1]
 				
-			torchaudio.save(f"{indir}/audio/{file}", sliced, TARGET_SAMPLE_RATE, encoding="PCM_S", bits_per_sample=16)
+			kwargs = {}
+			if file[-4:] == ".wav":
+				kwargs['encoding'] = "PCM_S"
+				kwargs['bits_per_sample'] = 16
+
+			torchaudio.save(f"{indir}/audio/{file}", sliced, TARGET_SAMPLE_RATE, **kwargs)
 			
 			segments +=1
 
@@ -2462,18 +2485,32 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
 
 	errored = 0
 	messages = []
-	normalize = True
+	normalize = False # True
 	phonemize = should_phonemize()
 	lines = { 'training': [], 'validation': [] }
 	segments = {}
 
+	quantize_in_memory = args.tts_backend == "vall-e"
+
 	if args.tts_backend != "tortoise":
 		text_length = 0
 		audio_length = 0
 
+	start_offset = -0.1
+	end_offset = 0.1
+	trim_silence = False
+
+	TARGET_SAMPLE_RATE = 22050
+	if args.tts_backend != "tortoise":
+		TARGET_SAMPLE_RATE = 24000
+	if tts:
+		TARGET_SAMPLE_RATE = tts.input_sample_rate
+
 	for filename in tqdm(results, desc="Parsing results"):
 		use_segment = use_segments
 
+		extension = os.path.splitext(filename)[-1][1:]
+		out_extension = extension # "wav"
 		result = results[filename]
 		lang = result['language']
 		language = LANGUAGES[lang] if lang in LANGUAGES else lang
@@ -2481,8 +2518,8 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
 
 		# check if unsegmented text exceeds 200 characters
 		if not use_segment:
-			if len(result['text']) > 200:
-				message = f"Text length too long (200 < {len(result['text'])}), using segments: {filename}"
+			if len(result['text']) > MAX_TRAINING_CHAR_LENGTH:
+				message = f"Text length too long ({MAX_TRAINING_CHAR_LENGTH} < {len(result['text'])}), using segments: {filename}"
 				print(message)
 				messages.append(message)
 				use_segment = True
@@ -2490,13 +2527,15 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
 		# check if unsegmented audio exceeds 11.6s
 		if not use_segment:
 			path = f'{indir}/audio/{filename}'
-			if not os.path.exists(path):
+			if not quantize_in_memory and not os.path.exists(path):
 				messages.append(f"Missing source audio: {filename}")
 				errored += 1
 				continue
 
-			metadata = torchaudio.info(path)
-			duration = metadata.num_frames / metadata.sample_rate
+			duration = 0
+			for segment in result['segments']:
+				duration = max(duration, result['segments'][segment]['end'])
+
 			if duration >= MAX_TRAINING_DURATION:
 				message = f"Audio too large, using segments: {filename}"
 				print(message)
@@ -2511,19 +2550,36 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
 				if duration <= MIN_TRAINING_DURATION or MAX_TRAINING_DURATION <= duration:
 					continue
 
-				path = f'{indir}/audio/' + filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav")
+				path = f'{indir}/audio/' + filename.replace(f".{extension}", f"_{pad(segment['id'], 4)}.{out_extension}")
 				if os.path.exists(path):
 					continue
 				exists = False
 				break
 
-			if not exists:
+			if not quantize_in_memory and not exists:
 				tmp = {}
 				tmp[filename] = result
 				print(f"Audio not segmented, segmenting: {filename}")
 				message = slice_dataset( voice, results=tmp )
 				print(message)
 				messages = messages + message.split("\n")
+		
+		waveform = None
+		
+
+		if quantize_in_memory:
+			path = f'{indir}/audio/{filename}'
+			if not os.path.exists(path):
+				path = f'./voices/{voice}/{filename}'
+
+			if not os.path.exists(path):
+				message = f"Audio not found: {path}"
+				print(message)
+				messages.append(message)
+				#continue
+			else:
+				waveform = torchaudio.load(path)
+				waveform = resample(waveform[0], waveform[1], TARGET_SAMPLE_RATE)
 
 		if not use_segment:
 			segments[filename] = {
@@ -2533,13 +2589,18 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
 				'normalizer': normalizer,
 				'phonemes': result['phonemes'] if 'phonemes' in result else None
 			}
+
+			if waveform:
+				segments[filename]['waveform'] = waveform
 		else:
 			for segment in result['segments']:
 				duration = segment['end'] - segment['start']
 				if duration <= MIN_TRAINING_DURATION or MAX_TRAINING_DURATION <= duration:
 					continue
 
-				segments[filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav")] = {
+				file = filename.replace(f".{extension}", f"_{pad(segment['id'], 4)}.{out_extension}")
+
+				segments[file] = {
 					'text': segment['text'],
 					'lang': lang,
 					'language': language,
@@ -2547,22 +2608,27 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
 					'phonemes': segment['phonemes'] if 'phonemes' in segment else None
 				}
 
+				if waveform:
+					sliced, error = slice_waveform( waveform[0], waveform[1], segment['start'] + start_offset, segment['end'] + end_offset, trim_silence )
+					if error:
+						message = f"{error}, skipping... {file}"
+						print(message)
+						messages.append(message)
+						segments[file]['error'] = error
+						#continue
+					else:
+						segments[file]['waveform'] = (sliced, waveform[1])
+
 	jobs = {
 		'quantize':  [[], []],
 		'phonemize': [[], []],
 	}
 
 	for file in tqdm(segments, desc="Parsing segments"):
+		extension = os.path.splitext(file)[-1][1:]
 		result = segments[file]
 		path = f'{indir}/audio/{file}'
 
-		if not os.path.exists(path):
-			message = f"Missing segment, skipping... {file}"
-			print(message)
-			messages.append(message)
-			errored += 1
-			continue
-
 		text = result['text']
 		lang = result['lang']
 		language = result['language']
@@ -2573,28 +2639,20 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
 		
 		normalized = normalizer(text) if normalize else text
 
-		if len(text) > 200:
-			message = f"Text length too long (200 < {len(text)}), skipping... {file}"
+		if len(text) > MAX_TRAINING_CHAR_LENGTH:
+			message = f"Text length too long ({MAX_TRAINING_CHAR_LENGTH} < {len(text)}), skipping... {file}"
 			print(message)
 			messages.append(message)
 			errored += 1
 			continue
 
-		waveform, sample_rate = torchaudio.load(path)
-		num_channels, num_frames = waveform.shape
-		duration = num_frames / sample_rate
+		# num_channels, num_frames = waveform.shape
+		#duration = num_frames / sample_rate
 
-		error = validate_waveform( waveform, sample_rate )
-		if error:
-			message = f"{error}, skipping... {file}"
-			print(message)
-			messages.append(message)
-			errored += 1
-			continue
 
 		culled = len(text) < text_length
-		if not culled and audio_length > 0:
-			culled = duration < audio_length
+		#if not culled and audio_length > 0:
+		#	culled = duration < audio_length
 
 		line = f'audio/{file}|{phonemes if phonemize and phonemes else text}'
 
@@ -2605,17 +2663,8 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
 		
 		os.makedirs(f'{indir}/valle/', exist_ok=True)
 
-		qnt_file = f'{indir}/valle/{file.replace(".wav",".qnt.pt")}'
-		if not os.path.exists(qnt_file):
-			jobs['quantize'][0].append(qnt_file)
-			jobs['quantize'][1].append((waveform, sample_rate))
-			"""
-			quantized = valle_quantize( waveform, sample_rate ).cpu()
-			torch.save(quantized, f'{indir}/valle/{file.replace(".wav",".qnt.pt")}')
-			print("Quantized:", file)
-			"""
-
-		phn_file = f'{indir}/valle/{file.replace(".wav",".phn.txt")}'
+		#phn_file = f'{indir}/valle/{file.replace(f".{extension}",".phn.txt")}'
+		phn_file = f'./training/valle/data/{voice}/{file.replace(f".{extension}",".phn.txt")}'
 		if not os.path.exists(phn_file):
 			jobs['phonemize'][0].append(phn_file)
 			jobs['phonemize'][1].append(normalized)
@@ -2625,13 +2674,46 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
 			print("Phonemized:", file, normalized, text)
 			"""
 
+		#qnt_file = f'{indir}/valle/{file.replace(f".{extension}",".qnt.pt")}'
+		qnt_file = f'./training/valle/data/{voice}/{file.replace(f".{extension}",".qnt.pt")}'
+		if 'error' not in result:
+			if not quantize_in_memory and not os.path.exists(path):
+				message = f"Missing segment, skipping... {file}"
+				print(message)
+				messages.append(message)
+				errored += 1
+				continue
+
+		if not os.path.exists(qnt_file):
+			waveform = None
+			if 'waveform' in result:
+				waveform, sample_rate = result['waveform']
+			elif os.path.exists(path):
+				waveform, sample_rate = torchaudio.load(path)
+				error = validate_waveform( waveform, sample_rate )
+				if error:
+					message = f"{error}, skipping... {file}"
+					print(message)
+					messages.append(message)
+					errored += 1
+					continue
+
+			if waveform is not None:
+				jobs['quantize'][0].append(qnt_file)
+				jobs['quantize'][1].append((waveform, sample_rate))
+				"""
+				quantized = valle_quantize( waveform, sample_rate ).cpu()
+				torch.save(quantized, f'{indir}/valle/{file.replace(".wav",".qnt.pt")}')
+				print("Quantized:", file)
+				"""
+
 	for i in tqdm(range(len(jobs['quantize'][0])), desc="Quantizing"):
 		qnt_file = jobs['quantize'][0][i]
 		waveform, sample_rate = jobs['quantize'][1][i]
 
 		quantized = valle_quantize( waveform, sample_rate ).cpu()
 		torch.save(quantized, qnt_file)
-		print("Quantized:", qnt_file)
+		#print("Quantized:", qnt_file)
 
 	for i in tqdm(range(len(jobs['phonemize'][0])), desc="Phonemizing"):
 		phn_file = jobs['phonemize'][0][i]
@@ -2640,7 +2722,7 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
 		try:
 			phonemized = valle_phonemize( normalized )
 			open(phn_file, 'w', encoding='utf-8').write(" ".join(phonemized))
-			print("Phonemized:", phn_file)
+			#print("Phonemized:", phn_file)
 		except Exception as e:
 			message = f"Failed to phonemize: {phn_file}: {normalized}"
 			messages.append(message)
@@ -2980,7 +3062,7 @@ def get_voice( name, dir=get_voice_dir(), load_latents=True ):
 		voice = voice + list(glob(f'{subj}/*.pth'))
 	return sorted( voice )
 
-def get_voice_list(dir=get_voice_dir(), append_defaults=False):
+def get_voice_list(dir=get_voice_dir(), append_defaults=False, extensions=["wav", "mp3", "flac", "pth"]):
 	defaults = [ "random", "microphone" ]
 	os.makedirs(dir, exist_ok=True)
 	#res = sorted([d for d in os.listdir(dir) if d not in defaults and os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 ])
@@ -2993,7 +3075,7 @@ def get_voice_list(dir=get_voice_dir(), append_defaults=False):
 			continue
 		if len(os.listdir(os.path.join(dir, name))) == 0:
 			continue
-		files = get_voice( name, dir=dir )
+		files = get_voice( name, dir=dir, extensions=extensions )
 
 		if len(files) > 0:
 			res.append(name)
@@ -3001,7 +3083,7 @@ def get_voice_list(dir=get_voice_dir(), append_defaults=False):
 			for subdir in os.listdir(f'{dir}/{name}'):
 				if not os.path.isdir(f'{dir}/{name}/{subdir}'):
 					continue
-				files = get_voice( f'{name}/{subdir}', dir=dir )
+				files = get_voice( f'{name}/{subdir}', dir=dir, extensions=extensions )
 				if len(files) == 0:
 					continue
 				res.append(f'{name}/{subdir}')