From 7798767fc6a00601b1b2ca487738e9c54a1f4db2 Mon Sep 17 00:00:00 2001
From: mrq <barry.quiggles@protonmail.com>
Date: Mon, 6 Mar 2023 21:48:34 +0000
Subject: [PATCH] added settings editing (will add a guide on what to do later,
 and an example)

---
 src/utils.py | 244 ++++++++++++++++++++++++++++++++-------------------
 1 file changed, 153 insertions(+), 91 deletions(-)

diff --git a/src/utils.py b/src/utils.py
index c4eea7c..0074bc6 100755
--- a/src/utils.py
+++ b/src/utils.py
@@ -90,46 +90,59 @@ def generate(
 
 	do_gc()
 
-	if voice != "microphone":
-		voices = [voice]
-	else:
-		voices = []
+	voices = {}
 
-	if voice == "microphone":
-		if mic_audio is None:
-			raise Exception("Please provide audio from mic when choosing `microphone` as a voice input")
-		mic = load_audio(mic_audio, tts.input_sample_rate)
-		voice_samples, conditioning_latents = [mic], None
-	elif voice == "random":
-		voice_samples, conditioning_latents = None, tts.get_random_conditioning_latents()
-	else:
-		progress(0, desc="Loading voice...")
-		# nasty check for users that, for whatever reason, updated the web UI but not mrq/tortoise-tts
-		if hasattr(tts, 'autoregressive_model_hash'):
-			voice_samples, conditioning_latents = load_voice(voice, model_hash=tts.autoregressive_model_hash)
+	voice_samples = None
+	conditioning_latents =None
+	sample_voice = None
+
+	def fetch_voice( requested ):
+		voice = requested
+
+		if voice in voices:
+			return voices[voice]
+
+		print(f"Loading voice: {voice}")
+
+		if voice == "microphone":
+			if mic_audio is None:
+				raise Exception("Please provide audio from mic when choosing `microphone` as a voice input")
+			voice_samples, conditioning_latents = [load_audio(mic_audio, tts.input_sample_rate)], None
+		elif voice == "random":
+			voice_samples, conditioning_latents = None, tts.get_random_conditioning_latents()
 		else:
-			voice_samples, conditioning_latents = load_voice(voice)
-
-	if voice_samples and len(voice_samples) > 0:
-		sample_voice = torch.cat(voice_samples, dim=-1).squeeze().cpu()
-
-		conditioning_latents = tts.get_conditioning_latents(voice_samples, return_mels=not args.latents_lean_and_mean, progress=progress, slices=voice_latents_chunks, force_cpu=args.force_cpu_for_conditioning_latents)
-		if len(conditioning_latents) == 4:
-			conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None)
-			
-		if voice != "microphone":
+			progress(0, desc=f"Loading voice: {voice}")
+			# nasty check for users that, for whatever reason, updated the web UI but not mrq/tortoise-tts
 			if hasattr(tts, 'autoregressive_model_hash'):
-				torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth')
+				voice_samples, conditioning_latents = load_voice(voice, model_hash=tts.autoregressive_model_hash)
 			else:
-				torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth')
-		voice_samples = None
-	else:
-		if conditioning_latents is not None:
-			sample_voice, _ = load_voice(voice, load_latents=False)
-			if sample_voice and len(sample_voice) > 0:
-				sample_voice = torch.cat(sample_voice, dim=-1).squeeze().cpu()
+				voice_samples, conditioning_latents = load_voice(voice)
+
+		if voice_samples and len(voice_samples) > 0:
+			sample_voice = torch.cat(voice_samples, dim=-1).squeeze().cpu()
+
+			conditioning_latents = tts.get_conditioning_latents(voice_samples, return_mels=not args.latents_lean_and_mean, progress=progress, slices=voice_latents_chunks, force_cpu=args.force_cpu_for_conditioning_latents)
+			if len(conditioning_latents) == 4:
+				conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None)
+				
+			if voice != "microphone":
+				if hasattr(tts, 'autoregressive_model_hash'):
+					torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth')
+				else:
+					torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth')
+			voice_samples = None
 		else:
-			sample_voice = None
+			if conditioning_latents is not None:
+				sample_voice, _ = load_voice(voice, load_latents=False)
+				if sample_voice and len(sample_voice) > 0:
+					sample_voice = torch.cat(sample_voice, dim=-1).squeeze().cpu()
+			else:
+				sample_voice = None
+
+		voices[voice] = (voice_samples, conditioning_latents, sample_voice)
+		return voices[voice]
+
+	voice_samples, conditioning_latents, sample_voice = fetch_voice(voice)
 
 	if seed == 0:
 		seed = None
@@ -138,42 +151,80 @@ def generate(
 		print("Requesting weighing against CVVP weight, but voice latents are missing some extra data. Please regenerate your voice latents.")
 		cvvp_weight = 0
 
+	def get_settings( override=None ):
+		settings = {
+			'temperature': float(temperature),
 
-	settings = {
-		'temperature': float(temperature),
+			'top_p': float(top_p),
+			'diffusion_temperature': float(diffusion_temperature),
+			'length_penalty': float(length_penalty),
+			'repetition_penalty': float(repetition_penalty),
+			'cond_free_k': float(cond_free_k),
 
-		'top_p': float(top_p),
-		'diffusion_temperature': float(diffusion_temperature),
-		'length_penalty': float(length_penalty),
-		'repetition_penalty': float(repetition_penalty),
-		'cond_free_k': float(cond_free_k),
+			'num_autoregressive_samples': num_autoregressive_samples,
+			'sample_batch_size': args.sample_batch_size,
+			'diffusion_iterations': diffusion_iterations,
 
-		'num_autoregressive_samples': num_autoregressive_samples,
-		'sample_batch_size': args.sample_batch_size,
-		'diffusion_iterations': diffusion_iterations,
+			'voice_samples': voice_samples,
+			'conditioning_latents': conditioning_latents,
 
-		'voice_samples': voice_samples,
-		'conditioning_latents': conditioning_latents,
-		'use_deterministic_seed': seed,
-		'return_deterministic_state': True,
-		'k': candidates,
-		'diffusion_sampler': diffusion_sampler,
-		'breathing_room': breathing_room,
-		'progress': progress,
-		'half_p': "Half Precision" in experimental_checkboxes,
-		'cond_free': "Conditioning-Free" in experimental_checkboxes,
-		'cvvp_amount': cvvp_weight,
-	}
+			'use_deterministic_seed': seed,
+			'return_deterministic_state': True,
+			'k': candidates,
+			'diffusion_sampler': diffusion_sampler,
+			'breathing_room': breathing_room,
+			'progress': progress,
+			'half_p': "Half Precision" in experimental_checkboxes,
+			'cond_free': "Conditioning-Free" in experimental_checkboxes,
+			'cvvp_amount': cvvp_weight,
+			'autoregressive_model': args.autoregressive_model,
+		}
 
-	# clamp it down for the insane users who want this
-	# it would be wiser to enforce the sample size to the batch size, but this is what the user wants
-	sample_batch_size = args.sample_batch_size
-	if not sample_batch_size:
-		sample_batch_size = tts.autoregressive_batch_size
-	if num_autoregressive_samples < sample_batch_size:
-		settings['sample_batch_size'] = num_autoregressive_samples
+		# could be better to just do a ternary on everything above, but i am not a professional
+		if override is not None:
+			if 'voice' in override:
+				voice = override['voice']
 
-	if delimiter is None:
+				if "autoregressive_model" in override and override["autoregressive_model"] == "auto":
+					dir = f'./training/{voice}-finetune/models/'
+					if os.path.exists(f'./training/finetunes/{voice}.pth'):
+						override["autoregressive_model"] = f'./training/finetunes/{voice}.pth'
+					elif os.path.isdir(dir):
+						counts = sorted([ int(d[:-8]) for d in os.listdir(dir) if d[-8:] == "_gpt.pth" ])
+						names = [ f'./{dir}/{d}_gpt.pth' for d in counts ]
+						override["autoregressive_model"] = names[-1]
+					else:
+						override["autoregressive_model"] = None
+
+					# necessary to ensure the right model gets loaded for the latents
+					tts.load_autoregressive_model( override["autoregressive_model"] )
+
+				fetched = fetch_voice(voice)
+
+				settings['voice_samples'] = fetched[0]
+				settings['conditioning_latents'] = fetched[1]
+
+			for k in override:
+				if k not in settings:
+					continue
+				settings[k] = override[k]
+
+		if hasattr(tts, 'autoregressive_model_path') and tts.autoregressive_model_path != settings["autoregressive_model"]:
+			tts.load_autoregressive_model( settings["autoregressive_model"] )
+
+		# clamp it down for the insane users who want this
+		# it would be wiser to enforce the sample size to the batch size, but this is what the user wants
+		sample_batch_size = args.sample_batch_size
+		if not sample_batch_size:
+			sample_batch_size = tts.autoregressive_batch_size
+		if num_autoregressive_samples < sample_batch_size:
+			settings['sample_batch_size'] = num_autoregressive_samples
+			
+		return settings
+
+	settings = get_settings()
+
+	if not delimiter:
 		delimiter = "\n"
 	elif delimiter == "\\n":
 		delimiter = "\n"
@@ -189,7 +240,6 @@ def generate(
 	os.makedirs(outdir, exist_ok=True)
 
 	audio_cache = {}
-
 	resample = None
 
 	if tts.output_sample_rate != args.output_sample_rate:
@@ -238,12 +288,28 @@ def generate(
 				cut_text = f"[{prompt},] {cut_text}"
 		elif emotion != "None":
 			cut_text = f"[I am really {emotion.lower()},] {cut_text}"
-
+		
 		progress.msg_prefix = f'[{str(line+1)}/{str(len(texts))}]'
 		print(f"{progress.msg_prefix} Generating line: {cut_text}")
-
 		start_time = time.time()
-		gen, additionals = tts.tts(cut_text, **settings )
+
+		# do setting editing
+		match = re.findall(r'^(\{.+\}) (.+?)$', cut_text) 
+		if match and len(match) > 0:
+			match = match[0]
+			try:
+				override = json.loads(match[0])
+			except Exception as e:
+				print(e)
+				raise Exception("Prompt settings editing requested, but received invalid JSON")
+
+			cut_text = match[1].strip()
+			new_settings = get_settings( override )
+
+			gen, additionals = tts.tts(cut_text, **new_settings )
+		else:
+			gen, additionals = tts.tts(cut_text, **settings )
+
 		seed = additionals[0]
 		run_time = time.time()-start_time
 		print(f"Generating line took {run_time} seconds")
@@ -327,19 +393,6 @@ def generate(
 		'model_hash': tts.autoregressive_model_hash if hasattr(tts, 'autoregressive_model_hash') else None,
 	}
 
-	"""
-	# kludgy yucky codesmells
-	for name in audio_cache:
-		if 'output' not in audio_cache[name]:
-			continue
-
-		#output_voices.append(f'{outdir}/{voice}_{name}.wav')
-		output_voices.append(name)
-		if not args.embed_output_metadata:
-			with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f:
-				f.write(json.dumps(info, indent='\t') )
-	"""
-
 	if args.voice_fixer:
 		if not voicefixer:
 			progress(0, "Loading voicefix...")
@@ -1057,8 +1110,8 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres
 	files = sorted(files)
 
 	previous_list = []
-	parsed_list = []
 	if skip_existings and os.path.exists(f'{outdir}/train.txt'):
+		parsed_list = []
 		with open(f'{outdir}/train.txt', 'r', encoding="utf-8") as f:
 			parsed_list = f.readlines()
 
@@ -1103,20 +1156,13 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres
 			line = f"{sliced_name}|{segment['text'].strip()}"
 			transcription.append(line)
 			with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f:
-				f.write(f'{line}\n')
+				f.write(f'\n{line}')
 
 		do_gc()
 	
 	with open(f'{outdir}/whisper.json', 'w', encoding="utf-8") as f:
 		f.write(json.dumps(results, indent='\t'))
 
-	if len(parsed_list) > 0:
-		transcription = parsed_list + transcription
-	
-	joined = '\n'.join(transcription)
-	with open(f'{outdir}/train.txt', 'w', encoding="utf-8") as f:
-		f.write(joined)
-
 	unload_whisper()
 
 	return f"Processed dataset to: {outdir}\n{joined}"
@@ -1688,6 +1734,22 @@ def read_generate_settings(file, read_latents=True):
 		latents,
 	)
 
+def version_check_tts( min_version ):
+	global tts
+	if not tts:
+		raise Exception("TTS is not initialized")
+
+	if not hasattr(tts, 'version'):
+		return False
+
+	if min_version[0] > tts.version[0]:
+		return True
+	if min_version[1] > tts.version[1]:
+		return True
+	if min_version[2] >= tts.version[2]:
+		return True
+	return False
+
 def load_tts( restart=False, model=None ):
 	global args
 	global tts