diff --git a/src/utils.py b/src/utils.py index 94e7769..00d9c24 100755 --- a/src/utils.py +++ b/src/utils.py @@ -679,14 +679,23 @@ def generate_valle(**kwargs): def fetch_voice( voice ): if voice in voice_cache: return voice_cache[voice] + + """ voice_dir = f'./training/{voice}/audio/' if not os.path.isdir(voice_dir) or len(os.listdir(voice_dir)) == 0: voice_dir = f'./voices/{voice}/' files = [ f'{voice_dir}/{d}' for d in os.listdir(voice_dir) if d[-4:] == ".wav" ] + """ + + if os.path.isdir(f'./training/{voice}/audio/'): + files = get_voice(name="audio", dir=f"./training/{voice}/", load_latents=False) + else: + files = get_voice(name=voice, load_latents=False) + # return files - voice_cache[voice] = random.choice(files) + voice_cache[voice] = random.sample(files, k=min(3, len(files))) return voice_cache[voice] def get_settings( override=None ): @@ -707,7 +716,7 @@ def generate_valle(**kwargs): continue settings[k] = override[k] - settings['references'] = [ fetch_voice(voice=selected_voice) for _ in range(3) ] + settings['references'] = fetch_voice(voice=selected_voice) # [ fetch_voice(voice=selected_voice) for _ in range(3) ] return settings if not parameters['delimiter']: