added mixing models (shamelessly inspired from voldy's web ui)
This commit is contained in:
parent
c89c648b4a
commit
f66281f10c
48
src/utils.py
48
src/utils.py
|
@ -2861,4 +2861,50 @@ def unload_whisper():
|
||||||
whisper_model = None
|
whisper_model = None
|
||||||
print("Unloaded Whisper")
|
print("Unloaded Whisper")
|
||||||
|
|
||||||
do_gc()
|
do_gc()
|
||||||
|
|
||||||
|
# shamelessly borrowed from Voldy's Web UI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/modules/extras.py#L74
|
||||||
|
def merge_models( primary_model_name, secondary_model_name, alpha, progress=gr.Progress() ):
|
||||||
|
key_blacklist = []
|
||||||
|
|
||||||
|
def weighted_sum(theta0, theta1, alpha):
|
||||||
|
return ((1 - alpha) * theta0) + (alpha * theta1)
|
||||||
|
|
||||||
|
def read_model( filename ):
|
||||||
|
print(f"Loading {filename}")
|
||||||
|
return torch.load(filename)
|
||||||
|
|
||||||
|
theta_func = weighted_sum
|
||||||
|
|
||||||
|
theta_0 = read_model(primary_model_name)
|
||||||
|
theta_1 = read_model(secondary_model_name)
|
||||||
|
|
||||||
|
for key in enumerate_progress(theta_0.keys(), desc="Merging...", progress=progress):
|
||||||
|
if key in key_blacklist:
|
||||||
|
print("Skipping ignored key:", key)
|
||||||
|
continue
|
||||||
|
|
||||||
|
a = theta_0[key]
|
||||||
|
b = theta_1[key]
|
||||||
|
|
||||||
|
if a.dtype != torch.float32 and a.dtype != torch.float16:
|
||||||
|
print("Skipping key:", key, a.dtype)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if b.dtype != torch.float32 and b.dtype != torch.float16:
|
||||||
|
print("Skipping key:", key, b.dtype)
|
||||||
|
continue
|
||||||
|
|
||||||
|
theta_0[key] = theta_func(a, b, alpha)
|
||||||
|
|
||||||
|
del theta_1
|
||||||
|
|
||||||
|
primary_basename = os.path.splitext(os.path.basename(primary_model_name))[0]
|
||||||
|
secondary_basename = os.path.splitext(os.path.basename(secondary_model_name))[0]
|
||||||
|
suffix = "{:.3f}".format(alpha)
|
||||||
|
output_path = f'./models/finetunes/{primary_basename}_{secondary_basename}_{suffix}_merge.pth'
|
||||||
|
|
||||||
|
torch.save(theta_0, output_path)
|
||||||
|
message = f"Saved to {output_path}"
|
||||||
|
print(message)
|
||||||
|
return message
|
19
src/webui.py
19
src/webui.py
|
@ -27,6 +27,7 @@ GENERATE_SETTINGS = {}
|
||||||
TRANSCRIBE_SETTINGS = {}
|
TRANSCRIBE_SETTINGS = {}
|
||||||
EXEC_SETTINGS = {}
|
EXEC_SETTINGS = {}
|
||||||
TRAINING_SETTINGS = {}
|
TRAINING_SETTINGS = {}
|
||||||
|
MERGER_SETTINGS = {}
|
||||||
GENERATE_SETTINGS_ARGS = []
|
GENERATE_SETTINGS_ARGS = []
|
||||||
|
|
||||||
PRESETS = {
|
PRESETS = {
|
||||||
|
@ -359,7 +360,7 @@ def setup_gradio():
|
||||||
GENERATE_SETTINGS["candidates"] = gr.Slider(value=1, minimum=1, maximum=6, step=1, label="Candidates")
|
GENERATE_SETTINGS["candidates"] = gr.Slider(value=1, minimum=1, maximum=6, step=1, label="Candidates")
|
||||||
GENERATE_SETTINGS["seed"] = gr.Number(value=0, precision=0, label="Seed")
|
GENERATE_SETTINGS["seed"] = gr.Number(value=0, precision=0, label="Seed")
|
||||||
|
|
||||||
preset = gr.Radio( ["Ultra Fast", "Fast", "Standard", "High Quality"], label="Preset", type="value" )
|
preset = gr.Radio( ["Ultra Fast", "Fast", "Standard", "High Quality"], label="Preset", type="value", value="Ultra Fast" )
|
||||||
|
|
||||||
GENERATE_SETTINGS["num_autoregressive_samples"] = gr.Slider(value=16, minimum=2, maximum=512, step=1, label="Samples")
|
GENERATE_SETTINGS["num_autoregressive_samples"] = gr.Slider(value=16, minimum=2, maximum=512, step=1, label="Samples")
|
||||||
GENERATE_SETTINGS["diffusion_iterations"] = gr.Slider(value=30, minimum=0, maximum=512, step=1, label="Iterations")
|
GENERATE_SETTINGS["diffusion_iterations"] = gr.Slider(value=30, minimum=0, maximum=512, step=1, label="Iterations")
|
||||||
|
@ -435,6 +436,17 @@ def setup_gradio():
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
text_tokenizier_button = gr.Button(value="Tokenize Text")
|
text_tokenizier_button = gr.Button(value="Tokenize Text")
|
||||||
|
with gr.Tab("Model Merger"):
|
||||||
|
with gr.Column():
|
||||||
|
with gr.Row():
|
||||||
|
MERGER_SETTINGS["model_a"] = gr.Dropdown( choices=autoregressive_models, label="Model A", type="value", value=autoregressive_models[0] )
|
||||||
|
MERGER_SETTINGS["model_b"] = gr.Dropdown( choices=autoregressive_models, label="Model B", type="value", value=autoregressive_models[0] )
|
||||||
|
with gr.Row():
|
||||||
|
MERGER_SETTINGS["weight_slider"] = gr.Slider(label="Weight (from A to B)", value=0.5, minimum=0, maximum=1)
|
||||||
|
with gr.Row():
|
||||||
|
merger_button = gr.Button(value="Run Merger")
|
||||||
|
with gr.Column():
|
||||||
|
merger_output = gr.TextArea(label="Console Output", max_lines=8)
|
||||||
with gr.Tab("Training"):
|
with gr.Tab("Training"):
|
||||||
with gr.Tab("Prepare Dataset"):
|
with gr.Tab("Prepare Dataset"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
@ -777,6 +789,11 @@ def setup_gradio():
|
||||||
outputs=text_tokenizier_output
|
outputs=text_tokenizier_output
|
||||||
)
|
)
|
||||||
|
|
||||||
|
merger_button.click(merge_models,
|
||||||
|
inputs=list(MERGER_SETTINGS.values()),
|
||||||
|
outputs=merger_output
|
||||||
|
)
|
||||||
|
|
||||||
refresh_configs.click(
|
refresh_configs.click(
|
||||||
lambda: gr.update(choices=get_training_list()),
|
lambda: gr.update(choices=get_training_list()),
|
||||||
inputs=None,
|
inputs=None,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user