added mixing models (shamelessly inspired from voldy's web ui)

remotes/1710271458855113467/master
mrq 2023-03-29 19:29:13 +07:00
parent c89c648b4a
commit f66281f10c
2 changed files with 65 additions and 2 deletions

@ -2861,4 +2861,50 @@ def unload_whisper():
whisper_model = None
print("Unloaded Whisper")
do_gc()
do_gc()
# shamelessly borrowed from Voldy's Web UI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/modules/extras.py#L74
def merge_models( primary_model_name, secondary_model_name, alpha, progress=gr.Progress() ):
key_blacklist = []
def weighted_sum(theta0, theta1, alpha):
return ((1 - alpha) * theta0) + (alpha * theta1)
def read_model( filename ):
print(f"Loading {filename}")
return torch.load(filename)
theta_func = weighted_sum
theta_0 = read_model(primary_model_name)
theta_1 = read_model(secondary_model_name)
for key in enumerate_progress(theta_0.keys(), desc="Merging...", progress=progress):
if key in key_blacklist:
print("Skipping ignored key:", key)
continue
a = theta_0[key]
b = theta_1[key]
if a.dtype != torch.float32 and a.dtype != torch.float16:
print("Skipping key:", key, a.dtype)
continue
if b.dtype != torch.float32 and b.dtype != torch.float16:
print("Skipping key:", key, b.dtype)
continue
theta_0[key] = theta_func(a, b, alpha)
del theta_1
primary_basename = os.path.splitext(os.path.basename(primary_model_name))[0]
secondary_basename = os.path.splitext(os.path.basename(secondary_model_name))[0]
suffix = "{:.3f}".format(alpha)
output_path = f'./models/finetunes/{primary_basename}_{secondary_basename}_{suffix}_merge.pth'
torch.save(theta_0, output_path)
message = f"Saved to {output_path}"
print(message)
return message

@ -27,6 +27,7 @@ GENERATE_SETTINGS = {}
TRANSCRIBE_SETTINGS = {}
EXEC_SETTINGS = {}
TRAINING_SETTINGS = {}
MERGER_SETTINGS = {}
GENERATE_SETTINGS_ARGS = []
PRESETS = {
@ -359,7 +360,7 @@ def setup_gradio():
GENERATE_SETTINGS["candidates"] = gr.Slider(value=1, minimum=1, maximum=6, step=1, label="Candidates")
GENERATE_SETTINGS["seed"] = gr.Number(value=0, precision=0, label="Seed")
preset = gr.Radio( ["Ultra Fast", "Fast", "Standard", "High Quality"], label="Preset", type="value" )
preset = gr.Radio( ["Ultra Fast", "Fast", "Standard", "High Quality"], label="Preset", type="value", value="Ultra Fast" )
GENERATE_SETTINGS["num_autoregressive_samples"] = gr.Slider(value=16, minimum=2, maximum=512, step=1, label="Samples")
GENERATE_SETTINGS["diffusion_iterations"] = gr.Slider(value=30, minimum=0, maximum=512, step=1, label="Iterations")
@ -435,6 +436,17 @@ def setup_gradio():
with gr.Row():
text_tokenizier_button = gr.Button(value="Tokenize Text")
with gr.Tab("Model Merger"):
with gr.Column():
with gr.Row():
MERGER_SETTINGS["model_a"] = gr.Dropdown( choices=autoregressive_models, label="Model A", type="value", value=autoregressive_models[0] )
MERGER_SETTINGS["model_b"] = gr.Dropdown( choices=autoregressive_models, label="Model B", type="value", value=autoregressive_models[0] )
with gr.Row():
MERGER_SETTINGS["weight_slider"] = gr.Slider(label="Weight (from A to B)", value=0.5, minimum=0, maximum=1)
with gr.Row():
merger_button = gr.Button(value="Run Merger")
with gr.Column():
merger_output = gr.TextArea(label="Console Output", max_lines=8)
with gr.Tab("Training"):
with gr.Tab("Prepare Dataset"):
with gr.Row():
@ -777,6 +789,11 @@ def setup_gradio():
outputs=text_tokenizier_output
)
merger_button.click(merge_models,
inputs=list(MERGER_SETTINGS.values()),
outputs=merger_output
)
refresh_configs.click(
lambda: gr.update(choices=get_training_list()),
inputs=None,