From 954091697fce7a1b7997d5f3d73551f793f6bebc Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 11 Jan 2023 09:10:07 +0300
Subject: [PATCH] add an option to copy config from one of models in checkpoint
 merger

---
 modules/extras.py | 30 +++++++++++++++++++++++++++++-
 modules/ui.py     |  9 ++++++---
 2 files changed, 35 insertions(+), 4 deletions(-)

diff --git a/modules/extras.py b/modules/extras.py
index 7407bfe3..a03d558e 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -3,6 +3,7 @@ import math
 import os
 import sys
 import traceback
+import shutil
 
 import numpy as np
 from PIL import Image
@@ -248,7 +249,32 @@ def run_pnginfo(image):
     return '', geninfo, info
 
 
-def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format):
+def create_config(ckpt_result, config_source, a, b, c):
+    def config(x):
+        return sd_models.find_checkpoint_config(x) if x else None
+
+    if config_source == 0:
+        cfg = config(a) or config(b) or config(c)
+    elif config_source == 1:
+        cfg = config(b)
+    elif config_source == 2:
+        cfg = config(c)
+    else:
+        cfg = None
+
+    if cfg is None:
+        return
+
+    filename, _ = os.path.splitext(ckpt_result)
+    checkpoint_filename = filename + ".yaml"
+
+    print("Copying config:")
+    print("   from:", cfg)
+    print("     to:", checkpoint_filename)
+    shutil.copyfile(cfg, checkpoint_filename)
+
+
+def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source):
     shared.state.begin()
     shared.state.job = 'model-merge'
 
@@ -356,6 +382,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
 
     sd_models.list_models()
 
+    create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
+
     print("Checkpoint saved.")
     shared.state.textinfo = "Checkpoint saved to " + output_modelname
     shared.state.end()
diff --git a/modules/ui.py b/modules/ui.py
index 3c458ce8..82f5dd7c 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1129,7 +1129,7 @@ def create_ui():
             with gr.Column(variant='panel'):
                 gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
 
-                with gr.Row():
+                with FormRow():
                     primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
                     create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")
 
@@ -1143,11 +1143,13 @@ def create_ui():
                 interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
                 interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
 
-                with gr.Row():
+                with FormRow():
                     checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
                     save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
 
-                modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
+                config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method")
+
+                modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
 
             with gr.Column(variant='panel'):
                 submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
@@ -1703,6 +1705,7 @@ def create_ui():
                 save_as_half,
                 custom_name,
                 checkpoint_format,
+                config_source,
             ],
             outputs=[
                 submit_result,