From 9e64dad785bc6d6e0e948784b27028d5d4746e25 Mon Sep 17 00:00:00 2001
From: mrq <barry.quiggles@protonmail.com>
Date: Tue, 21 Feb 2023 21:50:05 +0000
Subject: [PATCH] clamp batch size to sample count when generating for the
 sickos that want that, added setting to remove non-final output after a
 generation, something else I forgot already

---
 src/utils.py | 79 ++++++++++++++++++++++++++++++++++++++++++----------
 src/webui.py | 10 ++++---
 2 files changed, 71 insertions(+), 18 deletions(-)

diff --git a/src/utils.py b/src/utils.py
index 970804d..eafe425 100755
--- a/src/utils.py
+++ b/src/utils.py
@@ -151,6 +151,11 @@ def generate(
 		'cvvp_amount': cvvp_weight,
 	}
 
+	# clamp it down for the insane users who want this
+	# it would be wiser to enforce the sample size to the batch size, but this is what the user wants
+	if num_autoregressive_samples < args.sample_batch_size:
+		settings['sample_batch_size'] = num_autoregressive_samples
+
 	if delimiter is None:
 		delimiter = "\n"
 	elif delimiter == "\\n":
@@ -301,30 +306,61 @@ def generate(
 		'time': time.time()-full_start_time,
 	}
 
+	"""
 	# kludgy yucky codesmells
 	for name in audio_cache:
 		if 'output' not in audio_cache[name]:
 			continue
 
-		output_voices.append(f'{outdir}/{voice}_{name}.wav')
-		with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f:
-			f.write(json.dumps(info, indent='\t') )
+		#output_voices.append(f'{outdir}/{voice}_{name}.wav')
+		output_voices.append(name)
+		if not args.embed_output_metadata:
+			with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f:
+				f.write(json.dumps(info, indent='\t') )
+	"""
 
 	if args.voice_fixer:
 		if not voicefixer:
 			load_voicefixer()
 
-		fixed_output_voices = []
-		for path in progress.tqdm(output_voices, desc="Running voicefix..."):
-			fixed = path.replace(".wav", "_fixed.wav")
+		fixed_cache = {}
+		for name in progress.tqdm(audio_cache, desc="Running voicefix..."):
+			del audio_cache[name]['audio']
+			if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
+				continue
+
+			path = f'{outdir}/{voice}_{name}.wav'
+			fixed = f'{outdir}/{voice}_{name}_fixed.wav'
 			voicefixer.restore(
 				input=path,
 				output=fixed,
 				cuda=get_device_name() == "cuda" and args.voice_fixer_use_cuda,
 				#mode=mode,
 			)
-			fixed_output_voices.append(fixed)
-		output_voices = fixed_output_voices
+			
+			fixed_cache[f'{name}_fixed'] = {
+				'text': audio_cache[name]['text'],
+				'time': audio_cache[name]['time'],
+				'output': True
+			}
+			audio_cache[name]['output'] = False
+		
+		for name in fixed_cache:
+			audio_cache[name] = fixed_cache[name]
+
+	for name in audio_cache:
+		if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
+			if args.prune_nonfinal_outputs:
+				audio_cache[name]['pruned'] = True
+				os.remove(f'{outdir}/{voice}_{name}.wav')
+			continue
+
+		output_voices.append(f'{outdir}/{voice}_{name}.wav')
+
+		if not args.embed_output_metadata:
+			with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f:
+				f.write(json.dumps(info, indent='\t') )
+
 
 	if voice and voice != "random" and conditioning_latents is not None:
 		with open(f'{get_voice_dir()}/{voice}/cond_latents.pth', 'rb') as f:
@@ -332,6 +368,9 @@ def generate(
 
 	if args.embed_output_metadata:
 		for name in progress.tqdm(audio_cache, desc="Embedding metadata..."):
+			if 'pruned' in audio_cache[name] and audio_cache[name]['pruned']:
+				continue
+
 			info['text'] = audio_cache[name]['text']
 			info['time'] = audio_cache[name]['time']
 
@@ -490,7 +529,7 @@ def stop_training():
 	training_process.kill()
 	return "Training cancelled"
 
-def get_halfp_model():
+def get_halfp_model_path():
 	autoregressive_model_path = get_model_path('autoregressive.pth')
 	return autoregressive_model_path.replace(".pth", "_half.pth")
 
@@ -501,7 +540,7 @@ def convert_to_halfp():
 	for k in model:
 		model[k] = model[k].half()
 
-	outfile = get_halfp_model()
+	outfile = get_halfp_model_path()
 	torch.save(model, outfile)
 	print(f'Converted model to half precision: {outfile}')
 
@@ -733,13 +772,21 @@ def import_voices(files, saveAs=None, progress=None):
 
 			print(f"Imported voice to {path}")
 
-def get_voice_list(dir=get_voice_dir()):
+def get_voice_list(dir=get_voice_dir(), append_defaults=False):
 	os.makedirs(dir, exist_ok=True)
-	return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 ]) + ["microphone", "random"]
+	res = sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 ])
+	if append_defaults:
+		res = res + ["random", "microphone"]
+	return res
 
 def get_autoregressive_models(dir="./models/finetunes/"):
 	os.makedirs(dir, exist_ok=True)
-	return [get_model_path('autoregressive.pth')] + sorted([f'{dir}/{d}' for d in os.listdir(dir) if d[-4:] == ".pth" ])
+	base = [get_model_path('autoregressive.pth')]
+	halfp = get_halfp_model_path()
+	if os.path.exists(halfp):
+		base.append(halfp)
+
+	return base + sorted([f'{dir}/{d}' for d in os.listdir(dir) if d[-4:] == ".pth" ])
 
 def get_dataset_list(dir="./training/"):
 	return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 and "train.txt" in os.listdir(os.path.join(dir, d)) ])
@@ -842,6 +889,7 @@ def setup_args():
 		'force-cpu-for-conditioning-latents': False,
 		'defer-tts-load': False,
 		'device-override': None,
+		'prune-nonfinal-outputs': True,
 		'whisper-model': "base",
 		'autoregressive-model': None,
 		'concurrency-count': 2,
@@ -867,6 +915,7 @@ def setup_args():
 	parser.add_argument("--voice-fixer-use-cuda", action='store_true', default=default_arguments['voice-fixer-use-cuda'], help="Hints to voicefixer to use CUDA, if available.")
 	parser.add_argument("--force-cpu-for-conditioning-latents", default=default_arguments['force-cpu-for-conditioning-latents'], action='store_true', help="Forces computing conditional latents to be done on the CPU (if you constantyl OOM on low chunk counts)")
 	parser.add_argument("--defer-tts-load", default=default_arguments['defer-tts-load'], action='store_true', help="Defers loading TTS model")
+	parser.add_argument("--prune-nonfinal-outputs", default=default_arguments['prune-nonfinal-outputs'], action='store_true', help="Deletes non-final output files on completing a generation")
 	parser.add_argument("--device-override", default=default_arguments['device-override'], help="A device string to override pass through Torch")
 	parser.add_argument("--whisper-model", default=default_arguments['whisper-model'], help="Specifies which whisper model to use for transcription.")
 	parser.add_argument("--autoregressive-model", default=default_arguments['autoregressive-model'], help="Specifies which autoregressive model to use for sampling.")
@@ -901,7 +950,7 @@ def setup_args():
 	
 	return args
 
-def update_args( listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents, defer_tts_load, device_override, sample_batch_size, concurrency_count, output_sample_rate, output_volume ):
+def update_args( listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents, defer_tts_load, prune_nonfinal_outputs, device_override, sample_batch_size, concurrency_count, output_sample_rate, output_volume ):
 	global args
 
 	args.listen = listen
@@ -911,6 +960,7 @@ def update_args( listen, share, check_for_updates, models_from_local_only, low_v
 	args.low_vram = low_vram
 	args.force_cpu_for_conditioning_latents = force_cpu_for_conditioning_latents
 	args.defer_tts_load = defer_tts_load
+	args.prune_nonfinal_outputs = prune_nonfinal_outputs
 	args.device_override = device_override
 	args.sample_batch_size = sample_batch_size
 	args.embed_output_metadata = embed_output_metadata
@@ -932,6 +982,7 @@ def save_args_settings():
 		'models-from-local-only':args.models_from_local_only,
 		'force-cpu-for-conditioning-latents': args.force_cpu_for_conditioning_latents,
 		'defer-tts-load': args.defer_tts_load,
+		'prune-nonfinal-outputs': args.prune_nonfinal_outputs,
 		'device-override': args.device_override,
 		'whisper-model': args.whisper_model,
 		'autoregressive-model': args.autoregressive_model,
diff --git a/src/webui.py b/src/webui.py
index fb45620..a6ab00c 100755
--- a/src/webui.py
+++ b/src/webui.py
@@ -238,7 +238,7 @@ def save_training_settings_proxy( epochs, batch_size, learning_rate, learning_ra
 
 def update_voices():
 	return (
-		gr.Dropdown.update(choices=get_voice_list()),
+		gr.Dropdown.update(choices=get_voice_list(append_defaults=True)),
 		gr.Dropdown.update(choices=get_voice_list()),
 		gr.Dropdown.update(choices=get_voice_list("./results/")),
 	)
@@ -277,7 +277,8 @@ def setup_gradio():
 
 					emotion = gr.Radio( ["Happy", "Sad", "Angry", "Disgusted", "Arrogant", "Custom"], value="Custom", label="Emotion", type="value", interactive=True )
 					prompt = gr.Textbox(lines=1, label="Custom Emotion + Prompt (if selected)")
-					voice = gr.Dropdown(get_voice_list(), label="Voice", type="value")
+					voice_list = get_voice_list(append_defaults=True)
+					voice = gr.Dropdown(choices=voice_list, label="Voice", type="value", value=voice_list[0]) # it'd be very cash money if gradio was able to default to the first value in the list without this shit
 					mic_audio = gr.Audio( label="Microphone Source", source="microphone", type="filepath" )
 					refresh_voices = gr.Button(value="Refresh Voice List")
 					voice_latents_chunks = gr.Slider(label="Voice Chunks", minimum=1, maximum=64, value=1, step=1)
@@ -410,14 +411,15 @@ def setup_gradio():
 						gr.Checkbox(label="Use CUDA for Voice Fixer", value=args.voice_fixer_use_cuda),
 						gr.Checkbox(label="Force CPU for Conditioning Latents", value=args.force_cpu_for_conditioning_latents),
 						gr.Checkbox(label="Do Not Load TTS On Startup", value=args.defer_tts_load),
+						gr.Checkbox(label="Delete Non-Final Output", value=args.prune_nonfinal_outputs),
 						gr.Textbox(label="Device Override", value=args.device_override),
 					]
 				with gr.Column():
 					exec_inputs = exec_inputs + [
 						gr.Number(label="Sample Batch Size", precision=0, value=args.sample_batch_size),
 						gr.Number(label="Concurrency Count", precision=0, value=args.concurrency_count),
-						gr.Number(label="Ouptut Sample Rate", precision=0, value=args.output_sample_rate),
-						gr.Slider(label="Ouptut Volume", minimum=0, maximum=2, value=args.output_volume),
+						gr.Number(label="Output Sample Rate", precision=0, value=args.output_sample_rate),
+						gr.Slider(label="Output Volume", minimum=0, maximum=2, value=args.output_volume),
 					]
 					
 					autoregressive_models = get_autoregressive_models()