forked from mrq/ai-voice-cloning
added update checking for dlas and tortoise-tts, caching voices (for a given model and voice name) so random latents will remain the same
This commit is contained in:
parent
166d491a98
commit
c27ee3ce95
22
src/utils.py
22
src/utils.py
|
@ -103,8 +103,12 @@ def generate(
|
||||||
if seed == 0:
|
if seed == 0:
|
||||||
seed = None
|
seed = None
|
||||||
|
|
||||||
|
voice_cache = {}
|
||||||
def fetch_voice( voice ):
|
def fetch_voice( voice ):
|
||||||
print(f"Loading voice: {voice}")
|
print(f"Loading voice: {voice} with model {tts.autoregressive_model_hash[:8]}")
|
||||||
|
cache_key = f'{voice}:{tts.autoregressive_model_hash[:8]}'
|
||||||
|
if cache_key in voice_cache:
|
||||||
|
return voice_cache[cache_key]
|
||||||
|
|
||||||
sample_voice = None
|
sample_voice = None
|
||||||
if voice == "microphone":
|
if voice == "microphone":
|
||||||
|
@ -126,7 +130,8 @@ def generate(
|
||||||
sample_voice = torch.cat(voice_samples, dim=-1).squeeze().cpu()
|
sample_voice = torch.cat(voice_samples, dim=-1).squeeze().cpu()
|
||||||
voice_samples = None
|
voice_samples = None
|
||||||
|
|
||||||
return (voice_samples, conditioning_latents, sample_voice)
|
voice_cache[cache_key] = (voice_samples, conditioning_latents, sample_voice)
|
||||||
|
return voice_cache[cache_key]
|
||||||
|
|
||||||
def get_settings( override=None ):
|
def get_settings( override=None ):
|
||||||
settings = {
|
settings = {
|
||||||
|
@ -1479,12 +1484,19 @@ def curl(url):
|
||||||
print(e)
|
print(e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def check_for_updates():
|
def check_for_updates( dir = None ):
|
||||||
if not os.path.isfile('./.git/FETCH_HEAD'):
|
if dir is None:
|
||||||
|
check_for_updates("./")
|
||||||
|
check_for_updates("./dlas/")
|
||||||
|
check_for_updates("./tortoise-tts/")
|
||||||
|
return
|
||||||
|
|
||||||
|
git_dir = f'{dir}/.git/'
|
||||||
|
if not os.path.isfile(f'{git_dir}/FETCH_HEAD'):
|
||||||
print("Cannot check for updates: not from a git repo")
|
print("Cannot check for updates: not from a git repo")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
with open(f'./.git/FETCH_HEAD', 'r', encoding="utf-8") as f:
|
with open(f'{git_dir}/FETCH_HEAD', 'r', encoding="utf-8") as f:
|
||||||
head = f.read()
|
head = f.read()
|
||||||
|
|
||||||
match = re.findall(r"^([a-f0-9]+).+?https:\/\/(.+?)\/(.+?)\/(.+?)\n", head)
|
match = re.findall(r"^([a-f0-9]+).+?https:\/\/(.+?)\/(.+?)\/(.+?)\n", head)
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
Subproject commit 26133c20314b77155e77be804b43909dab9809d6
|
Subproject commit cc36c0997c8711889ef8028002fc9e41abd5c5f0
|
Loading…
Reference in New Issue
Block a user