diff --git a/modules/extras.py b/modules/extras.py index 79047f3a..cffe0381 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -7,7 +7,7 @@ from PIL import Image import torch import tqdm -from typing import Callable, List, Tuple +from typing import Callable, Dict, List, Tuple from functools import partial from dataclasses import dataclass @@ -21,7 +21,18 @@ import piexif.helper import gradio as gr -cached_images = {} +@dataclass(frozen=True) +class CacheKey: + image_hash: int + info_hash: int + args_hash: int + +@dataclass +class CacheEntry: + image: Image.Image + info: str + +cached_images: Dict[CacheKey, CacheEntry] = {} def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool ): @@ -84,22 +95,13 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop): - small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10)) - pixels = tuple(np.array(small).flatten().tolist()) - key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight, - resize_mode, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop) + pixels - - c = cached_images.get(key) - if c is None: - upscaler = shared.sd_upscalers[scaler_index] - c = upscaler.scaler.upscale(image, resize, upscaler.data_path) - if mode == 1 and crop: - cropped = Image.new("RGB", (resize_w, resize_h)) - cropped.paste(c, box=(resize_w // 2 - c.width // 2, resize_h // 2 - c.height // 2)) - c = cropped - cached_images[key] = c - return c - + upscaler = shared.sd_upscalers[scaler_index] + res = upscaler.scaler.upscale(image, resize, upscaler.data_path) + if mode == 1 and crop: + cropped = Image.new("RGB", (resize_w, resize_h)) + cropped.paste(res, box=(resize_w // 2 - res.width // 2, resize_h // 2 - res.height // 2)) + res = cropped + return res def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]: # Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text @@ -118,8 +120,18 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ def run_upscalers_blend( params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]: blended_result: Image.Image = None for upscaler in params: - res = upscale(image, upscaler.upscaler_idx, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop) - info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n" + upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop) + cache_key = CacheKey( image_hash = hash(np.array(image.getdata()).tobytes()), + info_hash = hash(info), + args_hash = hash(upscale_args) ) + cached_entry = cached_images.get(cache_key) + if cached_entry is None: + res = upscale(image, *upscale_args) + info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n" + cached_images[cache_key] = CacheEntry(image=res, info=info) + else: + res, info = cached_entry.image, cached_entry.info + if blended_result is None: blended_result = res else: