From f66281f10cb952706fa97669d8d8c37cb7a261c1 Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 29 Mar 2023 19:29:13 +0000 Subject: [PATCH] added mixing models (shamelessly inspired from voldy's web ui) --- src/utils.py | 48 +++++++++++++++++++++++++++++++++++++++++++++++- src/webui.py | 19 ++++++++++++++++++- 2 files changed, 65 insertions(+), 2 deletions(-) diff --git a/src/utils.py b/src/utils.py index d6f6d8f..a82285e 100755 --- a/src/utils.py +++ b/src/utils.py @@ -2861,4 +2861,50 @@ def unload_whisper(): whisper_model = None print("Unloaded Whisper") - do_gc() \ No newline at end of file + do_gc() + +# shamelessly borrowed from Voldy's Web UI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/modules/extras.py#L74 +def merge_models( primary_model_name, secondary_model_name, alpha, progress=gr.Progress() ): + key_blacklist = [] + + def weighted_sum(theta0, theta1, alpha): + return ((1 - alpha) * theta0) + (alpha * theta1) + + def read_model( filename ): + print(f"Loading {filename}") + return torch.load(filename) + + theta_func = weighted_sum + + theta_0 = read_model(primary_model_name) + theta_1 = read_model(secondary_model_name) + + for key in enumerate_progress(theta_0.keys(), desc="Merging...", progress=progress): + if key in key_blacklist: + print("Skipping ignored key:", key) + continue + + a = theta_0[key] + b = theta_1[key] + + if a.dtype != torch.float32 and a.dtype != torch.float16: + print("Skipping key:", key, a.dtype) + continue + + if b.dtype != torch.float32 and b.dtype != torch.float16: + print("Skipping key:", key, b.dtype) + continue + + theta_0[key] = theta_func(a, b, alpha) + + del theta_1 + + primary_basename = os.path.splitext(os.path.basename(primary_model_name))[0] + secondary_basename = os.path.splitext(os.path.basename(secondary_model_name))[0] + suffix = "{:.3f}".format(alpha) + output_path = f'./models/finetunes/{primary_basename}_{secondary_basename}_{suffix}_merge.pth' + + torch.save(theta_0, output_path) + message = f"Saved to {output_path}" + print(message) + return message \ No newline at end of file diff --git a/src/webui.py b/src/webui.py index ac5d815..1abed5a 100755 --- a/src/webui.py +++ b/src/webui.py @@ -27,6 +27,7 @@ GENERATE_SETTINGS = {} TRANSCRIBE_SETTINGS = {} EXEC_SETTINGS = {} TRAINING_SETTINGS = {} +MERGER_SETTINGS = {} GENERATE_SETTINGS_ARGS = [] PRESETS = { @@ -359,7 +360,7 @@ def setup_gradio(): GENERATE_SETTINGS["candidates"] = gr.Slider(value=1, minimum=1, maximum=6, step=1, label="Candidates") GENERATE_SETTINGS["seed"] = gr.Number(value=0, precision=0, label="Seed") - preset = gr.Radio( ["Ultra Fast", "Fast", "Standard", "High Quality"], label="Preset", type="value" ) + preset = gr.Radio( ["Ultra Fast", "Fast", "Standard", "High Quality"], label="Preset", type="value", value="Ultra Fast" ) GENERATE_SETTINGS["num_autoregressive_samples"] = gr.Slider(value=16, minimum=2, maximum=512, step=1, label="Samples") GENERATE_SETTINGS["diffusion_iterations"] = gr.Slider(value=30, minimum=0, maximum=512, step=1, label="Iterations") @@ -435,6 +436,17 @@ def setup_gradio(): with gr.Row(): text_tokenizier_button = gr.Button(value="Tokenize Text") + with gr.Tab("Model Merger"): + with gr.Column(): + with gr.Row(): + MERGER_SETTINGS["model_a"] = gr.Dropdown( choices=autoregressive_models, label="Model A", type="value", value=autoregressive_models[0] ) + MERGER_SETTINGS["model_b"] = gr.Dropdown( choices=autoregressive_models, label="Model B", type="value", value=autoregressive_models[0] ) + with gr.Row(): + MERGER_SETTINGS["weight_slider"] = gr.Slider(label="Weight (from A to B)", value=0.5, minimum=0, maximum=1) + with gr.Row(): + merger_button = gr.Button(value="Run Merger") + with gr.Column(): + merger_output = gr.TextArea(label="Console Output", max_lines=8) with gr.Tab("Training"): with gr.Tab("Prepare Dataset"): with gr.Row(): @@ -777,6 +789,11 @@ def setup_gradio(): outputs=text_tokenizier_output ) + merger_button.click(merge_models, + inputs=list(MERGER_SETTINGS.values()), + outputs=merger_output + ) + refresh_configs.click( lambda: gr.update(choices=get_training_list()), inputs=None,