From da9b4b5fb541479bc23416472e4009e92176aecc Mon Sep 17 00:00:00 2001
From: mrq <mrq@ecker.tech>
Date: Sat, 18 Mar 2023 15:14:22 +0000
Subject: [PATCH] tweaks

---
 src/utils.py | 21 ++++++++++++++-------
 1 file changed, 14 insertions(+), 7 deletions(-)

diff --git a/src/utils.py b/src/utils.py
index d93f791..52a294a 100755
--- a/src/utils.py
+++ b/src/utils.py
@@ -702,7 +702,7 @@ class TrainingState():
 
 	def spawn_process(self, config_path, gpus=1):
 		if args.tts_backend == "vall-e":
-			self.cmd = ['torchrun', '--nproc_per_node', f'{gpus}', '-m', 'vall_e.train', f'yaml="{config_path}"']
+			self.cmd = ['deepspeed', f'--num_gpus={gpus}', '--module', 'vall_e.train', f'yaml="{config_path}"']
 		else:
 			self.cmd = ['train.bat', config_path] if os.name == "nt" else ['./train.sh', config_path]
 
@@ -1358,7 +1358,7 @@ def create_dataset_json( path ):
 
 def phonemizer( text, language="en-us" ):
 	from phonemizer import phonemize
-	if language == "english":
+	if language == "en":
 		language = "en-us"
 	return phonemize( text, language=language, strip=True, preserve_punctuation=True, with_stress=True, backend=args.phonemizer_backend )
 
@@ -1393,7 +1393,8 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
 		use_segment = use_segments
 
 		result = results[filename]
-		language = LANGUAGES[result['language']] if result['language'] in LANGUAGES else None
+		lang = result['language']
+		language = LANGUAGES[lang] if lang in LANGUAGES else lang
 		normalizer = EnglishTextNormalizer() if language and language == "english" else BasicTextNormalizer()
 
 		# check if unsegmented text exceeds 200 characters
@@ -1445,6 +1446,7 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
 		if not use_segment:
 			segments[filename] = {
 				'text': result['text'],
+				'lang': lang,
 				'language': language,
 				'normalizer': normalizer,
 				'phonemes': result['phonemes'] if 'phonemes' in result else None
@@ -1457,6 +1459,7 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
 
 				segments[filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav")] = {
 					'text': segment['text'],
+					'lang': lang,
 					'language': language,
 					'normalizer': normalizer,
 					'phonemes': segment['phonemes'] if 'phonemes' in segment else None
@@ -1467,11 +1470,12 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
 		path = f'{indir}/audio/{file}'
 
 		text = result['text']
+		lang = result['lang']
 		language = result['language']
 		normalizer = result['normalizer']
 		phonemes = result['phonemes']
 		if phonemize and phonemes is None:
-			phonemes = phonemizer( text, language=language )
+			phonemes = phonemizer( text, language=lang )
 		if phonemize:
 			text = phonemes
 
@@ -1514,7 +1518,7 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
 		torch.save(quantized, f'{indir}/valle/{file.replace(".wav",".qnt.pt")}')
 		print("Quantized:", file)
 
-		tokens = tokenize_text(text, stringed=False, skip_specials=True)
+		tokens = tokenize_text(text, config="./models/tokenizers/ipa.json", stringed=False, skip_specials=True)
 		tokenized = " ".join( tokens )
 		tokenized = tokenized.replace(" \u02C8", "\u02C8")
 		tokenized = tokenized.replace(" \u02CC", "\u02CC")
@@ -1888,11 +1892,14 @@ def get_tokenizer_jsons( dir="./models/tokenizers/" ):
 	additionals = sorted([ f'{dir}/{d}' for d in os.listdir(dir) if d[-5:] == ".json" ]) if os.path.isdir(dir) else []
 	return relative_paths([ "./modules/tortoise-tts/tortoise/data/tokenizer.json" ] + additionals)
 
-def tokenize_text( text, stringed=True, skip_specials=False ):
+def tokenize_text( text, config=None, stringed=True, skip_specials=False ):
 	from tortoise.utils.tokenizer import VoiceBpeTokenizer
 
+	if not config:
+		config = args.tokenizer_json if args.tokenizer_json else get_tokenizer_jsons()[0]
+
 	if not tts:
-		tokenizer = VoiceBpeTokenizer(args.tokenizer_json if args.tokenizer_json else get_tokenizer_jsons()[0])
+		tokenizer = VoiceBpeTokenizer(config)
 	else:
 		tokenizer = tts.tokenizer