From e731b9ba84bc407cab473817bca6f3346cba1718 Mon Sep 17 00:00:00 2001 From: mrq <barry.quiggles@protonmail.com> Date: Mon, 6 Mar 2023 23:07:16 +0000 Subject: [PATCH] reworked generating metadata to embed, should now store overrided settings --- src/utils.py | 140 +++++++++++++++++++++++++++++++-------------------- 1 file changed, 85 insertions(+), 55 deletions(-) diff --git a/src/utils.py b/src/utils.py index 0074bc6..6fe3ac5 100755 --- a/src/utils.py +++ b/src/utils.py @@ -282,6 +282,72 @@ def generate( name = f"{name}_{candidate}" return name + def get_info( voice, settings = None, latents = True ): + info = { + 'text': text, + 'delimiter': '\\n' if delimiter and delimiter == "\n" else delimiter, + 'emotion': emotion, + 'prompt': prompt, + 'voice': voice, + 'seed': seed, + 'candidates': candidates, + 'num_autoregressive_samples': num_autoregressive_samples, + 'diffusion_iterations': diffusion_iterations, + 'temperature': temperature, + 'diffusion_sampler': diffusion_sampler, + 'breathing_room': breathing_room, + 'cvvp_weight': cvvp_weight, + 'top_p': top_p, + 'diffusion_temperature': diffusion_temperature, + 'length_penalty': length_penalty, + 'repetition_penalty': repetition_penalty, + 'cond_free_k': cond_free_k, + 'experimentals': experimental_checkboxes, + 'time': time.time()-full_start_time, + + 'datetime': datetime.now().isoformat(), + 'model': tts.autoregressive_model_path, + 'model_hash': tts.autoregressive_model_hash if hasattr(tts, 'autoregressive_model_hash') else None, + } + + if settings is not None: + for k in settings: + if k in info: + info[k] = settings[k] + + if 'half_p' in settings and 'cond_free' in settings: + info['experimentals'] = [] + if settings['half_p']: + info['experimentals'].append("Half Precision") + if settings['cond_free']: + info['experimentals'].append("Conditioning-Free") + + if latents and "latents" not in info: + voice = info['voice'] + latents_path = f'{get_voice_dir()}/{voice}/cond_latents.pth' + + if voice == "random" or voice == "microphone": + if latents and settings['conditioning_latents']: + dir = f'{get_voice_dir()}/{voice}/' + if not os.path.isdir(dir): + os.makedirs(dir, exist_ok=True) + latents_path = f'{dir}/cond_latents.pth' + torch.save(conditioning_latents, latents_path) + else: + if settings and "model_hash" in settings: + latents_path = f'{get_voice_dir()}/{voice}/cond_latents_{settings["model_hash"][:8]}.pth' + elif hasattr(tts, "autoregressive_model_hash"): + latents_path = f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth' + + if latents_path and os.path.exists(latents_path): + try: + with open(latents_path, 'rb') as f: + info['latents'] = base64.b64encode(f.read()).decode("ascii") + except Exception as e: + pass + + return info + for line, cut_text in enumerate(texts): if emotion == "Custom": if prompt and prompt.strip() != "": @@ -295,6 +361,7 @@ def generate( # do setting editing match = re.findall(r'^(\{.+\}) (.+?)$', cut_text) + override = None if match and len(match) > 0: match = match[0] try: @@ -304,11 +371,11 @@ def generate( raise Exception("Prompt settings editing requested, but received invalid JSON") cut_text = match[1].strip() - new_settings = get_settings( override ) - - gen, additionals = tts.tts(cut_text, **new_settings ) + used_settings = get_settings( override ) else: - gen, additionals = tts.tts(cut_text, **settings ) + used_settings = settings.copy() + + gen, additionals = tts.tts(cut_text, **used_settings ) seed = additionals[0] run_time = time.time()-start_time @@ -320,10 +387,16 @@ def generate( for j, g in enumerate(gen): audio = g.squeeze(0).cpu() name = get_name(line=line, candidate=j) + + used_settings['text'] = cut_text + used_settings['time'] = run_time + used_settings['datetime'] = datetime.now().isoformat(), + used_settings['model'] = tts.autoregressive_model_path + used_settings['model_hash'] = tts.autoregressive_model_hash if hasattr(tts, 'autoregressive_model_hash') else None + audio_cache[name] = { 'audio': audio, - 'text': cut_text, - 'time': run_time + 'settings': get_info(voice=override['voice'] if override and 'voice' in override else voice, settings=used_settings) } # save here in case some error happens mid-batch torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, tts.output_sample_rate) @@ -341,7 +414,7 @@ def generate( audio_cache[k]['audio'] = audio torchaudio.save(f'{outdir}/{voice}_{k}.wav', audio, args.output_sample_rate) - + output_voices = [] for candidate in range(candidates): if len(texts) > 1: @@ -358,40 +431,13 @@ def generate( audio = audio.squeeze(0).cpu() audio_cache[name] = { 'audio': audio, - 'text': text, - 'time': time.time()-full_start_time, + 'settings': get_info(voice=voice), 'output': True } else: name = get_name(candidate=candidate) audio_cache[name]['output'] = True - info = { - 'text': text, - 'delimiter': '\\n' if delimiter and delimiter == "\n" else delimiter, - 'emotion': emotion, - 'prompt': prompt, - 'voice': voice, - 'seed': seed, - 'candidates': candidates, - 'num_autoregressive_samples': num_autoregressive_samples, - 'diffusion_iterations': diffusion_iterations, - 'temperature': temperature, - 'diffusion_sampler': diffusion_sampler, - 'breathing_room': breathing_room, - 'cvvp_weight': cvvp_weight, - 'top_p': top_p, - 'diffusion_temperature': diffusion_temperature, - 'length_penalty': length_penalty, - 'repetition_penalty': repetition_penalty, - 'cond_free_k': cond_free_k, - 'experimentals': experimental_checkboxes, - 'time': time.time()-full_start_time, - - 'datetime': datetime.now().isoformat(), - 'model': tts.autoregressive_model_path, - 'model_hash': tts.autoregressive_model_hash if hasattr(tts, 'autoregressive_model_hash') else None, - } if args.voice_fixer: if not voicefixer: @@ -414,8 +460,7 @@ def generate( ) fixed_cache[f'{name}_fixed'] = { - 'text': audio_cache[name]['text'], - 'time': audio_cache[name]['time'], + 'settings': audio_cache[name]['settings'], 'output': True } audio_cache[name]['output'] = False @@ -434,36 +479,21 @@ def generate( if not args.embed_output_metadata: with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f: - f.write(json.dumps(info, indent='\t') ) - - - if voice and voice != "random" and conditioning_latents is not None: - latents_path = f'{get_voice_dir()}/{voice}/cond_latents.pth' - - if hasattr(tts, 'autoregressive_model_hash'): - latents_path = f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth' - - try: - with open(latents_path, 'rb') as f: - info['latents'] = base64.b64encode(f.read()).decode("ascii") - except Exception as e: - pass + f.write(json.dumps(audio_cache[name]['settings'], indent='\t') ) if args.embed_output_metadata: for name in progress.tqdm(audio_cache, desc="Embedding metadata..."): if 'pruned' in audio_cache[name] and audio_cache[name]['pruned']: continue - info['text'] = audio_cache[name]['text'] - info['time'] = audio_cache[name]['time'] - metadata = music_tag.load_file(f"{outdir}/{voice}_{name}.wav") - metadata['lyrics'] = json.dumps(info) + metadata['lyrics'] = json.dumps(audio_cache[name]['settings']) metadata.save() if sample_voice is not None: sample_voice = (tts.input_sample_rate, sample_voice.numpy()) + info = get_info(voice=voice, latents=False) print(f"Generation took {info['time']} seconds, saved to '{output_voices[0]}'\n") info['seed'] = settings['use_deterministic_seed']