From e731b9ba84bc407cab473817bca6f3346cba1718 Mon Sep 17 00:00:00 2001
From: mrq <barry.quiggles@protonmail.com>
Date: Mon, 6 Mar 2023 23:07:16 +0000
Subject: [PATCH] reworked generating metadata to embed, should now store
 overrided settings

---
 src/utils.py | 140 +++++++++++++++++++++++++++++++--------------------
 1 file changed, 85 insertions(+), 55 deletions(-)

diff --git a/src/utils.py b/src/utils.py
index 0074bc6..6fe3ac5 100755
--- a/src/utils.py
+++ b/src/utils.py
@@ -282,6 +282,72 @@ def generate(
 			name = f"{name}_{candidate}"
 		return name
 
+	def get_info( voice, settings = None, latents = True ):
+		info = {
+			'text': text,
+			'delimiter': '\\n' if delimiter and delimiter == "\n" else delimiter,
+			'emotion': emotion,
+			'prompt': prompt,
+			'voice': voice,
+			'seed': seed,
+			'candidates': candidates,
+			'num_autoregressive_samples': num_autoregressive_samples,
+			'diffusion_iterations': diffusion_iterations,
+			'temperature': temperature,
+			'diffusion_sampler': diffusion_sampler,
+			'breathing_room': breathing_room,
+			'cvvp_weight': cvvp_weight,
+			'top_p': top_p,
+			'diffusion_temperature': diffusion_temperature,
+			'length_penalty': length_penalty,
+			'repetition_penalty': repetition_penalty,
+			'cond_free_k': cond_free_k,
+			'experimentals': experimental_checkboxes,
+			'time': time.time()-full_start_time,
+
+			'datetime': datetime.now().isoformat(),
+			'model': tts.autoregressive_model_path,
+			'model_hash': tts.autoregressive_model_hash if hasattr(tts, 'autoregressive_model_hash') else None,
+		}
+
+		if settings is not None:
+			for k in settings:
+				if k in info:
+					info[k] = settings[k]
+
+			if 'half_p' in settings and 'cond_free' in settings:
+				info['experimentals'] = []
+				if settings['half_p']:
+					info['experimentals'].append("Half Precision")
+				if settings['cond_free']:
+					info['experimentals'].append("Conditioning-Free")
+
+		if latents and "latents" not in info:
+			voice = info['voice']
+			latents_path = f'{get_voice_dir()}/{voice}/cond_latents.pth'
+
+			if voice == "random" or voice == "microphone":
+				if latents and settings['conditioning_latents']:
+					dir = f'{get_voice_dir()}/{voice}/'
+					if not os.path.isdir(dir):
+						os.makedirs(dir, exist_ok=True)
+					latents_path = f'{dir}/cond_latents.pth'
+					torch.save(conditioning_latents, latents_path)
+			else:
+				if settings and "model_hash" in settings:
+					latents_path = f'{get_voice_dir()}/{voice}/cond_latents_{settings["model_hash"][:8]}.pth'
+				elif hasattr(tts, "autoregressive_model_hash"):
+					latents_path = f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth'
+
+			if latents_path and os.path.exists(latents_path):
+				try:
+					with open(latents_path, 'rb') as f:
+						info['latents'] = base64.b64encode(f.read()).decode("ascii")
+				except Exception as e:
+					pass
+
+		return info
+
 	for line, cut_text in enumerate(texts):
 		if emotion == "Custom":
 			if prompt and prompt.strip() != "":
@@ -295,6 +361,7 @@ def generate(
 
 		# do setting editing
 		match = re.findall(r'^(\{.+\}) (.+?)$', cut_text) 
+		override = None
 		if match and len(match) > 0:
 			match = match[0]
 			try:
@@ -304,11 +371,11 @@ def generate(
 				raise Exception("Prompt settings editing requested, but received invalid JSON")
 
 			cut_text = match[1].strip()
-			new_settings = get_settings( override )
-
-			gen, additionals = tts.tts(cut_text, **new_settings )
+			used_settings = get_settings( override )
 		else:
-			gen, additionals = tts.tts(cut_text, **settings )
+			used_settings = settings.copy()
+
+		gen, additionals = tts.tts(cut_text, **used_settings )
 
 		seed = additionals[0]
 		run_time = time.time()-start_time
@@ -320,10 +387,16 @@ def generate(
 		for j, g in enumerate(gen):
 			audio = g.squeeze(0).cpu()
 			name = get_name(line=line, candidate=j)
+
+			used_settings['text'] = cut_text
+			used_settings['time'] = run_time
+			used_settings['datetime'] = datetime.now().isoformat(),
+			used_settings['model'] = tts.autoregressive_model_path
+			used_settings['model_hash'] = tts.autoregressive_model_hash if hasattr(tts, 'autoregressive_model_hash') else None
+
 			audio_cache[name] = {
 				'audio': audio,
-				'text': cut_text,
-				'time': run_time
+				'settings': get_info(voice=override['voice'] if override and 'voice' in override else voice, settings=used_settings)
 			}
 			# save here in case some error happens mid-batch
 			torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, tts.output_sample_rate)
@@ -341,7 +414,7 @@ def generate(
 
 		audio_cache[k]['audio'] = audio
 		torchaudio.save(f'{outdir}/{voice}_{k}.wav', audio, args.output_sample_rate)
- 
+
 	output_voices = []
 	for candidate in range(candidates):
 		if len(texts) > 1:
@@ -358,40 +431,13 @@ def generate(
 			audio = audio.squeeze(0).cpu()
 			audio_cache[name] = {
 				'audio': audio,
-				'text': text,
-				'time': time.time()-full_start_time,
+				'settings': get_info(voice=voice),
 				'output': True
 			}
 		else:
 			name = get_name(candidate=candidate)
 			audio_cache[name]['output'] = True
 
-	info = {
-		'text': text,
-		'delimiter': '\\n' if delimiter and delimiter == "\n" else delimiter,
-		'emotion': emotion,
-		'prompt': prompt,
-		'voice': voice,
-		'seed': seed,
-		'candidates': candidates,
-		'num_autoregressive_samples': num_autoregressive_samples,
-		'diffusion_iterations': diffusion_iterations,
-		'temperature': temperature,
-		'diffusion_sampler': diffusion_sampler,
-		'breathing_room': breathing_room,
-		'cvvp_weight': cvvp_weight,
-		'top_p': top_p,
-		'diffusion_temperature': diffusion_temperature,
-		'length_penalty': length_penalty,
-		'repetition_penalty': repetition_penalty,
-		'cond_free_k': cond_free_k,
-		'experimentals': experimental_checkboxes,
-		'time': time.time()-full_start_time,
-
-		'datetime': datetime.now().isoformat(),
-		'model': tts.autoregressive_model_path,
-		'model_hash': tts.autoregressive_model_hash if hasattr(tts, 'autoregressive_model_hash') else None,
-	}
 
 	if args.voice_fixer:
 		if not voicefixer:
@@ -414,8 +460,7 @@ def generate(
 			)
 			
 			fixed_cache[f'{name}_fixed'] = {
-				'text': audio_cache[name]['text'],
-				'time': audio_cache[name]['time'],
+				'settings': audio_cache[name]['settings'],
 				'output': True
 			}
 			audio_cache[name]['output'] = False
@@ -434,36 +479,21 @@ def generate(
 
 		if not args.embed_output_metadata:
 			with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f:
-				f.write(json.dumps(info, indent='\t') )
-
-
-	if voice and voice != "random" and conditioning_latents is not None:
-		latents_path = f'{get_voice_dir()}/{voice}/cond_latents.pth'
-
-		if hasattr(tts, 'autoregressive_model_hash'):
-			latents_path = f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth'
-
-		try:
-			with open(latents_path, 'rb') as f:
-				info['latents'] = base64.b64encode(f.read()).decode("ascii")
-		except Exception as e:
-			pass
+				f.write(json.dumps(audio_cache[name]['settings'], indent='\t') )
 
 	if args.embed_output_metadata:
 		for name in progress.tqdm(audio_cache, desc="Embedding metadata..."):
 			if 'pruned' in audio_cache[name] and audio_cache[name]['pruned']:
 				continue
 
-			info['text'] = audio_cache[name]['text']
-			info['time'] = audio_cache[name]['time']
-
 			metadata = music_tag.load_file(f"{outdir}/{voice}_{name}.wav")
-			metadata['lyrics'] = json.dumps(info) 
+			metadata['lyrics'] = json.dumps(audio_cache[name]['settings'])
 			metadata.save()
  
 	if sample_voice is not None:
 		sample_voice = (tts.input_sample_rate, sample_voice.numpy())
 
+	info = get_info(voice=voice, latents=False)
 	print(f"Generation took {info['time']} seconds, saved to '{output_voices[0]}'\n")
 
 	info['seed'] = settings['use_deterministic_seed']