From ee8270bdfb0c9ba98cd6abdd1280cf057972548a Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 16 Mar 2023 04:25:33 +0000 Subject: [PATCH] preparations for training an IPA-based finetune --- models/tokenizers/ipa.json | 121 +++++++++++++++++++++++++++++++++++++ modules/dlas | 2 +- modules/tortoise-tts | 2 +- src/utils.py | 67 ++++++++++++++------ src/webui.py | 3 +- 5 files changed, 172 insertions(+), 23 deletions(-) create mode 100755 models/tokenizers/ipa.json diff --git a/models/tokenizers/ipa.json b/models/tokenizers/ipa.json new file mode 100755 index 0000000..4a8d78d --- /dev/null +++ b/models/tokenizers/ipa.json @@ -0,0 +1,121 @@ +{ + "version": "1.0", + "truncation": null, + "padding": null, + "added_tokens": + [ + { + "id": 0, + "special": true, + "content": "[STOP]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 1, + "special": true, + "content": "[UNK]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 2, + "special": true, + "content": "[SPACE]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + } + ], + "normalizer": null, + "pre_tokenizer": null, + "post_processor": null, + "decoder": null, + "model": + { + "type": "BPE", + "dropout": null, + "unk_token": "[UNK]", + "continuing_subword_prefix": null, + "end_of_word_suffix": null, + "fuse_unk": false, + "vocab": + { + "[STOP]": 0, + "[UNK]": 1, + "[SPACE]": 2, + "!": 3, + "'": 4, + "(": 5, + ")": 6, + ",": 7, + "-": 8, + ".": 9, + "/": 10, + ":": 11, + ";": 12, + "?": 13, + "a": 14, + "aɪ": 15, + "aʊ": 16, + "b": 17, + "d": 18, + "d͡": 19, + "d͡ʒ": 20, + "e": 21, + "eɪ": 22, + "f": 23, + "h": 24, + "i": 25, + "j": 26, + "k": 27, + "l": 28, + "m": 29, + "n": 30, + "o": 31, + "oʊ": 32, + "p": 33, + "s": 34, + "t": 35, + "t͡": 36, + "t͡ʃ": 37, + "u": 38, + "v": 39, + "w": 40, + "z": 41, + "|": 42, + "æ": 43, + "ð": 44, + "ŋ": 45, + "ɑ": 46, + "ɔ": 47, + "ɔɪ": 48, + "ə": 49, + "ɚ": 50, + "ɛ": 51, + "ɡ": 52, + "ɪ": 53, + "ɹ": 54, + "ʃ": 55, + "ʊ": 56, + "ʌ": 57, + "ʒ": 58, + "θ": 59 + }, + "merges": + [ + "a ɪ", + "a ʊ", + "d͡ ʒ", + "e ɪ", + "o ʊ", + "t͡ ʃ", + "ɔ ɪ" + ] + } +} \ No newline at end of file diff --git a/modules/dlas b/modules/dlas index b253da6..730a047 160000 --- a/modules/dlas +++ b/modules/dlas @@ -1 +1 @@ -Subproject commit b253da6e353f0170c3eb60fe299c41d2fa21db50 +Subproject commit 730a04708d2cb29f526c3397894950a2733e6e29 diff --git a/modules/tortoise-tts b/modules/tortoise-tts index 42cb1f3..9961869 160000 --- a/modules/tortoise-tts +++ b/modules/tortoise-tts @@ -1 +1 @@ -Subproject commit 42cb1f36741aa3a24e7aab03e73b51becd182fa7 +Subproject commit 99618694db4cd7b77e68b62753bb8e2418ac0d55 diff --git a/src/utils.py b/src/utils.py index 3c375a1..f5fcf5a 100755 --- a/src/utils.py +++ b/src/utils.py @@ -20,8 +20,7 @@ import subprocess import psutil import yaml import hashlib -import io -import gzip +import string import tqdm import torch @@ -40,6 +39,13 @@ from tortoise.utils.text import split_and_recombine_text from tortoise.utils.device import get_device_name, set_device_name, get_device_count, get_device_vram, get_device_batch_size, do_gc from whisper.normalizers.english import EnglishTextNormalizer +from whisper.normalizers.basic import BasicTextNormalizer +from whisper.tokenizer import LANGUAGES + +try: + from phonemizer import phonemize as phonemizer +except Exception as e: + pass MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth" @@ -64,7 +70,7 @@ VALLE_ENABLED = False try: from vall_e.emb.qnt import encode as quantize - from vall_e.emb.g2p import encode as phonemize + # from vall_e.emb.g2p import encode as phonemize VALLE_ENABLED = True except Exception as e: @@ -1157,7 +1163,6 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non if whisper_model is None: load_whisper_model(language=language) - results = {} files = sorted( get_voices(load_latents=False)[voice] ) @@ -1175,14 +1180,15 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non if basename in results and skip_existings: print(f"Skipping already parsed file: {basename}") else: - results[basename] = whisper_transcribe(file, language=language) + result = whisper_transcribe(file, language=language) + results[basename] = result waveform, sample_rate = torchaudio.load(file) # resample to the input rate, since it'll get resampled for training anyways # this should also "help" increase throughput a bit when filling the dataloaders waveform, sample_rate = resample(waveform, sample_rate, tts.input_sample_rate if tts is not None else 22050) - torchaudio.save(f"{indir}/audio/{basename}", waveform, sample_rate) + torchaudio.save(f"{indir}/audio/{basename}", waveform, sample_rate, encoding="PCM_S", bits_per_sample=16) with open(infile, 'w', encoding="utf-8") as f: f.write(json.dumps(results, indent='\t')) @@ -1248,18 +1254,28 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, resul messages.append(message) continue sliced, _ = resample( sliced, sample_rate, 22050 ) - torchaudio.save(f"{indir}/audio/{file}", sliced, 22050) + torchaudio.save(f"{indir}/audio/{file}", sliced, 22050, encoding="PCM_S", bits_per_sample=16) segments +=1 messages.append(f"Sliced segments: {files} => {segments}.") return "\n".join(messages) -def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=True ): +""" +def phonemizer( text, language="eng" ): + transducer = make_g2p(language, f'{language}-ipa') + phones = transducer(text).output_string + ignored = [" "] + [ p for p in string.punctuation ] + return ["_" if p in ignored else p for p in phones] +""" + +def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, normalize=True ): indir = f'./training/{voice}/' infile = f'{indir}/whisper.json' messages = [] + phonemize = phonemize=args.tokenizer_json[-8:] == "ipa.json" + if not os.path.exists(infile): raise Exception(f"Missing dataset: {infile}") @@ -1272,12 +1288,19 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T 'supervisions': [], } - normalizer = EnglishTextNormalizer() if normalize else None errored = 0 for filename in results: - result = results[filename] use_segment = use_segments + + result = results[filename] + language = LANGUAGES[result['language']] if result['language'] in LANGUAGES else None + if language == "english": + language = "en-us" + + normalizer = None + if normalize: + normalizer = EnglishTextNormalizer() if language.lower()[:2] == "en" else BasicTextNormalizer() # check if unsegmented text exceeds 200 characters if not use_segment: @@ -1325,7 +1348,14 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T continue text = segment['text'].strip() - normalized_text = normalizer(text) if normalize and result['language'] == "en" else text + normalized_text = normalizer(text) if normalize else None + try: + phonemes = phonemizer( text, language=language, preserve_punctuation=True, strip=True ) if phonemize else None + except Exception as e: + pass + + if phonemize and phonemes: + text = phonemes if len(text) > 200: message = f"Text length too long (200 < {len(text)}), skipping... {file}" @@ -1351,11 +1381,7 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T if not culled and audio_length > 0: culled = duration < audio_length - # for when i add in a little treat ;), as it requires normalized text - if normalize and len(normalized_text) < 200: - line = f'audio/{file}|{text}|{normalized_text}' - else: - line = f'audio/{file}|{text}' + line = f'audio/{file}|{text}' lines['training' if not culled else 'validation'].append(line) @@ -1365,7 +1391,7 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T os.makedirs(f'{indir}/valle/', exist_ok=True) from vall_e.emb.qnt import encode as quantize - from vall_e.emb.g2p import encode as phonemize + # from vall_e.emb.g2p import encode as phonemize if waveform.shape[0] == 2: waveform = waveform[:1] @@ -1373,8 +1399,8 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T quantized = quantize( waveform, sample_rate ).cpu() torch.save(quantized, f'{indir}/valle/{file.replace(".wav",".qnt.pt")}') - phonemes = phonemize(normalized_text) - open(f'{indir}/valle/{file.replace(".wav",".phn.txt")}', 'w', encoding='utf-8').write(" ".join(phonemes)) + # phonemes = phonemizer(normalized_text) + open(f'{indir}/valle/{file.replace(".wav",".phn.txt")}', 'w', encoding='utf-8').write(" ".join(text)) training_joined = "\n".join(lines['training']) validation_joined = "\n".join(lines['validation']) @@ -1536,8 +1562,10 @@ def save_training_settings( **kwargs ): if settings['save_rate'] < 1: settings['save_rate'] = 1 + """ if settings['validation_rate'] < 1: settings['validation_rate'] = 1 + """ settings['validation_batch_size'] = int(settings['batch_size'] / settings['gradient_accumulation_size']) @@ -1554,7 +1582,6 @@ def save_training_settings( **kwargs ): settings['validation_enabled'] = False messages.append("Validation batch size == 0, disabling validation...") else: - settings['validation_enabled'] = True with open(settings['validation_path'], 'r', encoding="utf-8") as f: validation_lines = len(f.readlines()) diff --git a/src/webui.py b/src/webui.py index 48c16a9..6f6462a 100755 --- a/src/webui.py +++ b/src/webui.py @@ -443,7 +443,7 @@ def setup_gradio(): DATASET_SETTINGS['validation_text_length'] = gr.Number(label="Validation Text Length Threshold", value=12, precision=0) DATASET_SETTINGS['validation_audio_length'] = gr.Number(label="Validation Audio Length Threshold", value=1 ) with gr.Row(): - DATASET_SETTINGS['skip'] = gr.Checkbox(label="Skip Already Transcribed", value=False) + DATASET_SETTINGS['skip'] = gr.Checkbox(label="Skip Existing", value=False) DATASET_SETTINGS['slice'] = gr.Checkbox(label="Slice Segments", value=False) DATASET_SETTINGS['trim_silence'] = gr.Checkbox(label="Trim Silence", value=False) with gr.Row(): @@ -496,6 +496,7 @@ def setup_gradio(): with gr.Row(): TRAINING_SETTINGS["half_p"] = gr.Checkbox(label="Half Precision", value=args.training_default_halfp) TRAINING_SETTINGS["bitsandbytes"] = gr.Checkbox(label="BitsAndBytes", value=args.training_default_bnb) + TRAINING_SETTINGS["validation_enabled"] = gr.Checkbox(label="Validation Enabled", value=False) with gr.Row(): TRAINING_SETTINGS["workers"] = gr.Number(label="Worker Processes", value=2, precision=0)