diff --git a/README.md b/README.md index b34fd68..f04b5c4 100755 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ # AI Voice Cloning -This [repo](https://git.ecker.tech/mrq/ai-voice-cloning)/[rentry](https://rentry.org/AI-Voice-Cloning/) aims to serve as both a foolproof guide for setting up AI voice cloning tools for legitimate, local use on Windows/Linux, as well as a stepping stone for anons that genuinely want to play around with [TorToiSe](https://github.com/neonbjb/tortoise-tts). +> **Note** This project has been in dire need of being rewritten from the ground up for some time. Apologies for any crust from my rather spaghetti code. -Similar to my own findings for Stable Diffusion image generation, this rentry may appear a little disheveled as I note my new findings with TorToiSe. Please keep this in mind if the guide seems to shift a bit or sound confusing. +This [repo](https://git.ecker.tech/mrq/ai-voice-cloning)/[rentry](https://rentry.org/AI-Voice-Cloning/) aims to serve as both a foolproof guide for setting up AI voice cloning tools for legitimate, local use on Windows/Linux, as well as a stepping stone for anons that genuinely want to play around with [TorToiSe](https://github.com/neonbjb/tortoise-tts). >\>Ugh... why bother when I can just abuse 11.AI? diff --git a/modules/tortoise-tts b/modules/tortoise-tts index 5ff00bf..b10c584 160000 --- a/modules/tortoise-tts +++ b/modules/tortoise-tts @@ -1 +1 @@ -Subproject commit 5ff00bf3bfa97e2c8e9f166b920273f83ac9d8f0 +Subproject commit b10c58436d6871c26485d30b203e6cfdd4167602 diff --git a/requirements.txt b/requirements.txt index e1794a4..fcb0746 100755 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,5 @@ music-tag voicefixer psutil phonemizer -pydantic==1.10.11 \ No newline at end of file +pydantic==1.10.11 +websockets \ No newline at end of file diff --git a/src/utils.py b/src/utils.py index 759e054..1eaefb7 100755 --- a/src/utils.py +++ b/src/utils.py @@ -45,7 +45,7 @@ from tortoise.utils.device import get_device_name, set_device_name, get_device_c MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth" -WHISPER_MODELS = ["tiny", "base", "small", "medium", "large"] +WHISPER_MODELS = ["tiny", "base", "small", "medium", "large", "large-v1", "large-v2"] WHISPER_SPECIALIZED_MODELS = ["tiny.en", "base.en", "small.en", "medium.en"] WHISPER_BACKENDS = ["openai/whisper", "lightmare/whispercpp", "m-bain/whisperx"] VOCODERS = ['univnet', 'bigvgan_base_24khz_100band', 'bigvgan_24khz_100band'] @@ -61,12 +61,15 @@ RESAMPLERS = {} MIN_TRAINING_DURATION = 0.6 MAX_TRAINING_DURATION = 11.6097505669 +MAX_TRAINING_CHAR_LENGTH = 200 VALLE_ENABLED = False BARK_ENABLED = False VERBOSE_DEBUG = True +import traceback + try: from whisper.normalizers.english import EnglishTextNormalizer from whisper.normalizers.basic import BasicTextNormalizer @@ -75,7 +78,7 @@ try: print("Whisper detected") except Exception as e: if VERBOSE_DEBUG: - print("Error:", e) + print(traceback.format_exc()) pass try: @@ -90,12 +93,14 @@ try: VALLE_ENABLED = True except Exception as e: if VERBOSE_DEBUG: - print("Error:", e) + print(traceback.format_exc()) pass if VALLE_ENABLED: TTSES.append('vall-e') +# torchaudio.set_audio_backend('soundfile') + try: import bark from bark import text_to_semantic @@ -109,35 +114,10 @@ try: BARK_ENABLED = True except Exception as e: if VERBOSE_DEBUG: - print("Error:", e) + print(traceback.format_exc()) pass if BARK_ENABLED: - try: - from vocos import Vocos - VOCOS_ENABLED = True - print("Vocos detected") - except Exception as e: - if VERBOSE_DEBUG: - print("Error:", e) - VOCOS_ENABLED = False - - try: - from hubert.hubert_manager import HuBERTManager - from hubert.pre_kmeans_hubert import CustomHubert - from hubert.customtokenizer import CustomTokenizer - - hubert_manager = HuBERTManager() - hubert_manager.make_sure_hubert_installed() - hubert_manager.make_sure_tokenizer_installed() - - HUBERT_ENABLED = True - print("HuBERT detected") - except Exception as e: - if VERBOSE_DEBUG: - print("Error:", e) - HUBERT_ENABLED = False - TTSES.append('bark') def semantic_to_audio_tokens( @@ -181,7 +161,32 @@ if BARK_ENABLED: self.device = get_device_name() - if VOCOS_ENABLED: + try: + from vocos import Vocos + self.vocos_enabled = True + print("Vocos detected") + except Exception as e: + if VERBOSE_DEBUG: + print(traceback.format_exc()) + self.vocos_enabled = False + + try: + from hubert.hubert_manager import HuBERTManager + from hubert.pre_kmeans_hubert import CustomHubert + from hubert.customtokenizer import CustomTokenizer + + hubert_manager = HuBERTManager() + hubert_manager.make_sure_hubert_installed() + hubert_manager.make_sure_tokenizer_installed() + + self.hubert_enabled = True + print("HuBERT detected") + except Exception as e: + if VERBOSE_DEBUG: + print(traceback.format_exc()) + self.hubert_enabled = False + + if self.vocos_enabled: self.vocos = Vocos.from_pretrained("charactr/vocos-encodec-24khz").to(self.device) def create_voice( self, voice ): @@ -238,7 +243,7 @@ if BARK_ENABLED: # generate semantic tokens - if HUBERT_ENABLED: + if self.hubert_enabled: wav = wav.to(self.device) # Extract discrete codes from EnCodec @@ -426,7 +431,7 @@ def generate_bark(**kwargs): idx_cache = {} for i, file in enumerate(os.listdir(outdir)): filename = os.path.basename(file) - extension = os.path.splitext(filename)[1] + extension = os.path.splitext(filename)[-1][1:] if extension != ".json" and extension != ".wav": continue match = re.findall(rf"^{cleanup_voice_name(voice)}_(\d+)(?:.+?)?{extension}$", filename) @@ -672,18 +677,23 @@ def generate_valle(**kwargs): voice_cache = {} def fetch_voice( voice ): + if voice in voice_cache: + return voice_cache[voice] voice_dir = f'./training/{voice}/audio/' - if not os.path.isdir(voice_dir): + + if not os.path.isdir(voice_dir) or len(os.listdir(voice_dir)) == 0: voice_dir = f'./voices/{voice}/' + files = [ f'{voice_dir}/{d}' for d in os.listdir(voice_dir) if d[-4:] == ".wav" ] # return files - return random.choice(files) + voice_cache[voice] = random.choice(files) + return voice_cache[voice] def get_settings( override=None ): settings = { 'ar_temp': float(parameters['temperature']), 'nar_temp': float(parameters['temperature']), - 'max_ar_samples': parameters['num_autoregressive_samples'], + 'max_ar_steps': parameters['num_autoregressive_samples'], } # could be better to just do a ternary on everything above, but i am not a professional @@ -697,7 +707,7 @@ def generate_valle(**kwargs): continue settings[k] = override[k] - settings['reference'] = fetch_voice(voice=selected_voice) + settings['references'] = [ fetch_voice(voice=selected_voice) for _ in range(3) ] return settings if not parameters['delimiter']: @@ -723,7 +733,7 @@ def generate_valle(**kwargs): idx_cache = {} for i, file in enumerate(os.listdir(outdir)): filename = os.path.basename(file) - extension = os.path.splitext(filename)[1] + extension = os.path.splitext(filename)[-1][1:] if extension != ".json" and extension != ".wav": continue match = re.findall(rf"^{voice}_(\d+)(?:.+?)?{extension}$", filename) @@ -783,11 +793,14 @@ def generate_valle(**kwargs): except Exception as e: raise Exception("Prompt settings editing requested, but received invalid JSON") - settings = get_settings( override=override ) - reference = settings['reference'] - settings.pop("reference") + name = get_name(line=line, candidate=0) - gen = tts.inference(cut_text, reference, **settings ) + settings = get_settings( override=override ) + references = settings['references'] + settings.pop("references") + settings['out_path'] = f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav' + + gen = tts.inference(cut_text, references, **settings ) run_time = time.time()-start_time print(f"Generating line took {run_time} seconds") @@ -805,7 +818,7 @@ def generate_valle(**kwargs): # save here in case some error happens mid-batch #torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu(), sr) - soundfile.write(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu()[0,0], sr) + #soundfile.write(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu()[0,0], sr) wav, sr = torchaudio.load(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav') audio_cache[name] = { @@ -1085,7 +1098,7 @@ def generate_tortoise(**kwargs): idx_cache = {} for i, file in enumerate(os.listdir(outdir)): filename = os.path.basename(file) - extension = os.path.splitext(filename)[1] + extension = os.path.splitext(filename)[-1][1:] if extension != ".json" and extension != ".wav": continue match = re.findall(rf"^{voice}_(\d+)(?:.+?)?{extension}$", filename) @@ -1605,30 +1618,18 @@ class TrainingState(): if args.tts_backend == "vall-e": keys['lrs'] = [ 'ar.lr', 'nar.lr', - 'ar-half.lr', 'nar-half.lr', - 'ar-quarter.lr', 'nar-quarter.lr', ] keys['losses'] = [ - 'ar.loss', 'nar.loss', 'ar+nar.loss', - 'ar-half.loss', 'nar-half.loss', 'ar-half+nar-half.loss', - 'ar-quarter.loss', 'nar-quarter.loss', 'ar-quarter+nar-quarter.loss', - + # 'ar.loss', 'nar.loss', 'ar+nar.loss', 'ar.loss.nll', 'nar.loss.nll', - 'ar-half.loss.nll', 'nar-half.loss.nll', - 'ar-quarter.loss.nll', 'nar-quarter.loss.nll', ] keys['accuracies'] = [ 'ar.loss.acc', 'nar.loss.acc', - 'ar-half.loss.acc', 'nar-half.loss.acc', - 'ar-quarter.loss.acc', 'nar-quarter.loss.acc', + 'ar.stats.acc', 'nar.loss.acc', ] - keys['precisions'] = [ - 'ar.loss.precision', 'nar.loss.precision', - 'ar-half.loss.precision', 'nar-half.loss.precision', - 'ar-quarter.loss.precision', 'nar-quarter.loss.precision', - ] - keys['grad_norms'] = ['ar.grad_norm', 'nar.grad_norm', 'ar-half.grad_norm', 'nar-half.grad_norm', 'ar-quarter.grad_norm', 'nar-quarter.grad_norm'] + keys['precisions'] = [ 'ar.loss.precision', 'nar.loss.precision', ] + keys['grad_norms'] = ['ar.grad_norm', 'nar.grad_norm'] for k in keys['lrs']: if k not in self.info: @@ -1746,7 +1747,8 @@ class TrainingState(): if args.tts_backend == "tortoise": logs = sorted([f'{self.training_dir}/finetune/{d}' for d in os.listdir(f'{self.training_dir}/finetune/') if d[-4:] == ".log" ]) else: - logs = sorted([f'{self.training_dir}/logs/{d}/log.txt' for d in os.listdir(f'{self.training_dir}/logs/') ]) + log_dir = "logs" + logs = sorted([f'{self.training_dir}/{log_dir}/{d}/log.txt' for d in os.listdir(f'{self.training_dir}/{log_dir}/') ]) if update: logs = [logs[-1]] @@ -2219,6 +2221,8 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non files = get_voice(voice, load_latents=False) indir = f'./training/{voice}/' infile = f'{indir}/whisper.json' + + quantize_in_memory = args.tts_backend == "vall-e" os.makedirs(f'{indir}/audio/', exist_ok=True) @@ -2245,13 +2249,24 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non continue results[basename] = result - waveform, sample_rate = torchaudio.load(file) - # resample to the input rate, since it'll get resampled for training anyways - # this should also "help" increase throughput a bit when filling the dataloaders - waveform, sample_rate = resample(waveform, sample_rate, TARGET_SAMPLE_RATE) - if waveform.shape[0] == 2: - waveform = waveform[:1] - torchaudio.save(f"{indir}/audio/{basename}", waveform, sample_rate, encoding="PCM_S", bits_per_sample=16) + + if not quantize_in_memory: + waveform, sample_rate = torchaudio.load(file) + # resample to the input rate, since it'll get resampled for training anyways + # this should also "help" increase throughput a bit when filling the dataloaders + waveform, sample_rate = resample(waveform, sample_rate, TARGET_SAMPLE_RATE) + if waveform.shape[0] == 2: + waveform = waveform[:1] + + try: + kwargs = {} + if basename[-4:] == ".wav": + kwargs['encoding'] = "PCM_S" + kwargs['bits_per_sample'] = 16 + + torchaudio.save(f"{indir}/audio/{basename}", waveform, sample_rate, **kwargs) + except Exception as e: + print(e) with open(infile, 'w', encoding="utf-8") as f: f.write(json.dumps(results, indent='\t')) @@ -2317,6 +2332,9 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, resul segments = 0 for filename in results: path = f'./voices/{voice}/{filename}' + extension = os.path.splitext(filename)[-1][1:] + out_extension = extension # "wav" + if not os.path.exists(path): path = f'./training/{voice}/{filename}' @@ -2333,7 +2351,7 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, resul duration = num_frames / sample_rate for segment in result['segments']: - file = filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav") + file = filename.replace(f".{extension}", f"_{pad(segment['id'], 4)}.{out_extension}") sliced, error = slice_waveform( waveform, sample_rate, segment['start'] + start_offset, segment['end'] + end_offset, trim_silence ) if error: @@ -2341,12 +2359,18 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, resul print(message) messages.append(message) continue + sliced, _ = resample( sliced, sample_rate, TARGET_SAMPLE_RATE ) if waveform.shape[0] == 2: waveform = waveform[:1] - torchaudio.save(f"{indir}/audio/{file}", sliced, TARGET_SAMPLE_RATE, encoding="PCM_S", bits_per_sample=16) + kwargs = {} + if file[-4:] == ".wav": + kwargs['encoding'] = "PCM_S" + kwargs['bits_per_sample'] = 16 + + torchaudio.save(f"{indir}/audio/{file}", sliced, TARGET_SAMPLE_RATE, **kwargs) segments +=1 @@ -2462,18 +2486,32 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p errored = 0 messages = [] - normalize = True + normalize = False # True phonemize = should_phonemize() lines = { 'training': [], 'validation': [] } segments = {} + quantize_in_memory = args.tts_backend == "vall-e" + if args.tts_backend != "tortoise": text_length = 0 audio_length = 0 + start_offset = -0.1 + end_offset = 0.1 + trim_silence = False + + TARGET_SAMPLE_RATE = 22050 + if args.tts_backend != "tortoise": + TARGET_SAMPLE_RATE = 24000 + if tts: + TARGET_SAMPLE_RATE = tts.input_sample_rate + for filename in tqdm(results, desc="Parsing results"): use_segment = use_segments + extension = os.path.splitext(filename)[-1][1:] + out_extension = extension # "wav" result = results[filename] lang = result['language'] language = LANGUAGES[lang] if lang in LANGUAGES else lang @@ -2481,8 +2519,8 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p # check if unsegmented text exceeds 200 characters if not use_segment: - if len(result['text']) > 200: - message = f"Text length too long (200 < {len(result['text'])}), using segments: {filename}" + if len(result['text']) > MAX_TRAINING_CHAR_LENGTH: + message = f"Text length too long ({MAX_TRAINING_CHAR_LENGTH} < {len(result['text'])}), using segments: {filename}" print(message) messages.append(message) use_segment = True @@ -2490,13 +2528,15 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p # check if unsegmented audio exceeds 11.6s if not use_segment: path = f'{indir}/audio/{filename}' - if not os.path.exists(path): + if not quantize_in_memory and not os.path.exists(path): messages.append(f"Missing source audio: {filename}") errored += 1 continue - metadata = torchaudio.info(path) - duration = metadata.num_frames / metadata.sample_rate + duration = 0 + for segment in result['segments']: + duration = max(duration, result['segments'][segment]['end']) + if duration >= MAX_TRAINING_DURATION: message = f"Audio too large, using segments: {filename}" print(message) @@ -2511,19 +2551,36 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p if duration <= MIN_TRAINING_DURATION or MAX_TRAINING_DURATION <= duration: continue - path = f'{indir}/audio/' + filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav") + path = f'{indir}/audio/' + filename.replace(f".{extension}", f"_{pad(segment['id'], 4)}.{out_extension}") if os.path.exists(path): continue exists = False break - if not exists: + if not quantize_in_memory and not exists: tmp = {} tmp[filename] = result print(f"Audio not segmented, segmenting: {filename}") message = slice_dataset( voice, results=tmp ) print(message) messages = messages + message.split("\n") + + waveform = None + + + if quantize_in_memory: + path = f'{indir}/audio/{filename}' + if not os.path.exists(path): + path = f'./voices/{voice}/{filename}' + + if not os.path.exists(path): + message = f"Audio not found: {path}" + print(message) + messages.append(message) + #continue + else: + waveform = torchaudio.load(path) + waveform = resample(waveform[0], waveform[1], TARGET_SAMPLE_RATE) if not use_segment: segments[filename] = { @@ -2533,13 +2590,18 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p 'normalizer': normalizer, 'phonemes': result['phonemes'] if 'phonemes' in result else None } + + if waveform: + segments[filename]['waveform'] = waveform else: for segment in result['segments']: duration = segment['end'] - segment['start'] if duration <= MIN_TRAINING_DURATION or MAX_TRAINING_DURATION <= duration: continue - segments[filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav")] = { + file = filename.replace(f".{extension}", f"_{pad(segment['id'], 4)}.{out_extension}") + + segments[file] = { 'text': segment['text'], 'lang': lang, 'language': language, @@ -2547,22 +2609,27 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p 'phonemes': segment['phonemes'] if 'phonemes' in segment else None } + if waveform: + sliced, error = slice_waveform( waveform[0], waveform[1], segment['start'] + start_offset, segment['end'] + end_offset, trim_silence ) + if error: + message = f"{error}, skipping... {file}" + print(message) + messages.append(message) + segments[file]['error'] = error + #continue + else: + segments[file]['waveform'] = (sliced, waveform[1]) + jobs = { 'quantize': [[], []], 'phonemize': [[], []], } for file in tqdm(segments, desc="Parsing segments"): + extension = os.path.splitext(file)[-1][1:] result = segments[file] path = f'{indir}/audio/{file}' - if not os.path.exists(path): - message = f"Missing segment, skipping... {file}" - print(message) - messages.append(message) - errored += 1 - continue - text = result['text'] lang = result['lang'] language = result['language'] @@ -2573,28 +2640,20 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p normalized = normalizer(text) if normalize else text - if len(text) > 200: - message = f"Text length too long (200 < {len(text)}), skipping... {file}" + if len(text) > MAX_TRAINING_CHAR_LENGTH: + message = f"Text length too long ({MAX_TRAINING_CHAR_LENGTH} < {len(text)}), skipping... {file}" print(message) messages.append(message) errored += 1 continue - waveform, sample_rate = torchaudio.load(path) - num_channels, num_frames = waveform.shape - duration = num_frames / sample_rate + # num_channels, num_frames = waveform.shape + #duration = num_frames / sample_rate - error = validate_waveform( waveform, sample_rate ) - if error: - message = f"{error}, skipping... {file}" - print(message) - messages.append(message) - errored += 1 - continue culled = len(text) < text_length - if not culled and audio_length > 0: - culled = duration < audio_length + #if not culled and audio_length > 0: + # culled = duration < audio_length line = f'audio/{file}|{phonemes if phonemize and phonemes else text}' @@ -2605,17 +2664,8 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p os.makedirs(f'{indir}/valle/', exist_ok=True) - qnt_file = f'{indir}/valle/{file.replace(".wav",".qnt.pt")}' - if not os.path.exists(qnt_file): - jobs['quantize'][0].append(qnt_file) - jobs['quantize'][1].append((waveform, sample_rate)) - """ - quantized = valle_quantize( waveform, sample_rate ).cpu() - torch.save(quantized, f'{indir}/valle/{file.replace(".wav",".qnt.pt")}') - print("Quantized:", file) - """ - - phn_file = f'{indir}/valle/{file.replace(".wav",".phn.txt")}' + #phn_file = f'{indir}/valle/{file.replace(f".{extension}",".phn.txt")}' + phn_file = f'./training/valle/data/{voice}/{file.replace(f".{extension}",".phn.txt")}' if not os.path.exists(phn_file): jobs['phonemize'][0].append(phn_file) jobs['phonemize'][1].append(normalized) @@ -2625,13 +2675,46 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p print("Phonemized:", file, normalized, text) """ + #qnt_file = f'{indir}/valle/{file.replace(f".{extension}",".qnt.pt")}' + qnt_file = f'./training/valle/data/{voice}/{file.replace(f".{extension}",".qnt.pt")}' + if 'error' not in result: + if not quantize_in_memory and not os.path.exists(path): + message = f"Missing segment, skipping... {file}" + print(message) + messages.append(message) + errored += 1 + continue + + if not os.path.exists(qnt_file): + waveform = None + if 'waveform' in result: + waveform, sample_rate = result['waveform'] + elif os.path.exists(path): + waveform, sample_rate = torchaudio.load(path) + error = validate_waveform( waveform, sample_rate ) + if error: + message = f"{error}, skipping... {file}" + print(message) + messages.append(message) + errored += 1 + continue + + if waveform is not None: + jobs['quantize'][0].append(qnt_file) + jobs['quantize'][1].append((waveform, sample_rate)) + """ + quantized = valle_quantize( waveform, sample_rate ).cpu() + torch.save(quantized, f'{indir}/valle/{file.replace(".wav",".qnt.pt")}') + print("Quantized:", file) + """ + for i in tqdm(range(len(jobs['quantize'][0])), desc="Quantizing"): qnt_file = jobs['quantize'][0][i] waveform, sample_rate = jobs['quantize'][1][i] quantized = valle_quantize( waveform, sample_rate ).cpu() torch.save(quantized, qnt_file) - print("Quantized:", qnt_file) + #print("Quantized:", qnt_file) for i in tqdm(range(len(jobs['phonemize'][0])), desc="Phonemizing"): phn_file = jobs['phonemize'][0][i] @@ -2640,7 +2723,7 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p try: phonemized = valle_phonemize( normalized ) open(phn_file, 'w', encoding='utf-8').write(" ".join(phonemized)) - print("Phonemized:", phn_file) + #print("Phonemized:", phn_file) except Exception as e: message = f"Failed to phonemize: {phn_file}: {normalized}" messages.append(message) @@ -2970,17 +3053,26 @@ def import_voices(files, saveAs=None, progress=None): def relative_paths( dirs ): return [ './' + os.path.relpath( d ).replace("\\", "/") for d in dirs ] -def get_voice( name, dir=get_voice_dir(), load_latents=True ): +def get_voice( name, dir=get_voice_dir(), load_latents=True, extensions=["wav", "mp3", "flac"] ): subj = f'{dir}/{name}/' if not os.path.isdir(subj): return - - voice = list(glob(f'{subj}/*.wav')) + list(glob(f'{subj}/*.mp3')) + list(glob(f'{subj}/*.flac')) + files = os.listdir(subj) + if load_latents: - voice = voice + list(glob(f'{subj}/*.pth')) + extensions.append("pth") + + voice = [] + for file in files: + ext = os.path.splitext(file)[-1][1:] + if ext not in extensions: + continue + + voice.append(f'{subj}/{file}') + return sorted( voice ) -def get_voice_list(dir=get_voice_dir(), append_defaults=False): +def get_voice_list(dir=get_voice_dir(), append_defaults=False, extensions=["wav", "mp3", "flac", "pth"]): defaults = [ "random", "microphone" ] os.makedirs(dir, exist_ok=True) #res = sorted([d for d in os.listdir(dir) if d not in defaults and os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 ]) @@ -2993,7 +3085,7 @@ def get_voice_list(dir=get_voice_dir(), append_defaults=False): continue if len(os.listdir(os.path.join(dir, name))) == 0: continue - files = get_voice( name, dir=dir ) + files = get_voice( name, dir=dir, extensions=extensions ) if len(files) > 0: res.append(name) @@ -3001,7 +3093,7 @@ def get_voice_list(dir=get_voice_dir(), append_defaults=False): for subdir in os.listdir(f'{dir}/{name}'): if not os.path.isdir(f'{dir}/{name}/{subdir}'): continue - files = get_voice( f'{name}/{subdir}', dir=dir ) + files = get_voice( f'{name}/{subdir}', dir=dir, extensions=extensions ) if len(files) == 0: continue res.append(f'{name}/{subdir}')