add progress bar to modelmerger
This commit is contained in:
parent
7cfc645030
commit
c7e50425f6
|
@ -172,6 +172,17 @@ function submit_img2img(){
|
|||
return res
|
||||
}
|
||||
|
||||
function modelmerger(){
|
||||
var id = randomId()
|
||||
requestProgress(id, gradioApp().getElementById('modelmerger_results_panel'), null, function(){})
|
||||
|
||||
gradioApp().getElementById('modelmerger_result').innerHTML = ''
|
||||
|
||||
var res = create_submit_args(arguments)
|
||||
res[0] = id
|
||||
return res
|
||||
}
|
||||
|
||||
|
||||
function ask_for_style_name(_, prompt_text, negative_prompt_text) {
|
||||
name_ = prompt('Style name:')
|
||||
|
|
|
@ -274,14 +274,15 @@ def create_config(ckpt_result, config_source, a, b, c):
|
|||
shutil.copyfile(cfg, checkpoint_filename)
|
||||
|
||||
|
||||
def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source):
|
||||
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source):
|
||||
shared.state.begin()
|
||||
shared.state.job = 'model-merge'
|
||||
shared.state.job_count = 1
|
||||
|
||||
def fail(message):
|
||||
shared.state.textinfo = message
|
||||
shared.state.end()
|
||||
return [message, *[gr.update() for _ in range(4)]]
|
||||
return [*[gr.update() for _ in range(4)], message]
|
||||
|
||||
def weighted_sum(theta0, theta1, alpha):
|
||||
return ((1 - alpha) * theta0) + (alpha * theta1)
|
||||
|
@ -320,9 +321,12 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
|||
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
|
||||
|
||||
if theta_func1:
|
||||
shared.state.job_count += 1
|
||||
|
||||
print(f"Loading {tertiary_model_info.filename}...")
|
||||
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
|
||||
|
||||
shared.state.sampling_steps = len(theta_1.keys())
|
||||
for key in tqdm.tqdm(theta_1.keys()):
|
||||
if 'model' in key:
|
||||
if key in theta_2:
|
||||
|
@ -330,8 +334,12 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
|||
theta_1[key] = theta_func1(theta_1[key], t2)
|
||||
else:
|
||||
theta_1[key] = torch.zeros_like(theta_1[key])
|
||||
|
||||
shared.state.sampling_step += 1
|
||||
del theta_2
|
||||
|
||||
shared.state.nextjob()
|
||||
|
||||
shared.state.textinfo = f"Loading {primary_model_info.filename}..."
|
||||
print(f"Loading {primary_model_info.filename}...")
|
||||
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
|
||||
|
@ -340,6 +348,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
|||
|
||||
chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
|
||||
|
||||
shared.state.sampling_steps = len(theta_0.keys())
|
||||
for key in tqdm.tqdm(theta_0.keys()):
|
||||
if 'model' in key and key in theta_1:
|
||||
|
||||
|
@ -367,6 +376,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
|||
if save_as_half:
|
||||
theta_0[key] = theta_0[key].half()
|
||||
|
||||
shared.state.sampling_step += 1
|
||||
|
||||
# I believe this part should be discarded, but I'll leave it for now until I am sure
|
||||
for key in theta_1.keys():
|
||||
if 'model' in key and key not in theta_0:
|
||||
|
@ -393,6 +404,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
|||
|
||||
output_modelname = os.path.join(ckpt_dir, filename)
|
||||
|
||||
shared.state.nextjob()
|
||||
shared.state.textinfo = f"Saving to {output_modelname}..."
|
||||
print(f"Saving to {output_modelname}...")
|
||||
|
||||
|
@ -410,4 +422,4 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
|||
shared.state.textinfo = "Checkpoint saved to " + output_modelname
|
||||
shared.state.end()
|
||||
|
||||
return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
|
||||
return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname]
|
||||
|
|
|
@ -72,7 +72,7 @@ def progressapi(req: ProgressRequest):
|
|||
|
||||
if job_count > 0:
|
||||
progress += job_no / job_count
|
||||
if sampling_steps > 0:
|
||||
if sampling_steps > 0 and job_count > 0:
|
||||
progress += 1 / job_count * sampling_step / sampling_steps
|
||||
|
||||
progress = min(progress, 1)
|
||||
|
|
|
@ -1208,8 +1208,9 @@ def create_ui():
|
|||
with gr.Row():
|
||||
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
|
||||
|
||||
with gr.Column(variant='panel'):
|
||||
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
|
||||
with gr.Column(variant='compact', elem_id="modelmerger_results_container"):
|
||||
with gr.Group(elem_id="modelmerger_results_panel"):
|
||||
modelmerger_result = gr.HTML(elem_id="modelmerger_result", show_label=False)
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as train_interface:
|
||||
with gr.Row().style(equal_height=False):
|
||||
|
@ -1753,12 +1754,14 @@ def create_ui():
|
|||
print("Error loading/saving model file:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
modules.sd_models.list_models() # to remove the potentially missing models from the list
|
||||
return [f"Error merging checkpoints: {e}"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)]
|
||||
return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
|
||||
return results
|
||||
|
||||
modelmerger_merge.click(
|
||||
fn=modelmerger,
|
||||
fn=wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]),
|
||||
_js='modelmerger',
|
||||
inputs=[
|
||||
dummy_component,
|
||||
primary_model_name,
|
||||
secondary_model_name,
|
||||
tertiary_model_name,
|
||||
|
@ -1770,11 +1773,11 @@ def create_ui():
|
|||
config_source,
|
||||
],
|
||||
outputs=[
|
||||
submit_result,
|
||||
primary_model_name,
|
||||
secondary_model_name,
|
||||
tertiary_model_name,
|
||||
component_dict['sd_model_checkpoint'],
|
||||
modelmerger_result,
|
||||
]
|
||||
)
|
||||
|
||||
|
|
|
@ -737,6 +737,11 @@ footer {
|
|||
line-height: 2.4em;
|
||||
}
|
||||
|
||||
#modelmerger_results_container{
|
||||
margin-top: 1em;
|
||||
overflow: visible;
|
||||
}
|
||||
|
||||
/* The following handles localization for right-to-left (RTL) languages like Arabic.
|
||||
The rtl media type will only be activated by the logic in javascript/localization.js.
|
||||
If you change anything above, you need to make sure it is RTL compliant by just running
|
||||
|
|
Loading…
Reference in New Issue
Block a user