This commit is contained in:
ben_mkiv 2023-08-22 21:00:06 +02:00
commit 1ec3344999
4 changed files with 216 additions and 123 deletions

View File

@ -1,8 +1,8 @@
# AI Voice Cloning # AI Voice Cloning
This [repo](https://git.ecker.tech/mrq/ai-voice-cloning)/[rentry](https://rentry.org/AI-Voice-Cloning/) aims to serve as both a foolproof guide for setting up AI voice cloning tools for legitimate, local use on Windows/Linux, as well as a stepping stone for anons that genuinely want to play around with [TorToiSe](https://github.com/neonbjb/tortoise-tts). > **Note** This project has been in dire need of being rewritten from the ground up for some time. Apologies for any crust from my rather spaghetti code.
Similar to my own findings for Stable Diffusion image generation, this rentry may appear a little disheveled as I note my new findings with TorToiSe. Please keep this in mind if the guide seems to shift a bit or sound confusing. This [repo](https://git.ecker.tech/mrq/ai-voice-cloning)/[rentry](https://rentry.org/AI-Voice-Cloning/) aims to serve as both a foolproof guide for setting up AI voice cloning tools for legitimate, local use on Windows/Linux, as well as a stepping stone for anons that genuinely want to play around with [TorToiSe](https://github.com/neonbjb/tortoise-tts).
>\>Ugh... why bother when I can just abuse 11.AI? >\>Ugh... why bother when I can just abuse 11.AI?

@ -1 +1 @@
Subproject commit 5ff00bf3bfa97e2c8e9f166b920273f83ac9d8f0 Subproject commit b10c58436d6871c26485d30b203e6cfdd4167602

View File

@ -8,3 +8,4 @@ voicefixer
psutil psutil
phonemizer phonemizer
pydantic==1.10.11 pydantic==1.10.11
websockets

View File

@ -45,7 +45,7 @@ from tortoise.utils.device import get_device_name, set_device_name, get_device_c
MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth" MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth"
WHISPER_MODELS = ["tiny", "base", "small", "medium", "large"] WHISPER_MODELS = ["tiny", "base", "small", "medium", "large", "large-v1", "large-v2"]
WHISPER_SPECIALIZED_MODELS = ["tiny.en", "base.en", "small.en", "medium.en"] WHISPER_SPECIALIZED_MODELS = ["tiny.en", "base.en", "small.en", "medium.en"]
WHISPER_BACKENDS = ["openai/whisper", "lightmare/whispercpp", "m-bain/whisperx"] WHISPER_BACKENDS = ["openai/whisper", "lightmare/whispercpp", "m-bain/whisperx"]
VOCODERS = ['univnet', 'bigvgan_base_24khz_100band', 'bigvgan_24khz_100band'] VOCODERS = ['univnet', 'bigvgan_base_24khz_100band', 'bigvgan_24khz_100band']
@ -61,12 +61,15 @@ RESAMPLERS = {}
MIN_TRAINING_DURATION = 0.6 MIN_TRAINING_DURATION = 0.6
MAX_TRAINING_DURATION = 11.6097505669 MAX_TRAINING_DURATION = 11.6097505669
MAX_TRAINING_CHAR_LENGTH = 200
VALLE_ENABLED = False VALLE_ENABLED = False
BARK_ENABLED = False BARK_ENABLED = False
VERBOSE_DEBUG = True VERBOSE_DEBUG = True
import traceback
try: try:
from whisper.normalizers.english import EnglishTextNormalizer from whisper.normalizers.english import EnglishTextNormalizer
from whisper.normalizers.basic import BasicTextNormalizer from whisper.normalizers.basic import BasicTextNormalizer
@ -75,7 +78,7 @@ try:
print("Whisper detected") print("Whisper detected")
except Exception as e: except Exception as e:
if VERBOSE_DEBUG: if VERBOSE_DEBUG:
print("Error:", e) print(traceback.format_exc())
pass pass
try: try:
@ -90,12 +93,14 @@ try:
VALLE_ENABLED = True VALLE_ENABLED = True
except Exception as e: except Exception as e:
if VERBOSE_DEBUG: if VERBOSE_DEBUG:
print("Error:", e) print(traceback.format_exc())
pass pass
if VALLE_ENABLED: if VALLE_ENABLED:
TTSES.append('vall-e') TTSES.append('vall-e')
# torchaudio.set_audio_backend('soundfile')
try: try:
import bark import bark
from bark import text_to_semantic from bark import text_to_semantic
@ -109,35 +114,10 @@ try:
BARK_ENABLED = True BARK_ENABLED = True
except Exception as e: except Exception as e:
if VERBOSE_DEBUG: if VERBOSE_DEBUG:
print("Error:", e) print(traceback.format_exc())
pass pass
if BARK_ENABLED: if BARK_ENABLED:
try:
from vocos import Vocos
VOCOS_ENABLED = True
print("Vocos detected")
except Exception as e:
if VERBOSE_DEBUG:
print("Error:", e)
VOCOS_ENABLED = False
try:
from hubert.hubert_manager import HuBERTManager
from hubert.pre_kmeans_hubert import CustomHubert
from hubert.customtokenizer import CustomTokenizer
hubert_manager = HuBERTManager()
hubert_manager.make_sure_hubert_installed()
hubert_manager.make_sure_tokenizer_installed()
HUBERT_ENABLED = True
print("HuBERT detected")
except Exception as e:
if VERBOSE_DEBUG:
print("Error:", e)
HUBERT_ENABLED = False
TTSES.append('bark') TTSES.append('bark')
def semantic_to_audio_tokens( def semantic_to_audio_tokens(
@ -181,7 +161,32 @@ if BARK_ENABLED:
self.device = get_device_name() self.device = get_device_name()
if VOCOS_ENABLED: try:
from vocos import Vocos
self.vocos_enabled = True
print("Vocos detected")
except Exception as e:
if VERBOSE_DEBUG:
print(traceback.format_exc())
self.vocos_enabled = False
try:
from hubert.hubert_manager import HuBERTManager
from hubert.pre_kmeans_hubert import CustomHubert
from hubert.customtokenizer import CustomTokenizer
hubert_manager = HuBERTManager()
hubert_manager.make_sure_hubert_installed()
hubert_manager.make_sure_tokenizer_installed()
self.hubert_enabled = True
print("HuBERT detected")
except Exception as e:
if VERBOSE_DEBUG:
print(traceback.format_exc())
self.hubert_enabled = False
if self.vocos_enabled:
self.vocos = Vocos.from_pretrained("charactr/vocos-encodec-24khz").to(self.device) self.vocos = Vocos.from_pretrained("charactr/vocos-encodec-24khz").to(self.device)
def create_voice( self, voice ): def create_voice( self, voice ):
@ -238,7 +243,7 @@ if BARK_ENABLED:
# generate semantic tokens # generate semantic tokens
if HUBERT_ENABLED: if self.hubert_enabled:
wav = wav.to(self.device) wav = wav.to(self.device)
# Extract discrete codes from EnCodec # Extract discrete codes from EnCodec
@ -426,7 +431,7 @@ def generate_bark(**kwargs):
idx_cache = {} idx_cache = {}
for i, file in enumerate(os.listdir(outdir)): for i, file in enumerate(os.listdir(outdir)):
filename = os.path.basename(file) filename = os.path.basename(file)
extension = os.path.splitext(filename)[1] extension = os.path.splitext(filename)[-1][1:]
if extension != ".json" and extension != ".wav": if extension != ".json" and extension != ".wav":
continue continue
match = re.findall(rf"^{cleanup_voice_name(voice)}_(\d+)(?:.+?)?{extension}$", filename) match = re.findall(rf"^{cleanup_voice_name(voice)}_(\d+)(?:.+?)?{extension}$", filename)
@ -672,18 +677,23 @@ def generate_valle(**kwargs):
voice_cache = {} voice_cache = {}
def fetch_voice( voice ): def fetch_voice( voice ):
if voice in voice_cache:
return voice_cache[voice]
voice_dir = f'./training/{voice}/audio/' voice_dir = f'./training/{voice}/audio/'
if not os.path.isdir(voice_dir):
if not os.path.isdir(voice_dir) or len(os.listdir(voice_dir)) == 0:
voice_dir = f'./voices/{voice}/' voice_dir = f'./voices/{voice}/'
files = [ f'{voice_dir}/{d}' for d in os.listdir(voice_dir) if d[-4:] == ".wav" ] files = [ f'{voice_dir}/{d}' for d in os.listdir(voice_dir) if d[-4:] == ".wav" ]
# return files # return files
return random.choice(files) voice_cache[voice] = random.choice(files)
return voice_cache[voice]
def get_settings( override=None ): def get_settings( override=None ):
settings = { settings = {
'ar_temp': float(parameters['temperature']), 'ar_temp': float(parameters['temperature']),
'nar_temp': float(parameters['temperature']), 'nar_temp': float(parameters['temperature']),
'max_ar_samples': parameters['num_autoregressive_samples'], 'max_ar_steps': parameters['num_autoregressive_samples'],
} }
# could be better to just do a ternary on everything above, but i am not a professional # could be better to just do a ternary on everything above, but i am not a professional
@ -697,7 +707,7 @@ def generate_valle(**kwargs):
continue continue
settings[k] = override[k] settings[k] = override[k]
settings['reference'] = fetch_voice(voice=selected_voice) settings['references'] = [ fetch_voice(voice=selected_voice) for _ in range(3) ]
return settings return settings
if not parameters['delimiter']: if not parameters['delimiter']:
@ -723,7 +733,7 @@ def generate_valle(**kwargs):
idx_cache = {} idx_cache = {}
for i, file in enumerate(os.listdir(outdir)): for i, file in enumerate(os.listdir(outdir)):
filename = os.path.basename(file) filename = os.path.basename(file)
extension = os.path.splitext(filename)[1] extension = os.path.splitext(filename)[-1][1:]
if extension != ".json" and extension != ".wav": if extension != ".json" and extension != ".wav":
continue continue
match = re.findall(rf"^{voice}_(\d+)(?:.+?)?{extension}$", filename) match = re.findall(rf"^{voice}_(\d+)(?:.+?)?{extension}$", filename)
@ -783,11 +793,14 @@ def generate_valle(**kwargs):
except Exception as e: except Exception as e:
raise Exception("Prompt settings editing requested, but received invalid JSON") raise Exception("Prompt settings editing requested, but received invalid JSON")
settings = get_settings( override=override ) name = get_name(line=line, candidate=0)
reference = settings['reference']
settings.pop("reference")
gen = tts.inference(cut_text, reference, **settings ) settings = get_settings( override=override )
references = settings['references']
settings.pop("references")
settings['out_path'] = f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav'
gen = tts.inference(cut_text, references, **settings )
run_time = time.time()-start_time run_time = time.time()-start_time
print(f"Generating line took {run_time} seconds") print(f"Generating line took {run_time} seconds")
@ -805,7 +818,7 @@ def generate_valle(**kwargs):
# save here in case some error happens mid-batch # save here in case some error happens mid-batch
#torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu(), sr) #torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu(), sr)
soundfile.write(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu()[0,0], sr) #soundfile.write(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu()[0,0], sr)
wav, sr = torchaudio.load(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav') wav, sr = torchaudio.load(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
audio_cache[name] = { audio_cache[name] = {
@ -1085,7 +1098,7 @@ def generate_tortoise(**kwargs):
idx_cache = {} idx_cache = {}
for i, file in enumerate(os.listdir(outdir)): for i, file in enumerate(os.listdir(outdir)):
filename = os.path.basename(file) filename = os.path.basename(file)
extension = os.path.splitext(filename)[1] extension = os.path.splitext(filename)[-1][1:]
if extension != ".json" and extension != ".wav": if extension != ".json" and extension != ".wav":
continue continue
match = re.findall(rf"^{voice}_(\d+)(?:.+?)?{extension}$", filename) match = re.findall(rf"^{voice}_(\d+)(?:.+?)?{extension}$", filename)
@ -1605,30 +1618,18 @@ class TrainingState():
if args.tts_backend == "vall-e": if args.tts_backend == "vall-e":
keys['lrs'] = [ keys['lrs'] = [
'ar.lr', 'nar.lr', 'ar.lr', 'nar.lr',
'ar-half.lr', 'nar-half.lr',
'ar-quarter.lr', 'nar-quarter.lr',
] ]
keys['losses'] = [ keys['losses'] = [
'ar.loss', 'nar.loss', 'ar+nar.loss', # 'ar.loss', 'nar.loss', 'ar+nar.loss',
'ar-half.loss', 'nar-half.loss', 'ar-half+nar-half.loss',
'ar-quarter.loss', 'nar-quarter.loss', 'ar-quarter+nar-quarter.loss',
'ar.loss.nll', 'nar.loss.nll', 'ar.loss.nll', 'nar.loss.nll',
'ar-half.loss.nll', 'nar-half.loss.nll',
'ar-quarter.loss.nll', 'nar-quarter.loss.nll',
] ]
keys['accuracies'] = [ keys['accuracies'] = [
'ar.loss.acc', 'nar.loss.acc', 'ar.loss.acc', 'nar.loss.acc',
'ar-half.loss.acc', 'nar-half.loss.acc', 'ar.stats.acc', 'nar.loss.acc',
'ar-quarter.loss.acc', 'nar-quarter.loss.acc',
] ]
keys['precisions'] = [ keys['precisions'] = [ 'ar.loss.precision', 'nar.loss.precision', ]
'ar.loss.precision', 'nar.loss.precision', keys['grad_norms'] = ['ar.grad_norm', 'nar.grad_norm']
'ar-half.loss.precision', 'nar-half.loss.precision',
'ar-quarter.loss.precision', 'nar-quarter.loss.precision',
]
keys['grad_norms'] = ['ar.grad_norm', 'nar.grad_norm', 'ar-half.grad_norm', 'nar-half.grad_norm', 'ar-quarter.grad_norm', 'nar-quarter.grad_norm']
for k in keys['lrs']: for k in keys['lrs']:
if k not in self.info: if k not in self.info:
@ -1746,7 +1747,8 @@ class TrainingState():
if args.tts_backend == "tortoise": if args.tts_backend == "tortoise":
logs = sorted([f'{self.training_dir}/finetune/{d}' for d in os.listdir(f'{self.training_dir}/finetune/') if d[-4:] == ".log" ]) logs = sorted([f'{self.training_dir}/finetune/{d}' for d in os.listdir(f'{self.training_dir}/finetune/') if d[-4:] == ".log" ])
else: else:
logs = sorted([f'{self.training_dir}/logs/{d}/log.txt' for d in os.listdir(f'{self.training_dir}/logs/') ]) log_dir = "logs"
logs = sorted([f'{self.training_dir}/{log_dir}/{d}/log.txt' for d in os.listdir(f'{self.training_dir}/{log_dir}/') ])
if update: if update:
logs = [logs[-1]] logs = [logs[-1]]
@ -2220,6 +2222,8 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non
indir = f'./training/{voice}/' indir = f'./training/{voice}/'
infile = f'{indir}/whisper.json' infile = f'{indir}/whisper.json'
quantize_in_memory = args.tts_backend == "vall-e"
os.makedirs(f'{indir}/audio/', exist_ok=True) os.makedirs(f'{indir}/audio/', exist_ok=True)
TARGET_SAMPLE_RATE = 22050 TARGET_SAMPLE_RATE = 22050
@ -2245,13 +2249,24 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non
continue continue
results[basename] = result results[basename] = result
waveform, sample_rate = torchaudio.load(file)
# resample to the input rate, since it'll get resampled for training anyways if not quantize_in_memory:
# this should also "help" increase throughput a bit when filling the dataloaders waveform, sample_rate = torchaudio.load(file)
waveform, sample_rate = resample(waveform, sample_rate, TARGET_SAMPLE_RATE) # resample to the input rate, since it'll get resampled for training anyways
if waveform.shape[0] == 2: # this should also "help" increase throughput a bit when filling the dataloaders
waveform = waveform[:1] waveform, sample_rate = resample(waveform, sample_rate, TARGET_SAMPLE_RATE)
torchaudio.save(f"{indir}/audio/{basename}", waveform, sample_rate, encoding="PCM_S", bits_per_sample=16) if waveform.shape[0] == 2:
waveform = waveform[:1]
try:
kwargs = {}
if basename[-4:] == ".wav":
kwargs['encoding'] = "PCM_S"
kwargs['bits_per_sample'] = 16
torchaudio.save(f"{indir}/audio/{basename}", waveform, sample_rate, **kwargs)
except Exception as e:
print(e)
with open(infile, 'w', encoding="utf-8") as f: with open(infile, 'w', encoding="utf-8") as f:
f.write(json.dumps(results, indent='\t')) f.write(json.dumps(results, indent='\t'))
@ -2317,6 +2332,9 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, resul
segments = 0 segments = 0
for filename in results: for filename in results:
path = f'./voices/{voice}/{filename}' path = f'./voices/{voice}/{filename}'
extension = os.path.splitext(filename)[-1][1:]
out_extension = extension # "wav"
if not os.path.exists(path): if not os.path.exists(path):
path = f'./training/{voice}/{filename}' path = f'./training/{voice}/{filename}'
@ -2333,7 +2351,7 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, resul
duration = num_frames / sample_rate duration = num_frames / sample_rate
for segment in result['segments']: for segment in result['segments']:
file = filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav") file = filename.replace(f".{extension}", f"_{pad(segment['id'], 4)}.{out_extension}")
sliced, error = slice_waveform( waveform, sample_rate, segment['start'] + start_offset, segment['end'] + end_offset, trim_silence ) sliced, error = slice_waveform( waveform, sample_rate, segment['start'] + start_offset, segment['end'] + end_offset, trim_silence )
if error: if error:
@ -2341,12 +2359,18 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, resul
print(message) print(message)
messages.append(message) messages.append(message)
continue continue
sliced, _ = resample( sliced, sample_rate, TARGET_SAMPLE_RATE ) sliced, _ = resample( sliced, sample_rate, TARGET_SAMPLE_RATE )
if waveform.shape[0] == 2: if waveform.shape[0] == 2:
waveform = waveform[:1] waveform = waveform[:1]
torchaudio.save(f"{indir}/audio/{file}", sliced, TARGET_SAMPLE_RATE, encoding="PCM_S", bits_per_sample=16) kwargs = {}
if file[-4:] == ".wav":
kwargs['encoding'] = "PCM_S"
kwargs['bits_per_sample'] = 16
torchaudio.save(f"{indir}/audio/{file}", sliced, TARGET_SAMPLE_RATE, **kwargs)
segments +=1 segments +=1
@ -2462,18 +2486,32 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
errored = 0 errored = 0
messages = [] messages = []
normalize = True normalize = False # True
phonemize = should_phonemize() phonemize = should_phonemize()
lines = { 'training': [], 'validation': [] } lines = { 'training': [], 'validation': [] }
segments = {} segments = {}
quantize_in_memory = args.tts_backend == "vall-e"
if args.tts_backend != "tortoise": if args.tts_backend != "tortoise":
text_length = 0 text_length = 0
audio_length = 0 audio_length = 0
start_offset = -0.1
end_offset = 0.1
trim_silence = False
TARGET_SAMPLE_RATE = 22050
if args.tts_backend != "tortoise":
TARGET_SAMPLE_RATE = 24000
if tts:
TARGET_SAMPLE_RATE = tts.input_sample_rate
for filename in tqdm(results, desc="Parsing results"): for filename in tqdm(results, desc="Parsing results"):
use_segment = use_segments use_segment = use_segments
extension = os.path.splitext(filename)[-1][1:]
out_extension = extension # "wav"
result = results[filename] result = results[filename]
lang = result['language'] lang = result['language']
language = LANGUAGES[lang] if lang in LANGUAGES else lang language = LANGUAGES[lang] if lang in LANGUAGES else lang
@ -2481,8 +2519,8 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
# check if unsegmented text exceeds 200 characters # check if unsegmented text exceeds 200 characters
if not use_segment: if not use_segment:
if len(result['text']) > 200: if len(result['text']) > MAX_TRAINING_CHAR_LENGTH:
message = f"Text length too long (200 < {len(result['text'])}), using segments: {filename}" message = f"Text length too long ({MAX_TRAINING_CHAR_LENGTH} < {len(result['text'])}), using segments: {filename}"
print(message) print(message)
messages.append(message) messages.append(message)
use_segment = True use_segment = True
@ -2490,13 +2528,15 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
# check if unsegmented audio exceeds 11.6s # check if unsegmented audio exceeds 11.6s
if not use_segment: if not use_segment:
path = f'{indir}/audio/{filename}' path = f'{indir}/audio/{filename}'
if not os.path.exists(path): if not quantize_in_memory and not os.path.exists(path):
messages.append(f"Missing source audio: {filename}") messages.append(f"Missing source audio: {filename}")
errored += 1 errored += 1
continue continue
metadata = torchaudio.info(path) duration = 0
duration = metadata.num_frames / metadata.sample_rate for segment in result['segments']:
duration = max(duration, result['segments'][segment]['end'])
if duration >= MAX_TRAINING_DURATION: if duration >= MAX_TRAINING_DURATION:
message = f"Audio too large, using segments: {filename}" message = f"Audio too large, using segments: {filename}"
print(message) print(message)
@ -2511,13 +2551,13 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
if duration <= MIN_TRAINING_DURATION or MAX_TRAINING_DURATION <= duration: if duration <= MIN_TRAINING_DURATION or MAX_TRAINING_DURATION <= duration:
continue continue
path = f'{indir}/audio/' + filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav") path = f'{indir}/audio/' + filename.replace(f".{extension}", f"_{pad(segment['id'], 4)}.{out_extension}")
if os.path.exists(path): if os.path.exists(path):
continue continue
exists = False exists = False
break break
if not exists: if not quantize_in_memory and not exists:
tmp = {} tmp = {}
tmp[filename] = result tmp[filename] = result
print(f"Audio not segmented, segmenting: {filename}") print(f"Audio not segmented, segmenting: {filename}")
@ -2525,6 +2565,23 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
print(message) print(message)
messages = messages + message.split("\n") messages = messages + message.split("\n")
waveform = None
if quantize_in_memory:
path = f'{indir}/audio/{filename}'
if not os.path.exists(path):
path = f'./voices/{voice}/{filename}'
if not os.path.exists(path):
message = f"Audio not found: {path}"
print(message)
messages.append(message)
#continue
else:
waveform = torchaudio.load(path)
waveform = resample(waveform[0], waveform[1], TARGET_SAMPLE_RATE)
if not use_segment: if not use_segment:
segments[filename] = { segments[filename] = {
'text': result['text'], 'text': result['text'],
@ -2533,13 +2590,18 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
'normalizer': normalizer, 'normalizer': normalizer,
'phonemes': result['phonemes'] if 'phonemes' in result else None 'phonemes': result['phonemes'] if 'phonemes' in result else None
} }
if waveform:
segments[filename]['waveform'] = waveform
else: else:
for segment in result['segments']: for segment in result['segments']:
duration = segment['end'] - segment['start'] duration = segment['end'] - segment['start']
if duration <= MIN_TRAINING_DURATION or MAX_TRAINING_DURATION <= duration: if duration <= MIN_TRAINING_DURATION or MAX_TRAINING_DURATION <= duration:
continue continue
segments[filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav")] = { file = filename.replace(f".{extension}", f"_{pad(segment['id'], 4)}.{out_extension}")
segments[file] = {
'text': segment['text'], 'text': segment['text'],
'lang': lang, 'lang': lang,
'language': language, 'language': language,
@ -2547,22 +2609,27 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
'phonemes': segment['phonemes'] if 'phonemes' in segment else None 'phonemes': segment['phonemes'] if 'phonemes' in segment else None
} }
if waveform:
sliced, error = slice_waveform( waveform[0], waveform[1], segment['start'] + start_offset, segment['end'] + end_offset, trim_silence )
if error:
message = f"{error}, skipping... {file}"
print(message)
messages.append(message)
segments[file]['error'] = error
#continue
else:
segments[file]['waveform'] = (sliced, waveform[1])
jobs = { jobs = {
'quantize': [[], []], 'quantize': [[], []],
'phonemize': [[], []], 'phonemize': [[], []],
} }
for file in tqdm(segments, desc="Parsing segments"): for file in tqdm(segments, desc="Parsing segments"):
extension = os.path.splitext(file)[-1][1:]
result = segments[file] result = segments[file]
path = f'{indir}/audio/{file}' path = f'{indir}/audio/{file}'
if not os.path.exists(path):
message = f"Missing segment, skipping... {file}"
print(message)
messages.append(message)
errored += 1
continue
text = result['text'] text = result['text']
lang = result['lang'] lang = result['lang']
language = result['language'] language = result['language']
@ -2573,28 +2640,20 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
normalized = normalizer(text) if normalize else text normalized = normalizer(text) if normalize else text
if len(text) > 200: if len(text) > MAX_TRAINING_CHAR_LENGTH:
message = f"Text length too long (200 < {len(text)}), skipping... {file}" message = f"Text length too long ({MAX_TRAINING_CHAR_LENGTH} < {len(text)}), skipping... {file}"
print(message) print(message)
messages.append(message) messages.append(message)
errored += 1 errored += 1
continue continue
waveform, sample_rate = torchaudio.load(path) # num_channels, num_frames = waveform.shape
num_channels, num_frames = waveform.shape #duration = num_frames / sample_rate
duration = num_frames / sample_rate
error = validate_waveform( waveform, sample_rate )
if error:
message = f"{error}, skipping... {file}"
print(message)
messages.append(message)
errored += 1
continue
culled = len(text) < text_length culled = len(text) < text_length
if not culled and audio_length > 0: #if not culled and audio_length > 0:
culled = duration < audio_length # culled = duration < audio_length
line = f'audio/{file}|{phonemes if phonemize and phonemes else text}' line = f'audio/{file}|{phonemes if phonemize and phonemes else text}'
@ -2605,17 +2664,8 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
os.makedirs(f'{indir}/valle/', exist_ok=True) os.makedirs(f'{indir}/valle/', exist_ok=True)
qnt_file = f'{indir}/valle/{file.replace(".wav",".qnt.pt")}' #phn_file = f'{indir}/valle/{file.replace(f".{extension}",".phn.txt")}'
if not os.path.exists(qnt_file): phn_file = f'./training/valle/data/{voice}/{file.replace(f".{extension}",".phn.txt")}'
jobs['quantize'][0].append(qnt_file)
jobs['quantize'][1].append((waveform, sample_rate))
"""
quantized = valle_quantize( waveform, sample_rate ).cpu()
torch.save(quantized, f'{indir}/valle/{file.replace(".wav",".qnt.pt")}')
print("Quantized:", file)
"""
phn_file = f'{indir}/valle/{file.replace(".wav",".phn.txt")}'
if not os.path.exists(phn_file): if not os.path.exists(phn_file):
jobs['phonemize'][0].append(phn_file) jobs['phonemize'][0].append(phn_file)
jobs['phonemize'][1].append(normalized) jobs['phonemize'][1].append(normalized)
@ -2625,13 +2675,46 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
print("Phonemized:", file, normalized, text) print("Phonemized:", file, normalized, text)
""" """
#qnt_file = f'{indir}/valle/{file.replace(f".{extension}",".qnt.pt")}'
qnt_file = f'./training/valle/data/{voice}/{file.replace(f".{extension}",".qnt.pt")}'
if 'error' not in result:
if not quantize_in_memory and not os.path.exists(path):
message = f"Missing segment, skipping... {file}"
print(message)
messages.append(message)
errored += 1
continue
if not os.path.exists(qnt_file):
waveform = None
if 'waveform' in result:
waveform, sample_rate = result['waveform']
elif os.path.exists(path):
waveform, sample_rate = torchaudio.load(path)
error = validate_waveform( waveform, sample_rate )
if error:
message = f"{error}, skipping... {file}"
print(message)
messages.append(message)
errored += 1
continue
if waveform is not None:
jobs['quantize'][0].append(qnt_file)
jobs['quantize'][1].append((waveform, sample_rate))
"""
quantized = valle_quantize( waveform, sample_rate ).cpu()
torch.save(quantized, f'{indir}/valle/{file.replace(".wav",".qnt.pt")}')
print("Quantized:", file)
"""
for i in tqdm(range(len(jobs['quantize'][0])), desc="Quantizing"): for i in tqdm(range(len(jobs['quantize'][0])), desc="Quantizing"):
qnt_file = jobs['quantize'][0][i] qnt_file = jobs['quantize'][0][i]
waveform, sample_rate = jobs['quantize'][1][i] waveform, sample_rate = jobs['quantize'][1][i]
quantized = valle_quantize( waveform, sample_rate ).cpu() quantized = valle_quantize( waveform, sample_rate ).cpu()
torch.save(quantized, qnt_file) torch.save(quantized, qnt_file)
print("Quantized:", qnt_file) #print("Quantized:", qnt_file)
for i in tqdm(range(len(jobs['phonemize'][0])), desc="Phonemizing"): for i in tqdm(range(len(jobs['phonemize'][0])), desc="Phonemizing"):
phn_file = jobs['phonemize'][0][i] phn_file = jobs['phonemize'][0][i]
@ -2640,7 +2723,7 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
try: try:
phonemized = valle_phonemize( normalized ) phonemized = valle_phonemize( normalized )
open(phn_file, 'w', encoding='utf-8').write(" ".join(phonemized)) open(phn_file, 'w', encoding='utf-8').write(" ".join(phonemized))
print("Phonemized:", phn_file) #print("Phonemized:", phn_file)
except Exception as e: except Exception as e:
message = f"Failed to phonemize: {phn_file}: {normalized}" message = f"Failed to phonemize: {phn_file}: {normalized}"
messages.append(message) messages.append(message)
@ -2970,17 +3053,26 @@ def import_voices(files, saveAs=None, progress=None):
def relative_paths( dirs ): def relative_paths( dirs ):
return [ './' + os.path.relpath( d ).replace("\\", "/") for d in dirs ] return [ './' + os.path.relpath( d ).replace("\\", "/") for d in dirs ]
def get_voice( name, dir=get_voice_dir(), load_latents=True ): def get_voice( name, dir=get_voice_dir(), load_latents=True, extensions=["wav", "mp3", "flac"] ):
subj = f'{dir}/{name}/' subj = f'{dir}/{name}/'
if not os.path.isdir(subj): if not os.path.isdir(subj):
return return
files = os.listdir(subj)
voice = list(glob(f'{subj}/*.wav')) + list(glob(f'{subj}/*.mp3')) + list(glob(f'{subj}/*.flac'))
if load_latents: if load_latents:
voice = voice + list(glob(f'{subj}/*.pth')) extensions.append("pth")
voice = []
for file in files:
ext = os.path.splitext(file)[-1][1:]
if ext not in extensions:
continue
voice.append(f'{subj}/{file}')
return sorted( voice ) return sorted( voice )
def get_voice_list(dir=get_voice_dir(), append_defaults=False): def get_voice_list(dir=get_voice_dir(), append_defaults=False, extensions=["wav", "mp3", "flac", "pth"]):
defaults = [ "random", "microphone" ] defaults = [ "random", "microphone" ]
os.makedirs(dir, exist_ok=True) os.makedirs(dir, exist_ok=True)
#res = sorted([d for d in os.listdir(dir) if d not in defaults and os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 ]) #res = sorted([d for d in os.listdir(dir) if d not in defaults and os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 ])
@ -2993,7 +3085,7 @@ def get_voice_list(dir=get_voice_dir(), append_defaults=False):
continue continue
if len(os.listdir(os.path.join(dir, name))) == 0: if len(os.listdir(os.path.join(dir, name))) == 0:
continue continue
files = get_voice( name, dir=dir ) files = get_voice( name, dir=dir, extensions=extensions )
if len(files) > 0: if len(files) > 0:
res.append(name) res.append(name)
@ -3001,7 +3093,7 @@ def get_voice_list(dir=get_voice_dir(), append_defaults=False):
for subdir in os.listdir(f'{dir}/{name}'): for subdir in os.listdir(f'{dir}/{name}'):
if not os.path.isdir(f'{dir}/{name}/{subdir}'): if not os.path.isdir(f'{dir}/{name}/{subdir}'):
continue continue
files = get_voice( f'{name}/{subdir}', dir=dir ) files = get_voice( f'{name}/{subdir}', dir=dir, extensions=extensions )
if len(files) == 0: if len(files) == 0:
continue continue
res.append(f'{name}/{subdir}') res.append(f'{name}/{subdir}')