From cb273b84281548b1a81b1f149d2a6dfe9782a16e Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 9 Mar 2023 18:34:52 +0000 Subject: [PATCH] cleanup --- src/utils.py | 47 +++++++++++++++++++++++++++-------------------- src/webui.py | 23 ++++++++++++++++++++--- 2 files changed, 47 insertions(+), 23 deletions(-) diff --git a/src/utils.py b/src/utils.py index 82ea25b..191e677 100755 --- a/src/utils.py +++ b/src/utils.py @@ -1274,6 +1274,15 @@ def optimize_training_settings( **kwargs ): if vram > (k-1): return v return 1 + + if settings['gpus'] > get_device_count(): + settings['gpus'] = get_device_count() + messages.append(f"GPU count exceeds defacto GPU count, clamping to: {settings['gpus']}") + + if settings['gpus'] <= 1: + settings['gpus'] = 1 + else: + messages.append(f"! EXPERIMENTAL ! Multi-GPU training is extremely particular, expect issues.") # assuming you have equal GPUs vram = get_device_vram() * settings['gpus'] @@ -1303,14 +1312,14 @@ def optimize_training_settings( **kwargs ): messages.append("Resume path specified, but does not exist. Disabling...") if settings['bitsandbytes']: - messages.append("BitsAndBytes requested. Please note this is ! EXPERIMENTAL !") + messages.append("! EXPERIMENTAL ! BitsAndBytes requested.") if settings['half_p']: if settings['bitsandbytes']: settings['half_p'] = False messages.append("Half Precision requested, but BitsAndBytes is also requested. Due to redundancies, disabling half precision...") else: - messages.append("Half Precision requested. Please note this is ! EXPERIMENTAL !") + messages.append("! EXPERIMENTAL ! Half Precision requested.") if not os.path.exists(get_halfp_model_path()): convert_to_halfp() @@ -1343,12 +1352,21 @@ def save_training_settings( **kwargs ): settings['iterations'] = calc_iterations(epochs=settings['epochs'], lines=lines, batch_size=settings['batch_size']) messages.append(f"For {settings['epochs']} epochs with {lines} lines, iterating for {settings['iterations']} steps") - iterations_per_epoch = int(settings['iterations'] / settings['epochs']) + iterations_per_epoch = settings['iterations'] / settings['epochs'] settings['print_rate'] = int(settings['print_rate'] * iterations_per_epoch) settings['save_rate'] = int(settings['save_rate'] * iterations_per_epoch) settings['validation_rate'] = int(settings['validation_rate'] * iterations_per_epoch) + iterations_per_epoch = int(iterations_per_epoch) + + if settings['print_rate'] < 1: + settings['print_rate'] = 1 + if settings['save_rate'] < 1: + settings['save_rate'] = 1 + if settings['validation_rate'] < 1: + settings['validation_rate'] = 1 + settings['validation_batch_size'] = int(settings['batch_size'] / settings['gradient_accumulation_size']) settings['iterations'] = calc_iterations(epochs=settings['epochs'], lines=lines, batch_size=settings['batch_size']) @@ -1809,9 +1827,7 @@ def save_args_settings(): # super kludgy )`; def import_generate_settings(file="./config/generate.json"): - global GENERATE_SETTINGS_ARGS - - defaults = { + res = { 'text': None, 'delimiter': None, 'emotion': None, @@ -1836,19 +1852,9 @@ def import_generate_settings(file="./config/generate.json"): } settings, _ = read_generate_settings(file, read_latents=False) - - res = [] - if GENERATE_SETTINGS_ARGS is not None: - for k in GENERATE_SETTINGS_ARGS: - if k not in defaults: - continue - res.append(defaults[k] if not settings or k not in settings or not settings[k] is None else settings[k]) - else: - for k in defaults: - res.append(defaults[k] if not settings or k not in settings or not settings[k] is None else settings[k]) - - return tuple(res) - + if settings is not None: + res.update(settings) + return res def reset_generation_settings(): with open(f'./config/generate.json', 'w', encoding="utf-8") as f: @@ -1978,7 +1984,8 @@ def deduce_autoregressive_model(voice=None): if os.path.isdir(dir): counts = sorted([ int(d[:-8]) for d in os.listdir(dir) if d[-8:] == "_gpt.pth" ]) names = [ f'{dir}/{d}_gpt.pth' for d in counts ] - return names[-1] + if len(names) > 0: + return names[-1] if args.autoregressive_model != "auto": return args.autoregressive_model diff --git a/src/webui.py b/src/webui.py index 281f3ac..39abcc4 100755 --- a/src/webui.py +++ b/src/webui.py @@ -27,6 +27,7 @@ GENERATE_SETTINGS = {} TRANSCRIBE_SETTINGS = {} EXEC_SETTINGS = {} TRAINING_SETTINGS = {} +GENERATE_SETTINGS_ARGS = [] PRESETS = { 'Ultra Fast': {'num_autoregressive_samples': 16, 'diffusion_iterations': 30, 'cond_free': False}, @@ -144,6 +145,18 @@ def history_view_results( voice ): gr.Dropdown.update(choices=sorted(files)) ) +def import_generate_settings_proxy( file=None ): + global GENERATE_SETTINGS_ARGS + settings = import_generate_settings( file ) + + res = [] + for k in GENERATE_SETTINGS_ARGS: + res.append(settings[k] if k in settings else None) + print(GENERATE_SETTINGS_ARGS) + print(settings) + print(res) + return tuple(res) + def compute_latents_proxy(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)): compute_latents( voice=voice, voice_latents_chunks=voice_latents_chunks, progress=progress ) return voice @@ -221,7 +234,10 @@ def import_training_settings_proxy( voice ): if k not in settings: continue output[k] = settings[k] + output = list(output.values()) + print(list(TRAINING_SETTINGS.keys())) + print(output) messages.append(f"Imported training settings: {injson}") return output[:-1] + ["\n".join(messages)] @@ -250,7 +266,7 @@ def history_copy_settings( voice, file ): def setup_gradio(): global args global ui - + if not args.share: def noop(function, return_value=None): def wrapped(*args, **kwargs): @@ -273,6 +289,7 @@ def setup_gradio(): autoregressive_models = get_autoregressive_models() dataset_list = get_dataset_list() + global GENERATE_SETTINGS_ARGS GENERATE_SETTINGS_ARGS = list(inspect.signature(generate_proxy).parameters.keys())[:-1] for i in range(len(GENERATE_SETTINGS_ARGS)): arg = GENERATE_SETTINGS_ARGS[i] @@ -639,7 +656,7 @@ def setup_gradio(): ) - copy_button.click(import_generate_settings, + copy_button.click(import_generate_settings_proxy, inputs=audio_in, # JSON elements cannot be used as inputs outputs=generate_settings ) @@ -738,7 +755,7 @@ def setup_gradio(): ) if os.path.isfile('./config/generate.json'): - ui.load(import_generate_settings, inputs=None, outputs=generate_settings) + ui.load(import_generate_settings_proxy, inputs=None, outputs=generate_settings) if args.check_for_updates: ui.load(check_for_updates)