forked from mrq/ai-voice-cloning
preparations for training an IPA-based finetune
This commit is contained in:
parent
7b80f7a42f
commit
ee8270bdfb
121
models/tokenizers/ipa.json
Executable file
121
models/tokenizers/ipa.json
Executable file
|
@ -0,0 +1,121 @@
|
||||||
|
{
|
||||||
|
"version": "1.0",
|
||||||
|
"truncation": null,
|
||||||
|
"padding": null,
|
||||||
|
"added_tokens":
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": 0,
|
||||||
|
"special": true,
|
||||||
|
"content": "[STOP]",
|
||||||
|
"single_word": false,
|
||||||
|
"lstrip": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"normalized": false
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"special": true,
|
||||||
|
"content": "[UNK]",
|
||||||
|
"single_word": false,
|
||||||
|
"lstrip": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"normalized": false
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"special": true,
|
||||||
|
"content": "[SPACE]",
|
||||||
|
"single_word": false,
|
||||||
|
"lstrip": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"normalized": false
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"normalizer": null,
|
||||||
|
"pre_tokenizer": null,
|
||||||
|
"post_processor": null,
|
||||||
|
"decoder": null,
|
||||||
|
"model":
|
||||||
|
{
|
||||||
|
"type": "BPE",
|
||||||
|
"dropout": null,
|
||||||
|
"unk_token": "[UNK]",
|
||||||
|
"continuing_subword_prefix": null,
|
||||||
|
"end_of_word_suffix": null,
|
||||||
|
"fuse_unk": false,
|
||||||
|
"vocab":
|
||||||
|
{
|
||||||
|
"[STOP]": 0,
|
||||||
|
"[UNK]": 1,
|
||||||
|
"[SPACE]": 2,
|
||||||
|
"!": 3,
|
||||||
|
"'": 4,
|
||||||
|
"(": 5,
|
||||||
|
")": 6,
|
||||||
|
",": 7,
|
||||||
|
"-": 8,
|
||||||
|
".": 9,
|
||||||
|
"/": 10,
|
||||||
|
":": 11,
|
||||||
|
";": 12,
|
||||||
|
"?": 13,
|
||||||
|
"a": 14,
|
||||||
|
"aɪ": 15,
|
||||||
|
"aʊ": 16,
|
||||||
|
"b": 17,
|
||||||
|
"d": 18,
|
||||||
|
"d͡": 19,
|
||||||
|
"d͡ʒ": 20,
|
||||||
|
"e": 21,
|
||||||
|
"eɪ": 22,
|
||||||
|
"f": 23,
|
||||||
|
"h": 24,
|
||||||
|
"i": 25,
|
||||||
|
"j": 26,
|
||||||
|
"k": 27,
|
||||||
|
"l": 28,
|
||||||
|
"m": 29,
|
||||||
|
"n": 30,
|
||||||
|
"o": 31,
|
||||||
|
"oʊ": 32,
|
||||||
|
"p": 33,
|
||||||
|
"s": 34,
|
||||||
|
"t": 35,
|
||||||
|
"t͡": 36,
|
||||||
|
"t͡ʃ": 37,
|
||||||
|
"u": 38,
|
||||||
|
"v": 39,
|
||||||
|
"w": 40,
|
||||||
|
"z": 41,
|
||||||
|
"|": 42,
|
||||||
|
"æ": 43,
|
||||||
|
"ð": 44,
|
||||||
|
"ŋ": 45,
|
||||||
|
"ɑ": 46,
|
||||||
|
"ɔ": 47,
|
||||||
|
"ɔɪ": 48,
|
||||||
|
"ə": 49,
|
||||||
|
"ɚ": 50,
|
||||||
|
"ɛ": 51,
|
||||||
|
"ɡ": 52,
|
||||||
|
"ɪ": 53,
|
||||||
|
"ɹ": 54,
|
||||||
|
"ʃ": 55,
|
||||||
|
"ʊ": 56,
|
||||||
|
"ʌ": 57,
|
||||||
|
"ʒ": 58,
|
||||||
|
"θ": 59
|
||||||
|
},
|
||||||
|
"merges":
|
||||||
|
[
|
||||||
|
"a ɪ",
|
||||||
|
"a ʊ",
|
||||||
|
"d͡ ʒ",
|
||||||
|
"e ɪ",
|
||||||
|
"o ʊ",
|
||||||
|
"t͡ ʃ",
|
||||||
|
"ɔ ɪ"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
|
@ -1 +1 @@
|
||||||
Subproject commit b253da6e353f0170c3eb60fe299c41d2fa21db50
|
Subproject commit 730a04708d2cb29f526c3397894950a2733e6e29
|
|
@ -1 +1 @@
|
||||||
Subproject commit 42cb1f36741aa3a24e7aab03e73b51becd182fa7
|
Subproject commit 99618694db4cd7b77e68b62753bb8e2418ac0d55
|
67
src/utils.py
67
src/utils.py
|
@ -20,8 +20,7 @@ import subprocess
|
||||||
import psutil
|
import psutil
|
||||||
import yaml
|
import yaml
|
||||||
import hashlib
|
import hashlib
|
||||||
import io
|
import string
|
||||||
import gzip
|
|
||||||
|
|
||||||
import tqdm
|
import tqdm
|
||||||
import torch
|
import torch
|
||||||
|
@ -40,6 +39,13 @@ from tortoise.utils.text import split_and_recombine_text
|
||||||
from tortoise.utils.device import get_device_name, set_device_name, get_device_count, get_device_vram, get_device_batch_size, do_gc
|
from tortoise.utils.device import get_device_name, set_device_name, get_device_count, get_device_vram, get_device_batch_size, do_gc
|
||||||
|
|
||||||
from whisper.normalizers.english import EnglishTextNormalizer
|
from whisper.normalizers.english import EnglishTextNormalizer
|
||||||
|
from whisper.normalizers.basic import BasicTextNormalizer
|
||||||
|
from whisper.tokenizer import LANGUAGES
|
||||||
|
|
||||||
|
try:
|
||||||
|
from phonemizer import phonemize as phonemizer
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth"
|
MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth"
|
||||||
|
|
||||||
|
@ -64,7 +70,7 @@ VALLE_ENABLED = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from vall_e.emb.qnt import encode as quantize
|
from vall_e.emb.qnt import encode as quantize
|
||||||
from vall_e.emb.g2p import encode as phonemize
|
# from vall_e.emb.g2p import encode as phonemize
|
||||||
|
|
||||||
VALLE_ENABLED = True
|
VALLE_ENABLED = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -1157,7 +1163,6 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non
|
||||||
if whisper_model is None:
|
if whisper_model is None:
|
||||||
load_whisper_model(language=language)
|
load_whisper_model(language=language)
|
||||||
|
|
||||||
|
|
||||||
results = {}
|
results = {}
|
||||||
|
|
||||||
files = sorted( get_voices(load_latents=False)[voice] )
|
files = sorted( get_voices(load_latents=False)[voice] )
|
||||||
|
@ -1175,14 +1180,15 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non
|
||||||
if basename in results and skip_existings:
|
if basename in results and skip_existings:
|
||||||
print(f"Skipping already parsed file: {basename}")
|
print(f"Skipping already parsed file: {basename}")
|
||||||
else:
|
else:
|
||||||
results[basename] = whisper_transcribe(file, language=language)
|
result = whisper_transcribe(file, language=language)
|
||||||
|
results[basename] = result
|
||||||
|
|
||||||
waveform, sample_rate = torchaudio.load(file)
|
waveform, sample_rate = torchaudio.load(file)
|
||||||
# resample to the input rate, since it'll get resampled for training anyways
|
# resample to the input rate, since it'll get resampled for training anyways
|
||||||
# this should also "help" increase throughput a bit when filling the dataloaders
|
# this should also "help" increase throughput a bit when filling the dataloaders
|
||||||
waveform, sample_rate = resample(waveform, sample_rate, tts.input_sample_rate if tts is not None else 22050)
|
waveform, sample_rate = resample(waveform, sample_rate, tts.input_sample_rate if tts is not None else 22050)
|
||||||
|
|
||||||
torchaudio.save(f"{indir}/audio/{basename}", waveform, sample_rate)
|
torchaudio.save(f"{indir}/audio/{basename}", waveform, sample_rate, encoding="PCM_S", bits_per_sample=16)
|
||||||
|
|
||||||
with open(infile, 'w', encoding="utf-8") as f:
|
with open(infile, 'w', encoding="utf-8") as f:
|
||||||
f.write(json.dumps(results, indent='\t'))
|
f.write(json.dumps(results, indent='\t'))
|
||||||
|
@ -1248,18 +1254,28 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, resul
|
||||||
messages.append(message)
|
messages.append(message)
|
||||||
continue
|
continue
|
||||||
sliced, _ = resample( sliced, sample_rate, 22050 )
|
sliced, _ = resample( sliced, sample_rate, 22050 )
|
||||||
torchaudio.save(f"{indir}/audio/{file}", sliced, 22050)
|
torchaudio.save(f"{indir}/audio/{file}", sliced, 22050, encoding="PCM_S", bits_per_sample=16)
|
||||||
|
|
||||||
segments +=1
|
segments +=1
|
||||||
|
|
||||||
messages.append(f"Sliced segments: {files} => {segments}.")
|
messages.append(f"Sliced segments: {files} => {segments}.")
|
||||||
return "\n".join(messages)
|
return "\n".join(messages)
|
||||||
|
|
||||||
def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=True ):
|
"""
|
||||||
|
def phonemizer( text, language="eng" ):
|
||||||
|
transducer = make_g2p(language, f'{language}-ipa')
|
||||||
|
phones = transducer(text).output_string
|
||||||
|
ignored = [" "] + [ p for p in string.punctuation ]
|
||||||
|
return ["_" if p in ignored else p for p in phones]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, normalize=True ):
|
||||||
indir = f'./training/{voice}/'
|
indir = f'./training/{voice}/'
|
||||||
infile = f'{indir}/whisper.json'
|
infile = f'{indir}/whisper.json'
|
||||||
messages = []
|
messages = []
|
||||||
|
|
||||||
|
phonemize = phonemize=args.tokenizer_json[-8:] == "ipa.json"
|
||||||
|
|
||||||
if not os.path.exists(infile):
|
if not os.path.exists(infile):
|
||||||
raise Exception(f"Missing dataset: {infile}")
|
raise Exception(f"Missing dataset: {infile}")
|
||||||
|
|
||||||
|
@ -1272,13 +1288,20 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
|
||||||
'supervisions': [],
|
'supervisions': [],
|
||||||
}
|
}
|
||||||
|
|
||||||
normalizer = EnglishTextNormalizer() if normalize else None
|
|
||||||
|
|
||||||
errored = 0
|
errored = 0
|
||||||
for filename in results:
|
for filename in results:
|
||||||
result = results[filename]
|
|
||||||
use_segment = use_segments
|
use_segment = use_segments
|
||||||
|
|
||||||
|
result = results[filename]
|
||||||
|
language = LANGUAGES[result['language']] if result['language'] in LANGUAGES else None
|
||||||
|
if language == "english":
|
||||||
|
language = "en-us"
|
||||||
|
|
||||||
|
normalizer = None
|
||||||
|
if normalize:
|
||||||
|
normalizer = EnglishTextNormalizer() if language.lower()[:2] == "en" else BasicTextNormalizer()
|
||||||
|
|
||||||
# check if unsegmented text exceeds 200 characters
|
# check if unsegmented text exceeds 200 characters
|
||||||
if not use_segment:
|
if not use_segment:
|
||||||
if len(result['text']) > 200:
|
if len(result['text']) > 200:
|
||||||
|
@ -1325,7 +1348,14 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
|
||||||
continue
|
continue
|
||||||
|
|
||||||
text = segment['text'].strip()
|
text = segment['text'].strip()
|
||||||
normalized_text = normalizer(text) if normalize and result['language'] == "en" else text
|
normalized_text = normalizer(text) if normalize else None
|
||||||
|
try:
|
||||||
|
phonemes = phonemizer( text, language=language, preserve_punctuation=True, strip=True ) if phonemize else None
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if phonemize and phonemes:
|
||||||
|
text = phonemes
|
||||||
|
|
||||||
if len(text) > 200:
|
if len(text) > 200:
|
||||||
message = f"Text length too long (200 < {len(text)}), skipping... {file}"
|
message = f"Text length too long (200 < {len(text)}), skipping... {file}"
|
||||||
|
@ -1351,11 +1381,7 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
|
||||||
if not culled and audio_length > 0:
|
if not culled and audio_length > 0:
|
||||||
culled = duration < audio_length
|
culled = duration < audio_length
|
||||||
|
|
||||||
# for when i add in a little treat ;), as it requires normalized text
|
line = f'audio/{file}|{text}'
|
||||||
if normalize and len(normalized_text) < 200:
|
|
||||||
line = f'audio/{file}|{text}|{normalized_text}'
|
|
||||||
else:
|
|
||||||
line = f'audio/{file}|{text}'
|
|
||||||
|
|
||||||
lines['training' if not culled else 'validation'].append(line)
|
lines['training' if not culled else 'validation'].append(line)
|
||||||
|
|
||||||
|
@ -1365,7 +1391,7 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
|
||||||
os.makedirs(f'{indir}/valle/', exist_ok=True)
|
os.makedirs(f'{indir}/valle/', exist_ok=True)
|
||||||
|
|
||||||
from vall_e.emb.qnt import encode as quantize
|
from vall_e.emb.qnt import encode as quantize
|
||||||
from vall_e.emb.g2p import encode as phonemize
|
# from vall_e.emb.g2p import encode as phonemize
|
||||||
|
|
||||||
if waveform.shape[0] == 2:
|
if waveform.shape[0] == 2:
|
||||||
waveform = waveform[:1]
|
waveform = waveform[:1]
|
||||||
|
@ -1373,8 +1399,8 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
|
||||||
quantized = quantize( waveform, sample_rate ).cpu()
|
quantized = quantize( waveform, sample_rate ).cpu()
|
||||||
torch.save(quantized, f'{indir}/valle/{file.replace(".wav",".qnt.pt")}')
|
torch.save(quantized, f'{indir}/valle/{file.replace(".wav",".qnt.pt")}')
|
||||||
|
|
||||||
phonemes = phonemize(normalized_text)
|
# phonemes = phonemizer(normalized_text)
|
||||||
open(f'{indir}/valle/{file.replace(".wav",".phn.txt")}', 'w', encoding='utf-8').write(" ".join(phonemes))
|
open(f'{indir}/valle/{file.replace(".wav",".phn.txt")}', 'w', encoding='utf-8').write(" ".join(text))
|
||||||
|
|
||||||
training_joined = "\n".join(lines['training'])
|
training_joined = "\n".join(lines['training'])
|
||||||
validation_joined = "\n".join(lines['validation'])
|
validation_joined = "\n".join(lines['validation'])
|
||||||
|
@ -1536,8 +1562,10 @@ def save_training_settings( **kwargs ):
|
||||||
|
|
||||||
if settings['save_rate'] < 1:
|
if settings['save_rate'] < 1:
|
||||||
settings['save_rate'] = 1
|
settings['save_rate'] = 1
|
||||||
|
"""
|
||||||
if settings['validation_rate'] < 1:
|
if settings['validation_rate'] < 1:
|
||||||
settings['validation_rate'] = 1
|
settings['validation_rate'] = 1
|
||||||
|
"""
|
||||||
|
|
||||||
settings['validation_batch_size'] = int(settings['batch_size'] / settings['gradient_accumulation_size'])
|
settings['validation_batch_size'] = int(settings['batch_size'] / settings['gradient_accumulation_size'])
|
||||||
|
|
||||||
|
@ -1554,7 +1582,6 @@ def save_training_settings( **kwargs ):
|
||||||
settings['validation_enabled'] = False
|
settings['validation_enabled'] = False
|
||||||
messages.append("Validation batch size == 0, disabling validation...")
|
messages.append("Validation batch size == 0, disabling validation...")
|
||||||
else:
|
else:
|
||||||
settings['validation_enabled'] = True
|
|
||||||
with open(settings['validation_path'], 'r', encoding="utf-8") as f:
|
with open(settings['validation_path'], 'r', encoding="utf-8") as f:
|
||||||
validation_lines = len(f.readlines())
|
validation_lines = len(f.readlines())
|
||||||
|
|
||||||
|
|
|
@ -443,7 +443,7 @@ def setup_gradio():
|
||||||
DATASET_SETTINGS['validation_text_length'] = gr.Number(label="Validation Text Length Threshold", value=12, precision=0)
|
DATASET_SETTINGS['validation_text_length'] = gr.Number(label="Validation Text Length Threshold", value=12, precision=0)
|
||||||
DATASET_SETTINGS['validation_audio_length'] = gr.Number(label="Validation Audio Length Threshold", value=1 )
|
DATASET_SETTINGS['validation_audio_length'] = gr.Number(label="Validation Audio Length Threshold", value=1 )
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
DATASET_SETTINGS['skip'] = gr.Checkbox(label="Skip Already Transcribed", value=False)
|
DATASET_SETTINGS['skip'] = gr.Checkbox(label="Skip Existing", value=False)
|
||||||
DATASET_SETTINGS['slice'] = gr.Checkbox(label="Slice Segments", value=False)
|
DATASET_SETTINGS['slice'] = gr.Checkbox(label="Slice Segments", value=False)
|
||||||
DATASET_SETTINGS['trim_silence'] = gr.Checkbox(label="Trim Silence", value=False)
|
DATASET_SETTINGS['trim_silence'] = gr.Checkbox(label="Trim Silence", value=False)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
@ -496,6 +496,7 @@ def setup_gradio():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
TRAINING_SETTINGS["half_p"] = gr.Checkbox(label="Half Precision", value=args.training_default_halfp)
|
TRAINING_SETTINGS["half_p"] = gr.Checkbox(label="Half Precision", value=args.training_default_halfp)
|
||||||
TRAINING_SETTINGS["bitsandbytes"] = gr.Checkbox(label="BitsAndBytes", value=args.training_default_bnb)
|
TRAINING_SETTINGS["bitsandbytes"] = gr.Checkbox(label="BitsAndBytes", value=args.training_default_bnb)
|
||||||
|
TRAINING_SETTINGS["validation_enabled"] = gr.Checkbox(label="Validation Enabled", value=False)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
TRAINING_SETTINGS["workers"] = gr.Number(label="Worker Processes", value=2, precision=0)
|
TRAINING_SETTINGS["workers"] = gr.Number(label="Worker Processes", value=2, precision=0)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user