fix auto fill and repair separate axisoptions

This commit is contained in:
Vladimir Repin 2023-01-21 22:43:37 +03:00
parent f53527f778
commit ac2eb97db9

View File

@ -165,10 +165,14 @@ class AxisOption:
self.confirm = confirm self.confirm = confirm
self.cost = cost self.cost = cost
self.choices = choices self.choices = choices
self.is_img2img = False
class AxisOptionImg2Img(AxisOption): class AxisOptionImg2Img(AxisOption):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_img2img = True
class AxisOptionTxt2Img(AxisOption):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.is_img2img = False self.is_img2img = False
@ -183,7 +187,8 @@ axis_options = [
AxisOption("CFG Scale", float, apply_field("cfg_scale")), AxisOption("CFG Scale", float, apply_field("cfg_scale")),
AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value), AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value),
AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list), AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),
AxisOption("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]), AxisOptionTxt2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
AxisOptionImg2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]),
AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: list(sd_models.checkpoints_list)), AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: list(sd_models.checkpoints_list)),
AxisOption("Sigma Churn", float, apply_field("s_churn")), AxisOption("Sigma Churn", float, apply_field("s_churn")),
AxisOption("Sigma min", float, apply_field("s_tmin")), AxisOption("Sigma min", float, apply_field("s_tmin")),
@ -192,8 +197,8 @@ axis_options = [
AxisOption("Eta", float, apply_field("eta")), AxisOption("Eta", float, apply_field("eta")),
AxisOption("Clip skip", int, apply_clip_skip), AxisOption("Clip skip", int, apply_clip_skip),
AxisOption("Denoising", float, apply_field("denoising_strength")), AxisOption("Denoising", float, apply_field("denoising_strength")),
AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [x.name for x in shared.sd_upscalers]), AxisOptionTxt2Img("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]]),
AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")), AxisOptionImg2Img("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")),
AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: list(sd_vae.vae_dict)), AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: list(sd_vae.vae_dict)),
AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)), AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)),
] ]
@ -288,7 +293,7 @@ class Script(scripts.Script):
return "X/Y plot" return "X/Y plot"
def ui(self, is_img2img): def ui(self, is_img2img):
current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img and is_img2img] current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img == is_img2img]
with gr.Row(): with gr.Row():
with gr.Column(scale=19): with gr.Column(scale=19):
@ -316,14 +321,14 @@ class Script(scripts.Script):
swap_axes_button.click(swap_axes, inputs=swap_args, outputs=swap_args) swap_axes_button.click(swap_axes, inputs=swap_args, outputs=swap_args)
def fill(x_type): def fill(x_type):
axis = axis_options[x_type] axis = current_axis_options[x_type]
return ", ".join(axis.choices()) if axis.choices else gr.update() return ", ".join(axis.choices()) if axis.choices else gr.update()
fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values]) fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values])
fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values]) fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values])
def select_axis(x_type): def select_axis(x_type):
return gr.Button.update(visible=axis_options[x_type].choices is not None) return gr.Button.update(visible=current_axis_options[x_type].choices is not None)
x_type.change(fn=select_axis, inputs=[x_type], outputs=[fill_x_button]) x_type.change(fn=select_axis, inputs=[x_type], outputs=[fill_x_button])
y_type.change(fn=select_axis, inputs=[y_type], outputs=[fill_y_button]) y_type.change(fn=select_axis, inputs=[y_type], outputs=[fill_y_button])