forked from mrq/ai-voice-cloning
added to embedded metadata: datetime, model path, model hash
This commit is contained in:
parent
81eb58f0d6
commit
787b44807a
64
src/utils.py
64
src/utils.py
|
@ -307,6 +307,10 @@ def generate(
|
||||||
'cond_free_k': cond_free_k,
|
'cond_free_k': cond_free_k,
|
||||||
'experimentals': experimental_checkboxes,
|
'experimentals': experimental_checkboxes,
|
||||||
'time': time.time()-full_start_time,
|
'time': time.time()-full_start_time,
|
||||||
|
|
||||||
|
'datetime': datetime.now().isoformat(),
|
||||||
|
'model': tts.autoregressive_model_path,
|
||||||
|
'model_hash': tts.autoregressive_model_hash if hasattr(tts, 'autoregressive_model_hash') else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -324,6 +328,7 @@ def generate(
|
||||||
|
|
||||||
if args.voice_fixer:
|
if args.voice_fixer:
|
||||||
if not voicefixer:
|
if not voicefixer:
|
||||||
|
progress(0, "Loading voicefix...")
|
||||||
load_voicefixer()
|
load_voicefixer()
|
||||||
|
|
||||||
fixed_cache = {}
|
fixed_cache = {}
|
||||||
|
@ -1006,7 +1011,33 @@ def get_voice_list(dir=get_voice_dir(), append_defaults=False):
|
||||||
res = res + ["random", "microphone"]
|
res = res + ["random", "microphone"]
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def get_autoregressive_models(dir="./models/finetunes/"):
|
def hash_file(path, algo="md5", buffer_size=0):
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
hash = None
|
||||||
|
if algo == "md5":
|
||||||
|
hash = hashlib.md5()
|
||||||
|
elif algo == "sha1":
|
||||||
|
hash = hashlib.sha1()
|
||||||
|
else:
|
||||||
|
raise Exception(f'Unknown hash algorithm specified: {algo}')
|
||||||
|
|
||||||
|
if not os.path.exists(path):
|
||||||
|
raise Exception(f'Path not found: {path}')
|
||||||
|
|
||||||
|
with open(path, 'rb') as f:
|
||||||
|
if buffer_size > 0:
|
||||||
|
while True:
|
||||||
|
data = f.read(buffer_size)
|
||||||
|
if not data:
|
||||||
|
break
|
||||||
|
hash.update(data)
|
||||||
|
else:
|
||||||
|
hash.update(f.read())
|
||||||
|
|
||||||
|
return "{0}".format(hash.hexdigest())
|
||||||
|
|
||||||
|
def get_autoregressive_models(dir="./models/finetunes/", prefixed=False):
|
||||||
os.makedirs(dir, exist_ok=True)
|
os.makedirs(dir, exist_ok=True)
|
||||||
base = [get_model_path('autoregressive.pth')]
|
base = [get_model_path('autoregressive.pth')]
|
||||||
halfp = get_halfp_model_path()
|
halfp = get_halfp_model_path()
|
||||||
|
@ -1018,12 +1049,20 @@ def get_autoregressive_models(dir="./models/finetunes/"):
|
||||||
for training in os.listdir(f'./training/'):
|
for training in os.listdir(f'./training/'):
|
||||||
if not os.path.isdir(f'./training/{training}/') or not os.path.isdir(f'./training/{training}/models/'):
|
if not os.path.isdir(f'./training/{training}/') or not os.path.isdir(f'./training/{training}/models/'):
|
||||||
continue
|
continue
|
||||||
#found = found + sorted([ f'./training/{training}/models/{d}' for d in os.listdir(f'./training/{training}/models/') if d[-8:] == "_gpt.pth" ])
|
|
||||||
models = sorted([ int(d[:-8]) for d in os.listdir(f'./training/{training}/models/') if d[-8:] == "_gpt.pth" ])
|
models = sorted([ int(d[:-8]) for d in os.listdir(f'./training/{training}/models/') if d[-8:] == "_gpt.pth" ])
|
||||||
found = found + [ f'./training/{training}/models/{d}_gpt.pth' for d in models ]
|
found = found + [ f'./training/{training}/models/{d}_gpt.pth' for d in models ]
|
||||||
#found.append(f'./training/{training}/models/{models[-1]}_gpt.pth')
|
|
||||||
|
|
||||||
return base + additionals + found
|
res = base + additionals + found
|
||||||
|
|
||||||
|
if prefixed:
|
||||||
|
for i in range(len(res)):
|
||||||
|
path = res[i]
|
||||||
|
hash = hash_file(path)
|
||||||
|
shorthash = hash[:8]
|
||||||
|
|
||||||
|
res[i] = f'[{shorthash}] {path}'
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
def get_dataset_list(dir="./training/"):
|
def get_dataset_list(dir="./training/"):
|
||||||
return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 and "train.txt" in os.listdir(os.path.join(dir, d)) ])
|
return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 and "train.txt" in os.listdir(os.path.join(dir, d)) ])
|
||||||
|
@ -1256,8 +1295,6 @@ def save_args_settings():
|
||||||
'training-default-bnb': args.training_default_bnb,
|
'training-default-bnb': args.training_default_bnb,
|
||||||
}
|
}
|
||||||
|
|
||||||
print(settings)
|
|
||||||
|
|
||||||
os.makedirs('./config/', exist_ok=True)
|
os.makedirs('./config/', exist_ok=True)
|
||||||
with open(f'./config/exec.json', 'w', encoding="utf-8") as f:
|
with open(f'./config/exec.json', 'w', encoding="utf-8") as f:
|
||||||
f.write(json.dumps(settings, indent='\t') )
|
f.write(json.dumps(settings, indent='\t') )
|
||||||
|
@ -1322,9 +1359,7 @@ def read_generate_settings(file, read_latents=True):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if j is None:
|
if j is not None:
|
||||||
print("No metadata found in audio file to read")
|
|
||||||
else:
|
|
||||||
if 'latents' in j:
|
if 'latents' in j:
|
||||||
if read_latents:
|
if read_latents:
|
||||||
latents = base64.b64decode(j['latents'])
|
latents = base64.b64decode(j['latents'])
|
||||||
|
@ -1360,6 +1395,10 @@ def load_tts( restart=False, model=None ):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
tts = TextToSpeech(minor_optimizations=not args.low_vram)
|
tts = TextToSpeech(minor_optimizations=not args.low_vram)
|
||||||
load_autoregressive_model(args.autoregressive_model)
|
load_autoregressive_model(args.autoregressive_model)
|
||||||
|
|
||||||
|
if not hasattr(tts, 'autoregressive_model_hash'):
|
||||||
|
tts.autoregressive_model_hash = hash_file(tts.autoregressive_model_path)
|
||||||
|
|
||||||
tts_loading = False
|
tts_loading = False
|
||||||
|
|
||||||
get_model_path('dvae.pth')
|
get_model_path('dvae.pth')
|
||||||
|
@ -1381,6 +1420,10 @@ def reload_tts( model=None ):
|
||||||
load_tts( restart=True, model=model )
|
load_tts( restart=True, model=model )
|
||||||
|
|
||||||
def update_autoregressive_model(autoregressive_model_path):
|
def update_autoregressive_model(autoregressive_model_path):
|
||||||
|
match = re.findall(r'^\[[a-fA-F0-9]{8}\] (.+?)$', autoregressive_model_path)
|
||||||
|
if match:
|
||||||
|
autoregressive_model_path = match[0]
|
||||||
|
|
||||||
if not autoregressive_model_path or not os.path.exists(autoregressive_model_path):
|
if not autoregressive_model_path or not os.path.exists(autoregressive_model_path):
|
||||||
print(f"Invalid model: {autoregressive_model_path}")
|
print(f"Invalid model: {autoregressive_model_path}")
|
||||||
return
|
return
|
||||||
|
@ -1416,6 +1459,9 @@ def update_autoregressive_model(autoregressive_model_path):
|
||||||
if tts.preloaded_tensors:
|
if tts.preloaded_tensors:
|
||||||
tts.autoregressive = tts.autoregressive.to(tts.device)
|
tts.autoregressive = tts.autoregressive.to(tts.device)
|
||||||
|
|
||||||
|
if not hasattr(tts, 'autoregressive_model_hash'):
|
||||||
|
tts.autoregressive_model_hash = hash_file(autoregressive_model_path)
|
||||||
|
|
||||||
print(f"Loaded model: {tts.autoregressive_model_path}")
|
print(f"Loaded model: {tts.autoregressive_model_path}")
|
||||||
|
|
||||||
do_gc()
|
do_gc()
|
||||||
|
|
|
@ -129,6 +129,9 @@ history_headers = {
|
||||||
"Rep Pen": "repetition_penalty",
|
"Rep Pen": "repetition_penalty",
|
||||||
"Cond-Free K": "cond_free_k",
|
"Cond-Free K": "cond_free_k",
|
||||||
"Time": "time",
|
"Time": "time",
|
||||||
|
"Datetime": "datetime",
|
||||||
|
"Model": "model",
|
||||||
|
"Model Hash": "model_hash",
|
||||||
}
|
}
|
||||||
|
|
||||||
def history_view_results( voice ):
|
def history_view_results( voice ):
|
||||||
|
@ -147,7 +150,7 @@ def history_view_results( voice ):
|
||||||
for k in history_headers:
|
for k in history_headers:
|
||||||
v = file
|
v = file
|
||||||
if k != "Name":
|
if k != "Name":
|
||||||
v = metadata[history_headers[k]]
|
v = metadata[history_headers[k]] if history_headers[k] in metadata else '?'
|
||||||
values.append(v)
|
values.append(v)
|
||||||
|
|
||||||
|
|
||||||
|
@ -174,8 +177,6 @@ def read_generate_settings_proxy(file, saveAs='.temp'):
|
||||||
|
|
||||||
latents = f'{outdir}/cond_latents.pth'
|
latents = f'{outdir}/cond_latents.pth'
|
||||||
|
|
||||||
print(j, latents)
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
gr.update(value=j, visible=j is not None),
|
gr.update(value=j, visible=j is not None),
|
||||||
gr.update(visible=j is not None),
|
gr.update(visible=j is not None),
|
||||||
|
|
Loading…
Reference in New Issue
Block a user