diff --git a/modules/extras.py b/modules/extras.py index 88eea22e..367c15cc 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -278,6 +278,11 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam shared.state.begin() shared.state.job = 'model-merge' + def fail(message): + shared.state.textinfo = message + shared.state.end() + return [message, *[gr.update() for _ in range(4)]] + def weighted_sum(theta0, theta1, alpha): return ((1 - alpha) * theta0) + (alpha * theta1) @@ -288,16 +293,12 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam return theta0 + (alpha * theta1_2_diff) if not primary_model_name: - shared.state.textinfo = "Failed: Merging requires a primary model." - shared.state.end() - return ["Failed: Merging requires a primary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)] + return fail("Failed: Merging requires a primary model.") primary_model_info = sd_models.checkpoints_list[primary_model_name] if not secondary_model_name: - shared.state.textinfo = "Failed: Merging requires a secondary model." - shared.state.end() - return ["Failed: Merging requires a secondary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)] + return fail("Failed: Merging requires a secondary model.") secondary_model_info = sd_models.checkpoints_list[secondary_model_name] @@ -308,9 +309,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam theta_func1, theta_func2 = theta_funcs[interp_method] if theta_func1 and not tertiary_model_name: - shared.state.textinfo = "Failed: Interpolation method requires a tertiary model." - shared.state.end() - return [f"Failed: Interpolation method ({interp_method}) requires a tertiary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)] + return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.") tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None