This commit is contained in:
mrq 2023-03-18 15:14:22 +00:00
parent f44895978d
commit da9b4b5fb5

View File

@ -702,7 +702,7 @@ class TrainingState():
def spawn_process(self, config_path, gpus=1):
if args.tts_backend == "vall-e":
self.cmd = ['torchrun', '--nproc_per_node', f'{gpus}', '-m', 'vall_e.train', f'yaml="{config_path}"']
self.cmd = ['deepspeed', f'--num_gpus={gpus}', '--module', 'vall_e.train', f'yaml="{config_path}"']
else:
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['./train.sh', config_path]
@ -1358,7 +1358,7 @@ def create_dataset_json( path ):
def phonemizer( text, language="en-us" ):
from phonemizer import phonemize
if language == "english":
if language == "en":
language = "en-us"
return phonemize( text, language=language, strip=True, preserve_punctuation=True, with_stress=True, backend=args.phonemizer_backend )
@ -1393,7 +1393,8 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
use_segment = use_segments
result = results[filename]
language = LANGUAGES[result['language']] if result['language'] in LANGUAGES else None
lang = result['language']
language = LANGUAGES[lang] if lang in LANGUAGES else lang
normalizer = EnglishTextNormalizer() if language and language == "english" else BasicTextNormalizer()
# check if unsegmented text exceeds 200 characters
@ -1445,6 +1446,7 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
if not use_segment:
segments[filename] = {
'text': result['text'],
'lang': lang,
'language': language,
'normalizer': normalizer,
'phonemes': result['phonemes'] if 'phonemes' in result else None
@ -1457,6 +1459,7 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
segments[filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav")] = {
'text': segment['text'],
'lang': lang,
'language': language,
'normalizer': normalizer,
'phonemes': segment['phonemes'] if 'phonemes' in segment else None
@ -1467,11 +1470,12 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
path = f'{indir}/audio/{file}'
text = result['text']
lang = result['lang']
language = result['language']
normalizer = result['normalizer']
phonemes = result['phonemes']
if phonemize and phonemes is None:
phonemes = phonemizer( text, language=language )
phonemes = phonemizer( text, language=lang )
if phonemize:
text = phonemes
@ -1514,7 +1518,7 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
torch.save(quantized, f'{indir}/valle/{file.replace(".wav",".qnt.pt")}')
print("Quantized:", file)
tokens = tokenize_text(text, stringed=False, skip_specials=True)
tokens = tokenize_text(text, config="./models/tokenizers/ipa.json", stringed=False, skip_specials=True)
tokenized = " ".join( tokens )
tokenized = tokenized.replace(" \u02C8", "\u02C8")
tokenized = tokenized.replace(" \u02CC", "\u02CC")
@ -1888,11 +1892,14 @@ def get_tokenizer_jsons( dir="./models/tokenizers/" ):
additionals = sorted([ f'{dir}/{d}' for d in os.listdir(dir) if d[-5:] == ".json" ]) if os.path.isdir(dir) else []
return relative_paths([ "./modules/tortoise-tts/tortoise/data/tokenizer.json" ] + additionals)
def tokenize_text( text, stringed=True, skip_specials=False ):
def tokenize_text( text, config=None, stringed=True, skip_specials=False ):
from tortoise.utils.tokenizer import VoiceBpeTokenizer
if not config:
config = args.tokenizer_json if args.tokenizer_json else get_tokenizer_jsons()[0]
if not tts:
tokenizer = VoiceBpeTokenizer(args.tokenizer_json if args.tokenizer_json else get_tokenizer_jsons()[0])
tokenizer = VoiceBpeTokenizer(config)
else:
tokenizer = tts.tokenizer