added mixing models (shamelessly inspired from voldy's web ui)
This commit is contained in:
parent
c89c648b4a
commit
f66281f10c
46
src/utils.py
46
src/utils.py
|
@ -2862,3 +2862,49 @@ def unload_whisper():
|
|||
print("Unloaded Whisper")
|
||||
|
||||
do_gc()
|
||||
|
||||
# shamelessly borrowed from Voldy's Web UI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/modules/extras.py#L74
|
||||
def merge_models( primary_model_name, secondary_model_name, alpha, progress=gr.Progress() ):
|
||||
key_blacklist = []
|
||||
|
||||
def weighted_sum(theta0, theta1, alpha):
|
||||
return ((1 - alpha) * theta0) + (alpha * theta1)
|
||||
|
||||
def read_model( filename ):
|
||||
print(f"Loading {filename}")
|
||||
return torch.load(filename)
|
||||
|
||||
theta_func = weighted_sum
|
||||
|
||||
theta_0 = read_model(primary_model_name)
|
||||
theta_1 = read_model(secondary_model_name)
|
||||
|
||||
for key in enumerate_progress(theta_0.keys(), desc="Merging...", progress=progress):
|
||||
if key in key_blacklist:
|
||||
print("Skipping ignored key:", key)
|
||||
continue
|
||||
|
||||
a = theta_0[key]
|
||||
b = theta_1[key]
|
||||
|
||||
if a.dtype != torch.float32 and a.dtype != torch.float16:
|
||||
print("Skipping key:", key, a.dtype)
|
||||
continue
|
||||
|
||||
if b.dtype != torch.float32 and b.dtype != torch.float16:
|
||||
print("Skipping key:", key, b.dtype)
|
||||
continue
|
||||
|
||||
theta_0[key] = theta_func(a, b, alpha)
|
||||
|
||||
del theta_1
|
||||
|
||||
primary_basename = os.path.splitext(os.path.basename(primary_model_name))[0]
|
||||
secondary_basename = os.path.splitext(os.path.basename(secondary_model_name))[0]
|
||||
suffix = "{:.3f}".format(alpha)
|
||||
output_path = f'./models/finetunes/{primary_basename}_{secondary_basename}_{suffix}_merge.pth'
|
||||
|
||||
torch.save(theta_0, output_path)
|
||||
message = f"Saved to {output_path}"
|
||||
print(message)
|
||||
return message
|
19
src/webui.py
19
src/webui.py
|
@ -27,6 +27,7 @@ GENERATE_SETTINGS = {}
|
|||
TRANSCRIBE_SETTINGS = {}
|
||||
EXEC_SETTINGS = {}
|
||||
TRAINING_SETTINGS = {}
|
||||
MERGER_SETTINGS = {}
|
||||
GENERATE_SETTINGS_ARGS = []
|
||||
|
||||
PRESETS = {
|
||||
|
@ -359,7 +360,7 @@ def setup_gradio():
|
|||
GENERATE_SETTINGS["candidates"] = gr.Slider(value=1, minimum=1, maximum=6, step=1, label="Candidates")
|
||||
GENERATE_SETTINGS["seed"] = gr.Number(value=0, precision=0, label="Seed")
|
||||
|
||||
preset = gr.Radio( ["Ultra Fast", "Fast", "Standard", "High Quality"], label="Preset", type="value" )
|
||||
preset = gr.Radio( ["Ultra Fast", "Fast", "Standard", "High Quality"], label="Preset", type="value", value="Ultra Fast" )
|
||||
|
||||
GENERATE_SETTINGS["num_autoregressive_samples"] = gr.Slider(value=16, minimum=2, maximum=512, step=1, label="Samples")
|
||||
GENERATE_SETTINGS["diffusion_iterations"] = gr.Slider(value=30, minimum=0, maximum=512, step=1, label="Iterations")
|
||||
|
@ -435,6 +436,17 @@ def setup_gradio():
|
|||
|
||||
with gr.Row():
|
||||
text_tokenizier_button = gr.Button(value="Tokenize Text")
|
||||
with gr.Tab("Model Merger"):
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
MERGER_SETTINGS["model_a"] = gr.Dropdown( choices=autoregressive_models, label="Model A", type="value", value=autoregressive_models[0] )
|
||||
MERGER_SETTINGS["model_b"] = gr.Dropdown( choices=autoregressive_models, label="Model B", type="value", value=autoregressive_models[0] )
|
||||
with gr.Row():
|
||||
MERGER_SETTINGS["weight_slider"] = gr.Slider(label="Weight (from A to B)", value=0.5, minimum=0, maximum=1)
|
||||
with gr.Row():
|
||||
merger_button = gr.Button(value="Run Merger")
|
||||
with gr.Column():
|
||||
merger_output = gr.TextArea(label="Console Output", max_lines=8)
|
||||
with gr.Tab("Training"):
|
||||
with gr.Tab("Prepare Dataset"):
|
||||
with gr.Row():
|
||||
|
@ -777,6 +789,11 @@ def setup_gradio():
|
|||
outputs=text_tokenizier_output
|
||||
)
|
||||
|
||||
merger_button.click(merge_models,
|
||||
inputs=list(MERGER_SETTINGS.values()),
|
||||
outputs=merger_output
|
||||
)
|
||||
|
||||
refresh_configs.click(
|
||||
lambda: gr.update(choices=get_training_list()),
|
||||
inputs=None,
|
||||
|
|
Loading…
Reference in New Issue
Block a user