From d7a5ad9fd9f1c728e5866bd515362bb297edab1b Mon Sep 17 00:00:00 2001
From: mrq <barry.quiggles@protonmail.com>
Date: Tue, 7 Mar 2023 04:34:39 +0000
Subject: [PATCH] cleaned up some model loading logic, added 'auto' mode for AR
 model (deduced by current voice)

---
 src/utils.py | 107 ++++++++++++++++++++++++++++++++++-----------------
 tortoise-tts |   2 +-
 2 files changed, 72 insertions(+), 37 deletions(-)

diff --git a/src/utils.py b/src/utils.py
index 309b72f..0aec5a8 100755
--- a/src/utils.py
+++ b/src/utils.py
@@ -54,6 +54,8 @@ voicefixer = None
 whisper_model = None
 training_state = None
 
+current_voice = None
+
 def generate(
 	text,
 	delimiter,
@@ -117,10 +119,7 @@ def generate(
 		else:
 			progress(0, desc=f"Loading voice: {voice}")
 			# nasty check for users that, for whatever reason, updated the web UI but not mrq/tortoise-tts
-			if hasattr(tts, 'autoregressive_model_hash'):
-				voice_samples, conditioning_latents = load_voice(voice, model_hash=tts.autoregressive_model_hash)
-			else:
-				voice_samples, conditioning_latents = load_voice(voice)
+			voice_samples, conditioning_latents = load_voice(voice, model_hash=tts.autoregressive_model_hash)
 
 		if voice_samples and len(voice_samples) > 0:
 			conditioning_latents = compute_latents(voice=voice, voice_samples=voice_samples, voice_latents_chunks=voice_latents_chunks)
@@ -146,6 +145,10 @@ def generate(
 		print("Requesting weighing against CVVP weight, but voice latents are missing some extra data. Please regenerate your voice latents.")
 		cvvp_weight = 0
 
+	autoregressive_model = args.autoregressive_model
+	if autoregressive_model == "auto":
+		autoregressive_model = deduce_autoregressive_model(voice)
+
 	def get_settings( override=None ):
 		settings = {
 			'temperature': float(temperature),
@@ -172,7 +175,7 @@ def generate(
 			'half_p': "Half Precision" in experimental_checkboxes,
 			'cond_free': "Conditioning-Free" in experimental_checkboxes,
 			'cvvp_amount': cvvp_weight,
-			'autoregressive_model': args.autoregressive_model,
+			'autoregressive_model': autoregressive_model,
 		}
 
 		# could be better to just do a ternary on everything above, but i am not a professional
@@ -180,18 +183,10 @@ def generate(
 			if 'voice' in override:
 				voice = override['voice']
 
-				if "autoregressive_model" in override and override["autoregressive_model"] == "auto":
-					dir = f'./training/{voice}-finetune/models/'
-					if os.path.exists(f'./training/finetunes/{voice}.pth'):
-						override["autoregressive_model"] = f'./training/finetunes/{voice}.pth'
-					elif os.path.isdir(dir):
-						counts = sorted([ int(d[:-8]) for d in os.listdir(dir) if d[-8:] == "_gpt.pth" ])
-						names = [ f'./{dir}/{d}_gpt.pth' for d in counts ]
-						override["autoregressive_model"] = names[-1]
-					else:
-						override["autoregressive_model"] = None
+				if "autoregressive_model" in override:
+					if override["autoregressive_model"] == "auto":
+						override["autoregressive_model"] = deduce_autoregressive_model(voice)
 
-					# necessary to ensure the right model gets loaded for the latents
 					tts.load_autoregressive_model( override["autoregressive_model"] )
 
 				fetched = fetch_voice(voice)
@@ -204,8 +199,7 @@ def generate(
 					continue
 				settings[k] = override[k]
 
-		if hasattr(tts, 'autoregressive_model_path') and tts.autoregressive_model_path != settings["autoregressive_model"]:
-			tts.load_autoregressive_model( settings["autoregressive_model"] )
+		tts.load_autoregressive_model( settings["autoregressive_model"] )
 
 		# clamp it down for the insane users who want this
 		# it would be wiser to enforce the sample size to the batch size, but this is what the user wants
@@ -302,7 +296,7 @@ def generate(
 
 			'datetime': datetime.now().isoformat(),
 			'model': tts.autoregressive_model_path,
-			'model_hash': tts.autoregressive_model_hash if hasattr(tts, 'autoregressive_model_hash') else None,
+			'model_hash': tts.autoregressive_model_hash 
 		}
 
 		if settings is not None:
@@ -331,7 +325,7 @@ def generate(
 			else:
 				if settings and "model_hash" in settings:
 					latents_path = f'{get_voice_dir()}/{voice}/cond_latents_{settings["model_hash"][:8]}.pth'
-				elif hasattr(tts, "autoregressive_model_hash"):
+				else:
 					latents_path = f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth'
 
 			if latents_path and os.path.exists(latents_path):
@@ -387,7 +381,7 @@ def generate(
 			used_settings['time'] = run_time
 			used_settings['datetime'] = datetime.now().isoformat(),
 			used_settings['model'] = tts.autoregressive_model_path
-			used_settings['model_hash'] = tts.autoregressive_model_hash if hasattr(tts, 'autoregressive_model_hash') else None
+			used_settings['model_hash'] = tts.autoregressive_model_hash
 
 			audio_cache[name] = {
 				'audio': audio,
@@ -540,6 +534,9 @@ def hash_file(path, algo="md5", buffer_size=0):
 	return "{0}".format(hash.hexdigest())
 
 def update_baseline_for_latents_chunks( voice ):
+	global current_voice
+	current_voice = voice
+
 	path = f'{get_voice_dir()}/{voice}/'
 	if not os.path.isdir(path):
 		return 1
@@ -583,6 +580,9 @@ def compute_latents(voice=None, voice_samples=None, voice_latents_chunks=0, prog
 	if hasattr(tts, "loading") and tts.loading:
 		raise Exception("TTS is still initializing...")
 
+	if args.autoregressive_model == "auto":
+		tts.load_autoregressive_model(deduce_autoregressive_model(voice))
+
 	if voice:
 		load_from_dataset = voice_latents_chunks == 0
 
@@ -620,10 +620,7 @@ def compute_latents(voice=None, voice_samples=None, voice_latents_chunks=0, prog
 	if len(conditioning_latents) == 4:
 		conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None)
 			
-	if hasattr(tts, 'autoregressive_model_hash'):
-		torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth')
-	else:
-		torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth')
+	torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth')
 
 	return conditioning_latents
 
@@ -1460,6 +1457,9 @@ def get_autoregressive_models(dir="./models/finetunes/", prefixed=False):
 		models = sorted([ int(d[:-8]) for d in os.listdir(f'./training/{training}/models/') if d[-8:] == "_gpt.pth" ])
 		found = found + [ f'./training/{training}/models/{d}_gpt.pth' for d in models ]
 
+	if len(found) > 0 or len(additionals) > 0:
+		base = ["auto"] + base
+
 	res = base + additionals + found
 	
 	if prefixed:
@@ -1815,28 +1815,29 @@ def version_check_tts( min_version ):
 		return True
 	return False
 
-def load_tts( restart=False, model=None ):
+def load_tts( restart=False, autoregressive_model=None ):
 	global args
 	global tts
 
 	if restart:
 		unload_tts()
 
+	if autoregressive_model:
+		args.autoregressive_model = autoregressive_model
+	else:
+		autoregressive_model = args.autoregressive_model
 
-	if model:
-		args.autoregressive_model = model
+	if autoregressive_model == "auto":
+		autoregressive_model = deduce_autoregressive_model()
 
-	print(f"Loading TorToiSe... (AR: {args.autoregressive_model}, vocoder: {args.vocoder_model})")
+	print(f"Loading TorToiSe... (AR: {autoregressive_model}, vocoder: {args.vocoder_model})")
 
 	tts_loading = True
 	try:
-		tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=args.autoregressive_model, vocoder_model=args.vocoder_model)
+		tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=autoregressive_model, vocoder_model=args.vocoder_model)
 	except Exception as e:
 		tts = TextToSpeech(minor_optimizations=not args.low_vram)
-		load_autoregressive_model(args.autoregressive_model)
-
-	if not hasattr(tts, 'autoregressive_model_hash'):
-		tts.autoregressive_model_hash = hash_file(tts.autoregressive_model_path)
+		load_autoregressive_model(autoregressive_model)
 
 	tts_loading = False
 
@@ -1858,6 +1859,37 @@ def unload_tts():
 def reload_tts( model=None ):
 	load_tts( restart=True, model=model )
 
+def get_current_voice():
+	global current_voice
+	if current_voice:
+		return current_voice
+
+	settings, _ = read_generate_settings("./config/generate.json", read_latents=False)
+	
+	if settings and "voice" in settings['voice']:
+		return settings["voice"]
+	
+	return None
+
+def deduce_autoregressive_model(voice=None):
+	if not voice:
+		voice = get_current_voice()
+
+	if voice:
+		dir = f'./training/{voice}-finetune/models/'
+		if os.path.exists(f'./training/finetunes/{voice}.pth'):
+			return f'./training/finetunes/{voice}.pth'
+		
+		if os.path.isdir(dir):
+			counts = sorted([ int(d[:-8]) for d in os.listdir(dir) if d[-8:] == "_gpt.pth" ])
+			names = [ f'{dir}/{d}_gpt.pth' for d in counts ]
+			return names[-1]
+
+	if args.autoregressive_model != "auto":
+		return args.autoregressive_model
+
+	return get_model_path('autoregressive.pth')
+
 def update_autoregressive_model(autoregressive_model_path):
 	match = re.findall(r'^\[[a-fA-F0-9]{8}\] (.+?)$', autoregressive_model_path)
 	if match:
@@ -1880,10 +1912,13 @@ def update_autoregressive_model(autoregressive_model_path):
 	if hasattr(tts, "loading") and tts.loading:
 		raise Exception("TTS is still initializing...")
 
+	if autoregressive_model_path == "auto":
+		autoregressive_model_path = deduce_autoregressive_model()
+
+	if autoregressive_model_path == tts.autoregressive_model_path:
+		return
 
-	print(f"Loading model: {autoregressive_model_path}")
 	tts.load_autoregressive_model(autoregressive_model_path)
-	print(f"Loaded model: {tts.autoregressive_model_path}")
 
 	do_gc()
 	
diff --git a/tortoise-tts b/tortoise-tts
index e2db36a..26133c2 160000
--- a/tortoise-tts
+++ b/tortoise-tts
@@ -1 +1 @@
-Subproject commit e2db36af602297501132f7f68331755f5904825a
+Subproject commit 26133c20314b77155e77be804b43909dab9809d6