From 9e64dad785bc6d6e0e948784b27028d5d4746e25 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 21 Feb 2023 21:50:05 +0000 Subject: [PATCH] clamp batch size to sample count when generating for the sickos that want that, added setting to remove non-final output after a generation, something else I forgot already --- src/utils.py | 79 ++++++++++++++++++++++++++++++++++++++++++---------- src/webui.py | 10 ++++--- 2 files changed, 71 insertions(+), 18 deletions(-) diff --git a/src/utils.py b/src/utils.py index 970804d..eafe425 100755 --- a/src/utils.py +++ b/src/utils.py @@ -151,6 +151,11 @@ def generate( 'cvvp_amount': cvvp_weight, } + # clamp it down for the insane users who want this + # it would be wiser to enforce the sample size to the batch size, but this is what the user wants + if num_autoregressive_samples < args.sample_batch_size: + settings['sample_batch_size'] = num_autoregressive_samples + if delimiter is None: delimiter = "\n" elif delimiter == "\\n": @@ -301,30 +306,61 @@ def generate( 'time': time.time()-full_start_time, } + """ # kludgy yucky codesmells for name in audio_cache: if 'output' not in audio_cache[name]: continue - output_voices.append(f'{outdir}/{voice}_{name}.wav') - with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f: - f.write(json.dumps(info, indent='\t') ) + #output_voices.append(f'{outdir}/{voice}_{name}.wav') + output_voices.append(name) + if not args.embed_output_metadata: + with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f: + f.write(json.dumps(info, indent='\t') ) + """ if args.voice_fixer: if not voicefixer: load_voicefixer() - fixed_output_voices = [] - for path in progress.tqdm(output_voices, desc="Running voicefix..."): - fixed = path.replace(".wav", "_fixed.wav") + fixed_cache = {} + for name in progress.tqdm(audio_cache, desc="Running voicefix..."): + del audio_cache[name]['audio'] + if 'output' not in audio_cache[name] or not audio_cache[name]['output']: + continue + + path = f'{outdir}/{voice}_{name}.wav' + fixed = f'{outdir}/{voice}_{name}_fixed.wav' voicefixer.restore( input=path, output=fixed, cuda=get_device_name() == "cuda" and args.voice_fixer_use_cuda, #mode=mode, ) - fixed_output_voices.append(fixed) - output_voices = fixed_output_voices + + fixed_cache[f'{name}_fixed'] = { + 'text': audio_cache[name]['text'], + 'time': audio_cache[name]['time'], + 'output': True + } + audio_cache[name]['output'] = False + + for name in fixed_cache: + audio_cache[name] = fixed_cache[name] + + for name in audio_cache: + if 'output' not in audio_cache[name] or not audio_cache[name]['output']: + if args.prune_nonfinal_outputs: + audio_cache[name]['pruned'] = True + os.remove(f'{outdir}/{voice}_{name}.wav') + continue + + output_voices.append(f'{outdir}/{voice}_{name}.wav') + + if not args.embed_output_metadata: + with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f: + f.write(json.dumps(info, indent='\t') ) + if voice and voice != "random" and conditioning_latents is not None: with open(f'{get_voice_dir()}/{voice}/cond_latents.pth', 'rb') as f: @@ -332,6 +368,9 @@ def generate( if args.embed_output_metadata: for name in progress.tqdm(audio_cache, desc="Embedding metadata..."): + if 'pruned' in audio_cache[name] and audio_cache[name]['pruned']: + continue + info['text'] = audio_cache[name]['text'] info['time'] = audio_cache[name]['time'] @@ -490,7 +529,7 @@ def stop_training(): training_process.kill() return "Training cancelled" -def get_halfp_model(): +def get_halfp_model_path(): autoregressive_model_path = get_model_path('autoregressive.pth') return autoregressive_model_path.replace(".pth", "_half.pth") @@ -501,7 +540,7 @@ def convert_to_halfp(): for k in model: model[k] = model[k].half() - outfile = get_halfp_model() + outfile = get_halfp_model_path() torch.save(model, outfile) print(f'Converted model to half precision: {outfile}') @@ -733,13 +772,21 @@ def import_voices(files, saveAs=None, progress=None): print(f"Imported voice to {path}") -def get_voice_list(dir=get_voice_dir()): +def get_voice_list(dir=get_voice_dir(), append_defaults=False): os.makedirs(dir, exist_ok=True) - return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 ]) + ["microphone", "random"] + res = sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 ]) + if append_defaults: + res = res + ["random", "microphone"] + return res def get_autoregressive_models(dir="./models/finetunes/"): os.makedirs(dir, exist_ok=True) - return [get_model_path('autoregressive.pth')] + sorted([f'{dir}/{d}' for d in os.listdir(dir) if d[-4:] == ".pth" ]) + base = [get_model_path('autoregressive.pth')] + halfp = get_halfp_model_path() + if os.path.exists(halfp): + base.append(halfp) + + return base + sorted([f'{dir}/{d}' for d in os.listdir(dir) if d[-4:] == ".pth" ]) def get_dataset_list(dir="./training/"): return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 and "train.txt" in os.listdir(os.path.join(dir, d)) ]) @@ -842,6 +889,7 @@ def setup_args(): 'force-cpu-for-conditioning-latents': False, 'defer-tts-load': False, 'device-override': None, + 'prune-nonfinal-outputs': True, 'whisper-model': "base", 'autoregressive-model': None, 'concurrency-count': 2, @@ -867,6 +915,7 @@ def setup_args(): parser.add_argument("--voice-fixer-use-cuda", action='store_true', default=default_arguments['voice-fixer-use-cuda'], help="Hints to voicefixer to use CUDA, if available.") parser.add_argument("--force-cpu-for-conditioning-latents", default=default_arguments['force-cpu-for-conditioning-latents'], action='store_true', help="Forces computing conditional latents to be done on the CPU (if you constantyl OOM on low chunk counts)") parser.add_argument("--defer-tts-load", default=default_arguments['defer-tts-load'], action='store_true', help="Defers loading TTS model") + parser.add_argument("--prune-nonfinal-outputs", default=default_arguments['prune-nonfinal-outputs'], action='store_true', help="Deletes non-final output files on completing a generation") parser.add_argument("--device-override", default=default_arguments['device-override'], help="A device string to override pass through Torch") parser.add_argument("--whisper-model", default=default_arguments['whisper-model'], help="Specifies which whisper model to use for transcription.") parser.add_argument("--autoregressive-model", default=default_arguments['autoregressive-model'], help="Specifies which autoregressive model to use for sampling.") @@ -901,7 +950,7 @@ def setup_args(): return args -def update_args( listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents, defer_tts_load, device_override, sample_batch_size, concurrency_count, output_sample_rate, output_volume ): +def update_args( listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents, defer_tts_load, prune_nonfinal_outputs, device_override, sample_batch_size, concurrency_count, output_sample_rate, output_volume ): global args args.listen = listen @@ -911,6 +960,7 @@ def update_args( listen, share, check_for_updates, models_from_local_only, low_v args.low_vram = low_vram args.force_cpu_for_conditioning_latents = force_cpu_for_conditioning_latents args.defer_tts_load = defer_tts_load + args.prune_nonfinal_outputs = prune_nonfinal_outputs args.device_override = device_override args.sample_batch_size = sample_batch_size args.embed_output_metadata = embed_output_metadata @@ -932,6 +982,7 @@ def save_args_settings(): 'models-from-local-only':args.models_from_local_only, 'force-cpu-for-conditioning-latents': args.force_cpu_for_conditioning_latents, 'defer-tts-load': args.defer_tts_load, + 'prune-nonfinal-outputs': args.prune_nonfinal_outputs, 'device-override': args.device_override, 'whisper-model': args.whisper_model, 'autoregressive-model': args.autoregressive_model, diff --git a/src/webui.py b/src/webui.py index fb45620..a6ab00c 100755 --- a/src/webui.py +++ b/src/webui.py @@ -238,7 +238,7 @@ def save_training_settings_proxy( epochs, batch_size, learning_rate, learning_ra def update_voices(): return ( - gr.Dropdown.update(choices=get_voice_list()), + gr.Dropdown.update(choices=get_voice_list(append_defaults=True)), gr.Dropdown.update(choices=get_voice_list()), gr.Dropdown.update(choices=get_voice_list("./results/")), ) @@ -277,7 +277,8 @@ def setup_gradio(): emotion = gr.Radio( ["Happy", "Sad", "Angry", "Disgusted", "Arrogant", "Custom"], value="Custom", label="Emotion", type="value", interactive=True ) prompt = gr.Textbox(lines=1, label="Custom Emotion + Prompt (if selected)") - voice = gr.Dropdown(get_voice_list(), label="Voice", type="value") + voice_list = get_voice_list(append_defaults=True) + voice = gr.Dropdown(choices=voice_list, label="Voice", type="value", value=voice_list[0]) # it'd be very cash money if gradio was able to default to the first value in the list without this shit mic_audio = gr.Audio( label="Microphone Source", source="microphone", type="filepath" ) refresh_voices = gr.Button(value="Refresh Voice List") voice_latents_chunks = gr.Slider(label="Voice Chunks", minimum=1, maximum=64, value=1, step=1) @@ -410,14 +411,15 @@ def setup_gradio(): gr.Checkbox(label="Use CUDA for Voice Fixer", value=args.voice_fixer_use_cuda), gr.Checkbox(label="Force CPU for Conditioning Latents", value=args.force_cpu_for_conditioning_latents), gr.Checkbox(label="Do Not Load TTS On Startup", value=args.defer_tts_load), + gr.Checkbox(label="Delete Non-Final Output", value=args.prune_nonfinal_outputs), gr.Textbox(label="Device Override", value=args.device_override), ] with gr.Column(): exec_inputs = exec_inputs + [ gr.Number(label="Sample Batch Size", precision=0, value=args.sample_batch_size), gr.Number(label="Concurrency Count", precision=0, value=args.concurrency_count), - gr.Number(label="Ouptut Sample Rate", precision=0, value=args.output_sample_rate), - gr.Slider(label="Ouptut Volume", minimum=0, maximum=2, value=args.output_volume), + gr.Number(label="Output Sample Rate", precision=0, value=args.output_sample_rate), + gr.Slider(label="Output Volume", minimum=0, maximum=2, value=args.output_volume), ] autoregressive_models = get_autoregressive_models()