From cb273b84281548b1a81b1f149d2a6dfe9782a16e Mon Sep 17 00:00:00 2001
From: mrq <barry.quiggles@protonmail.com>
Date: Thu, 9 Mar 2023 18:34:52 +0000
Subject: [PATCH] cleanup

---
 src/utils.py | 47 +++++++++++++++++++++++++++--------------------
 src/webui.py | 23 ++++++++++++++++++++---
 2 files changed, 47 insertions(+), 23 deletions(-)

diff --git a/src/utils.py b/src/utils.py
index 82ea25b..191e677 100755
--- a/src/utils.py
+++ b/src/utils.py
@@ -1274,6 +1274,15 @@ def optimize_training_settings( **kwargs ):
 			if vram > (k-1):
 				return v
 		return 1
+
+	if settings['gpus'] > get_device_count():
+		settings['gpus'] = get_device_count()
+		messages.append(f"GPU count exceeds defacto GPU count, clamping to: {settings['gpus']}")
+
+	if settings['gpus'] <= 1:
+		settings['gpus'] = 1
+	else:
+		messages.append(f"! EXPERIMENTAL ! Multi-GPU training is extremely particular, expect issues.")
 	
 	# assuming you have equal GPUs
 	vram = get_device_vram() * settings['gpus']
@@ -1303,14 +1312,14 @@ def optimize_training_settings( **kwargs ):
 		messages.append("Resume path specified, but does not exist. Disabling...")
 
 	if settings['bitsandbytes']:
-		messages.append("BitsAndBytes requested. Please note this is ! EXPERIMENTAL !")
+		messages.append("! EXPERIMENTAL ! BitsAndBytes requested.")
 
 	if settings['half_p']:
 		if settings['bitsandbytes']:
 			settings['half_p'] = False
 			messages.append("Half Precision requested, but BitsAndBytes is also requested. Due to redundancies, disabling half precision...")
 		else:
-			messages.append("Half Precision requested. Please note this is ! EXPERIMENTAL !")
+			messages.append("! EXPERIMENTAL ! Half Precision requested.")
 			if not os.path.exists(get_halfp_model_path()):
 				convert_to_halfp()	
 
@@ -1343,12 +1352,21 @@ def save_training_settings( **kwargs ):
 	settings['iterations'] = calc_iterations(epochs=settings['epochs'], lines=lines, batch_size=settings['batch_size'])
 	messages.append(f"For {settings['epochs']} epochs with {lines} lines, iterating for {settings['iterations']} steps")
 
-	iterations_per_epoch = int(settings['iterations'] / settings['epochs'])
+	iterations_per_epoch = settings['iterations'] / settings['epochs']
 
 	settings['print_rate'] = int(settings['print_rate'] * iterations_per_epoch)
 	settings['save_rate'] = int(settings['save_rate'] * iterations_per_epoch)
 	settings['validation_rate'] = int(settings['validation_rate'] * iterations_per_epoch)
 
+	iterations_per_epoch = int(iterations_per_epoch)
+	
+	if settings['print_rate'] < 1:
+		settings['print_rate'] = 1
+	if settings['save_rate'] < 1:
+		settings['save_rate'] = 1
+	if settings['validation_rate'] < 1:
+		settings['validation_rate'] = 1
+
 	settings['validation_batch_size'] = int(settings['batch_size'] / settings['gradient_accumulation_size'])
 
 	settings['iterations'] = calc_iterations(epochs=settings['epochs'], lines=lines, batch_size=settings['batch_size'])
@@ -1809,9 +1827,7 @@ def save_args_settings():
 
 # super kludgy )`;
 def import_generate_settings(file="./config/generate.json"):
-	global GENERATE_SETTINGS_ARGS
-
-	defaults = {
+	res = {
 		'text': None,
 		'delimiter': None,
 		'emotion': None,
@@ -1836,19 +1852,9 @@ def import_generate_settings(file="./config/generate.json"):
 	}
 
 	settings, _ = read_generate_settings(file, read_latents=False)
-	
-	res = []
-	if GENERATE_SETTINGS_ARGS is not None:
-		for k in GENERATE_SETTINGS_ARGS:
-			if k not in defaults:
-				continue
-			res.append(defaults[k] if not settings or k not in settings or not settings[k] is None else settings[k])
-	else:
-		for k in defaults:
-			res.append(defaults[k] if not settings or k not in settings or not settings[k] is None else settings[k])
-
-	return tuple(res)
-
+	if settings is not None:
+		res.update(settings)
+	return res
 
 def reset_generation_settings():
 	with open(f'./config/generate.json', 'w', encoding="utf-8") as f:
@@ -1978,7 +1984,8 @@ def deduce_autoregressive_model(voice=None):
 		if os.path.isdir(dir):
 			counts = sorted([ int(d[:-8]) for d in os.listdir(dir) if d[-8:] == "_gpt.pth" ])
 			names = [ f'{dir}/{d}_gpt.pth' for d in counts ]
-			return names[-1]
+			if len(names) > 0:
+				return names[-1]
 
 	if args.autoregressive_model != "auto":
 		return args.autoregressive_model
diff --git a/src/webui.py b/src/webui.py
index 281f3ac..39abcc4 100755
--- a/src/webui.py
+++ b/src/webui.py
@@ -27,6 +27,7 @@ GENERATE_SETTINGS = {}
 TRANSCRIBE_SETTINGS = {}
 EXEC_SETTINGS = {}
 TRAINING_SETTINGS = {}
+GENERATE_SETTINGS_ARGS = []
 
 PRESETS = {
 	'Ultra Fast': {'num_autoregressive_samples': 16, 'diffusion_iterations': 30, 'cond_free': False},
@@ -144,6 +145,18 @@ def history_view_results( voice ):
 		gr.Dropdown.update(choices=sorted(files))
 	)
 
+def import_generate_settings_proxy( file=None ):
+	global GENERATE_SETTINGS_ARGS
+	settings = import_generate_settings( file )
+
+	res = []
+	for k in GENERATE_SETTINGS_ARGS:
+		res.append(settings[k] if k in settings else None)
+	print(GENERATE_SETTINGS_ARGS)
+	print(settings)
+	print(res)
+	return tuple(res)
+
 def compute_latents_proxy(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
 	compute_latents( voice=voice, voice_latents_chunks=voice_latents_chunks, progress=progress )
 	return voice
@@ -221,7 +234,10 @@ def import_training_settings_proxy( voice ):
 		if k not in settings:
 			continue
 		output[k] = settings[k]
+
 	output = list(output.values())
+	print(list(TRAINING_SETTINGS.keys()))
+	print(output)
 	messages.append(f"Imported training settings: {injson}")
 
 	return output[:-1] + ["\n".join(messages)]
@@ -250,7 +266,7 @@ def history_copy_settings( voice, file ):
 def setup_gradio():
 	global args
 	global ui
-	
+
 	if not args.share:
 		def noop(function, return_value=None):
 			def wrapped(*args, **kwargs):
@@ -273,6 +289,7 @@ def setup_gradio():
 	autoregressive_models = get_autoregressive_models()
 	dataset_list = get_dataset_list()
 
+	global GENERATE_SETTINGS_ARGS
 	GENERATE_SETTINGS_ARGS = list(inspect.signature(generate_proxy).parameters.keys())[:-1]
 	for i in range(len(GENERATE_SETTINGS_ARGS)):
 		arg = GENERATE_SETTINGS_ARGS[i]
@@ -639,7 +656,7 @@ def setup_gradio():
 		)
 
 
-		copy_button.click(import_generate_settings,
+		copy_button.click(import_generate_settings_proxy,
 			inputs=audio_in, # JSON elements cannot be used as inputs
 			outputs=generate_settings
 		)
@@ -738,7 +755,7 @@ def setup_gradio():
 		)
 
 		if os.path.isfile('./config/generate.json'):
-			ui.load(import_generate_settings, inputs=None, outputs=generate_settings)
+			ui.load(import_generate_settings_proxy, inputs=None, outputs=generate_settings)
 		
 		if args.check_for_updates:
 			ui.load(check_for_updates)