forked from mrq/ai-voice-cloning
reworked generating metadata to embed, should now store overrided settings
This commit is contained in:
parent
7798767fc6
commit
e731b9ba84
140
src/utils.py
140
src/utils.py
|
@ -282,6 +282,72 @@ def generate(
|
||||||
name = f"{name}_{candidate}"
|
name = f"{name}_{candidate}"
|
||||||
return name
|
return name
|
||||||
|
|
||||||
|
def get_info( voice, settings = None, latents = True ):
|
||||||
|
info = {
|
||||||
|
'text': text,
|
||||||
|
'delimiter': '\\n' if delimiter and delimiter == "\n" else delimiter,
|
||||||
|
'emotion': emotion,
|
||||||
|
'prompt': prompt,
|
||||||
|
'voice': voice,
|
||||||
|
'seed': seed,
|
||||||
|
'candidates': candidates,
|
||||||
|
'num_autoregressive_samples': num_autoregressive_samples,
|
||||||
|
'diffusion_iterations': diffusion_iterations,
|
||||||
|
'temperature': temperature,
|
||||||
|
'diffusion_sampler': diffusion_sampler,
|
||||||
|
'breathing_room': breathing_room,
|
||||||
|
'cvvp_weight': cvvp_weight,
|
||||||
|
'top_p': top_p,
|
||||||
|
'diffusion_temperature': diffusion_temperature,
|
||||||
|
'length_penalty': length_penalty,
|
||||||
|
'repetition_penalty': repetition_penalty,
|
||||||
|
'cond_free_k': cond_free_k,
|
||||||
|
'experimentals': experimental_checkboxes,
|
||||||
|
'time': time.time()-full_start_time,
|
||||||
|
|
||||||
|
'datetime': datetime.now().isoformat(),
|
||||||
|
'model': tts.autoregressive_model_path,
|
||||||
|
'model_hash': tts.autoregressive_model_hash if hasattr(tts, 'autoregressive_model_hash') else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
if settings is not None:
|
||||||
|
for k in settings:
|
||||||
|
if k in info:
|
||||||
|
info[k] = settings[k]
|
||||||
|
|
||||||
|
if 'half_p' in settings and 'cond_free' in settings:
|
||||||
|
info['experimentals'] = []
|
||||||
|
if settings['half_p']:
|
||||||
|
info['experimentals'].append("Half Precision")
|
||||||
|
if settings['cond_free']:
|
||||||
|
info['experimentals'].append("Conditioning-Free")
|
||||||
|
|
||||||
|
if latents and "latents" not in info:
|
||||||
|
voice = info['voice']
|
||||||
|
latents_path = f'{get_voice_dir()}/{voice}/cond_latents.pth'
|
||||||
|
|
||||||
|
if voice == "random" or voice == "microphone":
|
||||||
|
if latents and settings['conditioning_latents']:
|
||||||
|
dir = f'{get_voice_dir()}/{voice}/'
|
||||||
|
if not os.path.isdir(dir):
|
||||||
|
os.makedirs(dir, exist_ok=True)
|
||||||
|
latents_path = f'{dir}/cond_latents.pth'
|
||||||
|
torch.save(conditioning_latents, latents_path)
|
||||||
|
else:
|
||||||
|
if settings and "model_hash" in settings:
|
||||||
|
latents_path = f'{get_voice_dir()}/{voice}/cond_latents_{settings["model_hash"][:8]}.pth'
|
||||||
|
elif hasattr(tts, "autoregressive_model_hash"):
|
||||||
|
latents_path = f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth'
|
||||||
|
|
||||||
|
if latents_path and os.path.exists(latents_path):
|
||||||
|
try:
|
||||||
|
with open(latents_path, 'rb') as f:
|
||||||
|
info['latents'] = base64.b64encode(f.read()).decode("ascii")
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return info
|
||||||
|
|
||||||
for line, cut_text in enumerate(texts):
|
for line, cut_text in enumerate(texts):
|
||||||
if emotion == "Custom":
|
if emotion == "Custom":
|
||||||
if prompt and prompt.strip() != "":
|
if prompt and prompt.strip() != "":
|
||||||
|
@ -295,6 +361,7 @@ def generate(
|
||||||
|
|
||||||
# do setting editing
|
# do setting editing
|
||||||
match = re.findall(r'^(\{.+\}) (.+?)$', cut_text)
|
match = re.findall(r'^(\{.+\}) (.+?)$', cut_text)
|
||||||
|
override = None
|
||||||
if match and len(match) > 0:
|
if match and len(match) > 0:
|
||||||
match = match[0]
|
match = match[0]
|
||||||
try:
|
try:
|
||||||
|
@ -304,11 +371,11 @@ def generate(
|
||||||
raise Exception("Prompt settings editing requested, but received invalid JSON")
|
raise Exception("Prompt settings editing requested, but received invalid JSON")
|
||||||
|
|
||||||
cut_text = match[1].strip()
|
cut_text = match[1].strip()
|
||||||
new_settings = get_settings( override )
|
used_settings = get_settings( override )
|
||||||
|
|
||||||
gen, additionals = tts.tts(cut_text, **new_settings )
|
|
||||||
else:
|
else:
|
||||||
gen, additionals = tts.tts(cut_text, **settings )
|
used_settings = settings.copy()
|
||||||
|
|
||||||
|
gen, additionals = tts.tts(cut_text, **used_settings )
|
||||||
|
|
||||||
seed = additionals[0]
|
seed = additionals[0]
|
||||||
run_time = time.time()-start_time
|
run_time = time.time()-start_time
|
||||||
|
@ -320,10 +387,16 @@ def generate(
|
||||||
for j, g in enumerate(gen):
|
for j, g in enumerate(gen):
|
||||||
audio = g.squeeze(0).cpu()
|
audio = g.squeeze(0).cpu()
|
||||||
name = get_name(line=line, candidate=j)
|
name = get_name(line=line, candidate=j)
|
||||||
|
|
||||||
|
used_settings['text'] = cut_text
|
||||||
|
used_settings['time'] = run_time
|
||||||
|
used_settings['datetime'] = datetime.now().isoformat(),
|
||||||
|
used_settings['model'] = tts.autoregressive_model_path
|
||||||
|
used_settings['model_hash'] = tts.autoregressive_model_hash if hasattr(tts, 'autoregressive_model_hash') else None
|
||||||
|
|
||||||
audio_cache[name] = {
|
audio_cache[name] = {
|
||||||
'audio': audio,
|
'audio': audio,
|
||||||
'text': cut_text,
|
'settings': get_info(voice=override['voice'] if override and 'voice' in override else voice, settings=used_settings)
|
||||||
'time': run_time
|
|
||||||
}
|
}
|
||||||
# save here in case some error happens mid-batch
|
# save here in case some error happens mid-batch
|
||||||
torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, tts.output_sample_rate)
|
torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, tts.output_sample_rate)
|
||||||
|
@ -341,7 +414,7 @@ def generate(
|
||||||
|
|
||||||
audio_cache[k]['audio'] = audio
|
audio_cache[k]['audio'] = audio
|
||||||
torchaudio.save(f'{outdir}/{voice}_{k}.wav', audio, args.output_sample_rate)
|
torchaudio.save(f'{outdir}/{voice}_{k}.wav', audio, args.output_sample_rate)
|
||||||
|
|
||||||
output_voices = []
|
output_voices = []
|
||||||
for candidate in range(candidates):
|
for candidate in range(candidates):
|
||||||
if len(texts) > 1:
|
if len(texts) > 1:
|
||||||
|
@ -358,40 +431,13 @@ def generate(
|
||||||
audio = audio.squeeze(0).cpu()
|
audio = audio.squeeze(0).cpu()
|
||||||
audio_cache[name] = {
|
audio_cache[name] = {
|
||||||
'audio': audio,
|
'audio': audio,
|
||||||
'text': text,
|
'settings': get_info(voice=voice),
|
||||||
'time': time.time()-full_start_time,
|
|
||||||
'output': True
|
'output': True
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
name = get_name(candidate=candidate)
|
name = get_name(candidate=candidate)
|
||||||
audio_cache[name]['output'] = True
|
audio_cache[name]['output'] = True
|
||||||
|
|
||||||
info = {
|
|
||||||
'text': text,
|
|
||||||
'delimiter': '\\n' if delimiter and delimiter == "\n" else delimiter,
|
|
||||||
'emotion': emotion,
|
|
||||||
'prompt': prompt,
|
|
||||||
'voice': voice,
|
|
||||||
'seed': seed,
|
|
||||||
'candidates': candidates,
|
|
||||||
'num_autoregressive_samples': num_autoregressive_samples,
|
|
||||||
'diffusion_iterations': diffusion_iterations,
|
|
||||||
'temperature': temperature,
|
|
||||||
'diffusion_sampler': diffusion_sampler,
|
|
||||||
'breathing_room': breathing_room,
|
|
||||||
'cvvp_weight': cvvp_weight,
|
|
||||||
'top_p': top_p,
|
|
||||||
'diffusion_temperature': diffusion_temperature,
|
|
||||||
'length_penalty': length_penalty,
|
|
||||||
'repetition_penalty': repetition_penalty,
|
|
||||||
'cond_free_k': cond_free_k,
|
|
||||||
'experimentals': experimental_checkboxes,
|
|
||||||
'time': time.time()-full_start_time,
|
|
||||||
|
|
||||||
'datetime': datetime.now().isoformat(),
|
|
||||||
'model': tts.autoregressive_model_path,
|
|
||||||
'model_hash': tts.autoregressive_model_hash if hasattr(tts, 'autoregressive_model_hash') else None,
|
|
||||||
}
|
|
||||||
|
|
||||||
if args.voice_fixer:
|
if args.voice_fixer:
|
||||||
if not voicefixer:
|
if not voicefixer:
|
||||||
|
@ -414,8 +460,7 @@ def generate(
|
||||||
)
|
)
|
||||||
|
|
||||||
fixed_cache[f'{name}_fixed'] = {
|
fixed_cache[f'{name}_fixed'] = {
|
||||||
'text': audio_cache[name]['text'],
|
'settings': audio_cache[name]['settings'],
|
||||||
'time': audio_cache[name]['time'],
|
|
||||||
'output': True
|
'output': True
|
||||||
}
|
}
|
||||||
audio_cache[name]['output'] = False
|
audio_cache[name]['output'] = False
|
||||||
|
@ -434,36 +479,21 @@ def generate(
|
||||||
|
|
||||||
if not args.embed_output_metadata:
|
if not args.embed_output_metadata:
|
||||||
with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f:
|
with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f:
|
||||||
f.write(json.dumps(info, indent='\t') )
|
f.write(json.dumps(audio_cache[name]['settings'], indent='\t') )
|
||||||
|
|
||||||
|
|
||||||
if voice and voice != "random" and conditioning_latents is not None:
|
|
||||||
latents_path = f'{get_voice_dir()}/{voice}/cond_latents.pth'
|
|
||||||
|
|
||||||
if hasattr(tts, 'autoregressive_model_hash'):
|
|
||||||
latents_path = f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth'
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(latents_path, 'rb') as f:
|
|
||||||
info['latents'] = base64.b64encode(f.read()).decode("ascii")
|
|
||||||
except Exception as e:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if args.embed_output_metadata:
|
if args.embed_output_metadata:
|
||||||
for name in progress.tqdm(audio_cache, desc="Embedding metadata..."):
|
for name in progress.tqdm(audio_cache, desc="Embedding metadata..."):
|
||||||
if 'pruned' in audio_cache[name] and audio_cache[name]['pruned']:
|
if 'pruned' in audio_cache[name] and audio_cache[name]['pruned']:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
info['text'] = audio_cache[name]['text']
|
|
||||||
info['time'] = audio_cache[name]['time']
|
|
||||||
|
|
||||||
metadata = music_tag.load_file(f"{outdir}/{voice}_{name}.wav")
|
metadata = music_tag.load_file(f"{outdir}/{voice}_{name}.wav")
|
||||||
metadata['lyrics'] = json.dumps(info)
|
metadata['lyrics'] = json.dumps(audio_cache[name]['settings'])
|
||||||
metadata.save()
|
metadata.save()
|
||||||
|
|
||||||
if sample_voice is not None:
|
if sample_voice is not None:
|
||||||
sample_voice = (tts.input_sample_rate, sample_voice.numpy())
|
sample_voice = (tts.input_sample_rate, sample_voice.numpy())
|
||||||
|
|
||||||
|
info = get_info(voice=voice, latents=False)
|
||||||
print(f"Generation took {info['time']} seconds, saved to '{output_voices[0]}'\n")
|
print(f"Generation took {info['time']} seconds, saved to '{output_voices[0]}'\n")
|
||||||
|
|
||||||
info['seed'] = settings['use_deterministic_seed']
|
info['seed'] = settings['use_deterministic_seed']
|
||||||
|
|
Loading…
Reference in New Issue
Block a user