From 740070ea9cdb254209f66417418f2a4af8b099d6 Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Mon, 26 Sep 2022 09:29:50 -0500 Subject: [PATCH] Re-implement universal model loading --- modules/codeformer_model.py | 35 +++++--- modules/esrgan_model.py | 56 +++++++++---- modules/extras.py | 2 + modules/gfpgan_model.py | 62 +++++++-------- modules/gfpgan_model_arch.py | 150 +++++++++++++++++++++++++++++++++++ modules/ldsr_model.py | 45 ++++++----- modules/modelloader.py | 65 +++++++++++++++ modules/paths.py | 3 +- modules/realesrgan_model.py | 23 +++++- modules/shared.py | 12 ++- modules/swinir_model.py | 83 +++++++++++++------ webui.py | 47 +++++------ 12 files changed, 449 insertions(+), 134 deletions(-) create mode 100644 modules/gfpgan_model_arch.py create mode 100644 modules/modelloader.py diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index 8fbdea24..dc0a5eee 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -5,22 +5,28 @@ import traceback import cv2 import torch -from modules import shared, devices -from modules.paths import script_path +from modules import shared, devices, modelloader +from modules.paths import script_path, models_path import modules.shared import modules.face_restoration from importlib import reload -# codeformer people made a choice to include modified basicsr librry to their projectwhich makes -# it utterly impossiblr to use it alongside with other libraries that also use basicsr, like GFPGAN. +# codeformer people made a choice to include modified basicsr library to their project, which makes +# it utterly impossible to use it alongside other libraries that also use basicsr, like GFPGAN. # I am making a choice to include some files from codeformer to work around this issue. - -pretrain_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' +model_dir = "Codeformer" +model_path = os.path.join(models_path, model_dir) +model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' have_codeformer = False codeformer = None -def setup_codeformer(): + +def setup_model(dirname): + global model_path + if not os.path.exists(model_path): + os.makedirs(model_path) + path = modules.paths.paths.get("CodeFormer", None) if path is None: return @@ -44,16 +50,22 @@ def setup_codeformer(): def name(self): return "CodeFormer" - def __init__(self): + def __init__(self, dirname): self.net = None self.face_helper = None + self.cmd_dir = dirname def create_models(self): if self.net is not None and self.face_helper is not None: self.net.to(devices.device_codeformer) return self.net, self.face_helper - + model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir) + if len(model_paths) != 0: + ckpt_path = model_paths[0] + else: + print("Unable to load codeformer model.") + return None, None net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer) ckpt_path = load_file_from_url(url=pretrain_model_url, model_dir=os.path.join(path, 'weights/CodeFormer'), progress=True) checkpoint = torch.load(ckpt_path)['params_ema'] @@ -74,6 +86,9 @@ def setup_codeformer(): original_resolution = np_image.shape[0:2] self.create_models() + if self.net is None or self.face_helper is None: + return np_image + self.face_helper.clean_all() self.face_helper.read_image(np_image) self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) @@ -114,7 +129,7 @@ def setup_codeformer(): have_codeformer = True global codeformer - codeformer = FaceRestorerCodeFormer() + codeformer = FaceRestorerCodeFormer(dirname) shared.face_restorers.append(codeformer) except Exception: diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index 7f3baf31..dd0ee629 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -5,15 +5,35 @@ import traceback import numpy as np import torch from PIL import Image +from basicsr.utils.download_util import load_file_from_url import modules.esrgam_model_arch as arch -from modules import shared -from modules.shared import opts -from modules.devices import has_mps import modules.images +from modules import shared +from modules import shared, modelloader +from modules.devices import has_mps +from modules.paths import models_path +from modules.shared import opts + +model_dir = "ESRGAN" +model_path = os.path.join(models_path, model_dir) +model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download" +model_name = "ESRGAN_x4.pth" -def load_model(filename): +def load_model(path: str, name: str): + global model_path + global model_url + global model_dir + global model_name + if "http" in path: + filename = load_file_from_url(url=model_url, model_dir=model_path, file_name=model_name, progress=True) + else: + filename = path + if not os.path.exists(filename) or filename is None: + print("Unable to load %s from %s" % (model_dir, filename)) + return None + print("Loading %s from %s" % (model_dir, filename)) # this code is adapted from https://github.com/xinntao/ESRGAN pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None) crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32) @@ -118,24 +138,30 @@ def esrgan_upscale(model, img): class UpscalerESRGAN(modules.images.Upscaler): def __init__(self, filename, title): self.name = title - self.model = load_model(filename) + self.filename = filename def do_upscale(self, img): - model = self.model.to(shared.device) + model = load_model(self.filename, self.name) + if model is None: + return img + model.to(shared.device) img = esrgan_upscale(model, img) return img -def load_models(dirname): - for file in os.listdir(dirname): - path = os.path.join(dirname, file) - model_name, extension = os.path.splitext(file) - - if extension != '.pt' and extension != '.pth': - continue +def setup_model(dirname): + global model_path + global model_name + if not os.path.exists(model_path): + os.makedirs(model_path) + model_paths = modelloader.load_models(model_path, command_path=dirname, ext_filter=[".pt", ".pth"]) + if len(model_paths) == 0: + modules.shared.sd_upscalers.append(UpscalerESRGAN(model_url, model_name)) + for file in model_paths: + name = modelloader.friendly_name(file) try: - modules.shared.sd_upscalers.append(UpscalerESRGAN(path, model_name)) + modules.shared.sd_upscalers.append(UpscalerESRGAN(file, name)) except Exception: - print(f"Error loading ESRGAN model: {path}", file=sys.stderr) + print(f"Error loading ESRGAN model: {file}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) diff --git a/modules/extras.py b/modules/extras.py index 382ffa7d..4c95cf76 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -36,6 +36,8 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v outputs = [] for image, image_name in zip(imageArr, imageNameArr): + if image is None: + return outputs, "Please select an input image.", '' existing_pnginfo = image.info or {} image = image.convert("RGB") diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index 44c5dc6c..ffb6960d 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -7,33 +7,20 @@ from modules import shared, devices from modules.shared import cmd_opts from modules.paths import script_path import modules.face_restoration +from modules import shared, devices, modelloader +from modules.paths import models_path - -def gfpgan_model_path(): - from modules.shared import cmd_opts - - filemask = 'GFPGAN*.pth' - - if cmd_opts.gfpgan_model is not None: - return cmd_opts.gfpgan_model - - places = [script_path, '.', os.path.join(cmd_opts.gfpgan_dir, 'experiments/pretrained_models')] - - filename = None - for place in places: - filename = next(iter(glob(os.path.join(place, filemask))), None) - if filename is not None: - break - - return filename - +model_dir = "GFPGAN" +cmd_dir = None +model_path = os.path.join(models_path, model_dir) +model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" loaded_gfpgan_model = None def gfpgan(): global loaded_gfpgan_model - + global model_path if loaded_gfpgan_model is not None: loaded_gfpgan_model.gfpgan.to(shared.device) return loaded_gfpgan_model @@ -41,7 +28,15 @@ def gfpgan(): if gfpgan_constructor is None: return None - model = gfpgan_constructor(model_path=gfpgan_model_path() or 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth', upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None) + models = modelloader.load_models(model_path, model_url, cmd_dir) + if len(models) != 0: + latest_file = max(models, key=os.path.getctime) + model_file = latest_file + else: + print("Unable to load gfpgan model!") + return None + model = gfpgan_constructor(model_path=model_file, model_dir=model_path, upscale=1, arch='clean', channel_multiplier=2, + bg_upsampler=None) model.gfpgan.to(shared.device) loaded_gfpgan_model = model @@ -50,7 +45,8 @@ def gfpgan(): def gfpgan_fix_faces(np_image): model = gfpgan() - + if model is None: + return np_image np_image_bgr = np_image[:, :, ::-1] cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True) np_image = gfpgan_output_bgr[:, :, ::-1] @@ -64,19 +60,21 @@ def gfpgan_fix_faces(np_image): have_gfpgan = False gfpgan_constructor = None -def setup_gfpgan(): + +def setup_model(dirname): + global model_path + if not os.path.exists(model_path): + os.makedirs(model_path) + try: - gfpgan_model_path() - - if os.path.exists(cmd_opts.gfpgan_dir): - sys.path.append(os.path.abspath(cmd_opts.gfpgan_dir)) - from gfpgan import GFPGANer - + from modules.gfpgan_model_arch import GFPGANerr + global cmd_dir global have_gfpgan - have_gfpgan = True - global gfpgan_constructor - gfpgan_constructor = GFPGANer + + cmd_dir = dirname + have_gfpgan = True + gfpgan_constructor = GFPGANerr class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration): def name(self): diff --git a/modules/gfpgan_model_arch.py b/modules/gfpgan_model_arch.py new file mode 100644 index 00000000..d81cea96 --- /dev/null +++ b/modules/gfpgan_model_arch.py @@ -0,0 +1,150 @@ +# GFPGAN likes to download stuff "wherever", and we're trying to fix that, so this is a copy of the original... + +import cv2 +import os +import torch +from basicsr.utils import img2tensor, tensor2img +from basicsr.utils.download_util import load_file_from_url +from facexlib.utils.face_restoration_helper import FaceRestoreHelper +from torchvision.transforms.functional import normalize + +from gfpgan.archs.gfpgan_bilinear_arch import GFPGANBilinear +from gfpgan.archs.gfpganv1_arch import GFPGANv1 +from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean + +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + + +class GFPGANerr(): + """Helper for restoration with GFPGAN. + + It will detect and crop faces, and then resize the faces to 512x512. + GFPGAN is used to restored the resized faces. + The background is upsampled with the bg_upsampler. + Finally, the faces will be pasted back to the upsample background image. + + Args: + model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically). + upscale (float): The upscale of the final output. Default: 2. + arch (str): The GFPGAN architecture. Option: clean | original. Default: clean. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + bg_upsampler (nn.Module): The upsampler for the background. Default: None. + """ + + def __init__(self, model_path, model_dir, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None, device=None): + self.upscale = upscale + self.bg_upsampler = bg_upsampler + + # initialize model + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device + # initialize the GFP-GAN + if arch == 'clean': + self.gfpgan = GFPGANv1Clean( + out_size=512, + num_style_feat=512, + channel_multiplier=channel_multiplier, + decoder_load_path=None, + fix_decoder=False, + num_mlp=8, + input_is_latent=True, + different_w=True, + narrow=1, + sft_half=True) + elif arch == 'bilinear': + self.gfpgan = GFPGANBilinear( + out_size=512, + num_style_feat=512, + channel_multiplier=channel_multiplier, + decoder_load_path=None, + fix_decoder=False, + num_mlp=8, + input_is_latent=True, + different_w=True, + narrow=1, + sft_half=True) + elif arch == 'original': + self.gfpgan = GFPGANv1( + out_size=512, + num_style_feat=512, + channel_multiplier=channel_multiplier, + decoder_load_path=None, + fix_decoder=True, + num_mlp=8, + input_is_latent=True, + different_w=True, + narrow=1, + sft_half=True) + elif arch == 'RestoreFormer': + from gfpgan.archs.restoreformer_arch import RestoreFormer + self.gfpgan = RestoreFormer() + # initialize face helper + self.face_helper = FaceRestoreHelper( + upscale, + face_size=512, + crop_ratio=(1, 1), + det_model='retinaface_resnet50', + save_ext='png', + use_parse=True, + device=self.device, + model_rootpath=model_dir) + + if model_path.startswith('https://'): + model_path = load_file_from_url( + url=model_path, model_dir=model_dir, progress=True, file_name=None) + loadnet = torch.load(model_path) + if 'params_ema' in loadnet: + keyname = 'params_ema' + else: + keyname = 'params' + self.gfpgan.load_state_dict(loadnet[keyname], strict=True) + self.gfpgan.eval() + self.gfpgan = self.gfpgan.to(self.device) + + @torch.no_grad() + def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True, weight=0.5): + self.face_helper.clean_all() + + if has_aligned: # the inputs are already aligned + img = cv2.resize(img, (512, 512)) + self.face_helper.cropped_faces = [img] + else: + self.face_helper.read_image(img) + # get face landmarks for each face + self.face_helper.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5) + # eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels + # TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations. + # align and warp each face + self.face_helper.align_warp_face() + + # face restoration + for cropped_face in self.face_helper.cropped_faces: + # prepare data + cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) + normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) + cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device) + + try: + output = self.gfpgan(cropped_face_t, return_rgb=False, weight=weight)[0] + # convert to image + restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1)) + except RuntimeError as error: + print(f'\tFailed inference for GFPGAN: {error}.') + restored_face = cropped_face + + restored_face = restored_face.astype('uint8') + self.face_helper.add_restored_face(restored_face) + + if not has_aligned and paste_back: + # upsample the background + if self.bg_upsampler is not None: + # Now only support RealESRGAN for upsampling background + bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0] + else: + bg_img = None + + self.face_helper.get_inverse_affine(None) + # paste each restored face to the input image + restored_img = self.face_helper.paste_faces_to_input_image(upsample_img=bg_img) + return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img + else: + return self.face_helper.cropped_faces, self.face_helper.restored_faces, None diff --git a/modules/ldsr_model.py b/modules/ldsr_model.py index 95e84659..e6e7ff74 100644 --- a/modules/ldsr_model.py +++ b/modules/ldsr_model.py @@ -3,11 +3,14 @@ import sys import traceback from collections import namedtuple -from basicsr.utils.download_util import load_file_from_url +from modules import shared, images, modelloader, paths +from modules.paths import models_path -import modules.images -from modules import shared -from modules.paths import script_path +model_dir = "LDSR" +model_path = os.path.join(models_path, model_dir) +cmd_path = None +model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1" +yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1" LDSRModelInfo = namedtuple("LDSRModelInfo", ["name", "location", "model", "netscale"]) @@ -25,28 +28,32 @@ class UpscalerLDSR(modules.images.Upscaler): return upscale_with_ldsr(img) -def add_lsdr(): - modules.shared.sd_upscalers.append(UpscalerLDSR(100)) +def setup_model(dirname): + global cmd_path + global model_path + if not os.path.exists(model_path): + os.makedirs(model_path) + cmd_path = dirname + shared.sd_upscalers.append(UpscalerLDSR(100)) -def setup_ldsr(): - path = modules.paths.paths.get("LDSR", None) +def prepare_ldsr(): + path = paths.paths.get("LDSR", None) if path is None: return global have_ldsr global LDSR_obj try: from LDSR import LDSR - model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1" - yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1" - repo_path = 'latent-diffusion/experiments/pretrained_models/' - model_path = load_file_from_url(url=model_url, model_dir=os.path.join("repositories", repo_path), - progress=True, file_name="model.chkpt") - yaml_path = load_file_from_url(url=yaml_url, model_dir=os.path.join("repositories", repo_path), - progress=True, file_name="project.yaml") - have_ldsr = True - LDSR_obj = LDSR(model_path, yaml_path) - + model_files = modelloader.load_models(model_path, model_url, cmd_path, dl_name="model.ckpt", ext_filter=[".ckpt"]) + yaml_files = modelloader.load_models(model_path, yaml_url, cmd_path, dl_name="project.yaml", ext_filter=[".yaml"]) + if len(model_files) != 0 and len(yaml_files) != 0: + model_file = model_files[0] + yaml_file = yaml_files[0] + have_ldsr = True + LDSR_obj = LDSR(model_file, yaml_file) + else: + return except Exception: print("Error importing LDSR:", file=sys.stderr) @@ -55,7 +62,7 @@ def setup_ldsr(): def upscale_with_ldsr(image): - setup_ldsr() + prepare_ldsr() if not have_ldsr or LDSR_obj is None: return image diff --git a/modules/modelloader.py b/modules/modelloader.py new file mode 100644 index 00000000..d59fbe05 --- /dev/null +++ b/modules/modelloader.py @@ -0,0 +1,65 @@ +import os +from urllib.parse import urlparse + +from basicsr.utils.download_util import load_file_from_url + + +def load_models(model_path: str, model_url: str = None, command_path: str = None, dl_name: str = None, existing=None, + ext_filter=None) -> list: + """ + A one-and done loader to try finding the desired models in specified directories. + + @param dl_name: The file name to use for downloading a model. If not specified, it will be used from the URL. + @param model_url: If specified, attempt to download model from the given URL. + @param model_path: The location to store/find models in. + @param command_path: A command-line argument to search for models in first. + @param existing: An array of existing model paths. + @param ext_filter: An optional list of filename extensions to filter by + @return: A list of paths containing the desired model(s) + """ + if ext_filter is None: + ext_filter = [] + if existing is None: + existing = [] + try: + places = [] + if command_path is not None and command_path != model_path: + pretrained_path = os.path.join(command_path, 'experiments/pretrained_models') + if os.path.exists(pretrained_path): + places.append(pretrained_path) + elif os.path.exists(command_path): + places.append(command_path) + places.append(model_path) + for place in places: + if os.path.exists(place): + for file in os.listdir(place): + if os.path.isdir(file): + continue + if len(ext_filter) != 0: + model_name, extension = os.path.splitext(file) + if extension not in ext_filter: + continue + if file not in existing: + path = os.path.join(place, file) + existing.append(path) + if model_url is not None: + if dl_name is not None: + model_file = load_file_from_url(url=model_url, model_dir=model_path, file_name=dl_name, progress=True) + else: + model_file = load_file_from_url(url=model_url, model_dir=model_path, progress=True) + + if os.path.exists(model_file) and os.path.isfile(model_file) and model_file not in existing: + existing.append(model_file) + except: + pass + return existing + + +def friendly_name(file: str): + if "http" in file: + file = urlparse(file).path + + file = os.path.basename(file) + model_name, extension = os.path.splitext(file) + model_name = model_name.replace("_", " ").title() + return model_name diff --git a/modules/paths.py b/modules/paths.py index 3a19f9e5..015fa672 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -3,9 +3,10 @@ import os import sys script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +models_path = os.path.join(script_path, "models") sys.path.insert(0, script_path) -# search for directory of stable diffsuion in following palces +# search for directory of stable diffusion in following places sd_path = None possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion'), '.', os.path.dirname(script_path)] for possible_sd_path in possible_sd_paths: diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index c32d6c4c..458bf678 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -1,14 +1,20 @@ +import os import sys import traceback from collections import namedtuple import numpy as np from PIL import Image +from basicsr.utils.download_util import load_file_from_url from realesrgan import RealESRGANer import modules.images +from modules.paths import models_path from modules.shared import cmd_opts, opts +model_dir = "RealESRGAN" +model_path = os.path.join(models_path, model_dir) +cmd_dir = None RealesrganModelInfo = namedtuple("RealesrganModelInfo", ["name", "location", "model", "netscale"]) realesrgan_models = [] have_realesrgan = False @@ -17,7 +23,6 @@ have_realesrgan = False def get_realesrgan_models(): try: from basicsr.archs.rrdbnet_arch import RRDBNet - from realesrgan import RealESRGANer from realesrgan.archs.srvgg_arch import SRVGGNetCompact models = [ RealesrganModelInfo( @@ -59,7 +64,7 @@ def get_realesrgan_models(): ] return models except Exception as e: - print("Error makeing Real-ESRGAN midels list:", file=sys.stderr) + print("Error making Real-ESRGAN models list:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) @@ -73,10 +78,15 @@ class UpscalerRealESRGAN(modules.images.Upscaler): return upscale_with_realesrgan(img, self.upscaling, self.model_index) -def setup_realesrgan(): +def setup_model(dirname): + global model_path + if not os.path.exists(model_path): + os.makedirs(model_path) + global realesrgan_models global have_realesrgan - + if model_path != dirname: + model_path = dirname try: from basicsr.archs.rrdbnet_arch import RRDBNet from realesrgan import RealESRGANer @@ -104,6 +114,11 @@ def upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index) info = realesrgan_models[RealESRGAN_model_index] model = info.model() + model_file = load_file_from_url(url=info.location, model_dir=model_path, progress=True) + if not os.path.exists(model_file): + print("Unable to load RealESRGAN model: %s" % info.name) + return image + upsampler = RealESRGANer( scale=info.netscale, model_path=info.location, diff --git a/modules/shared.py b/modules/shared.py index c32da110..1444040d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -16,11 +16,11 @@ import modules.sd_models sd_model_file = os.path.join(script_path, 'model.ckpt') default_sd_model_file = sd_model_file - +model_path = os.path.join(script_path, 'models') parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; this checkpoint will be added to the list of checkpoints and loaded by default if you don't have a checkpoint selected in settings",) -parser.add_argument("--ckpt-dir", type=str, default=os.path.join(script_path, 'models'), help="path to directory with stable diffusion checkpoints",) +parser.add_argument("--ckpt-dir", type=str, default=model_path, help="path to directory with stable diffusion checkpoints",) parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None) parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats") @@ -34,8 +34,12 @@ parser.add_argument("--always-batch-cond-uncond", action='store_true', help="dis parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.") parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)") -parser.add_argument("--esrgan-models-path", type=str, help="path to directory with ESRGAN models", default=os.path.join(script_path, 'ESRGAN')) -parser.add_argument("--swinir-models-path", type=str, help="path to directory with SwinIR models", default=os.path.join(script_path, 'SwinIR')) +parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model(s)", default=os.path.join(model_path, 'Codeformer')) +parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model(s)", default=os.path.join(model_path, 'GFPGAN')) +parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN models", default=os.path.join(model_path, 'ESRGAN')) +parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN models", default=os.path.join(model_path, 'RealESRGAN')) +parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR models", default=os.path.join(model_path, 'SwinIR')) +parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR models", default=os.path.join(model_path, 'LDSR')) parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") diff --git a/modules/swinir_model.py b/modules/swinir_model.py index e86d0789..f515779e 100644 --- a/modules/swinir_model.py +++ b/modules/swinir_model.py @@ -1,21 +1,39 @@ +import contextlib +import os import sys import traceback -import cv2 -import os -import contextlib -import numpy as np -from PIL import Image -import torch -import modules.images -from modules.shared import cmd_opts, opts, device -from modules.swinir_arch import SwinIR as net +import numpy as np +import torch +from PIL import Image +from basicsr.utils.download_util import load_file_from_url + +import modules.images +from modules import modelloader +from modules.paths import models_path +from modules.shared import cmd_opts, opts, device +from modules.swinir_model_arch import SwinIR as net + +model_dir = "SwinIR" +model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth" +model_name = "SwinIR x4" +model_path = os.path.join(models_path, model_dir) +cmd_path = "" precision_scope = ( torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext ) -def load_model(filename, scale=4): +def load_model(path, scale=4): + global model_path + global model_name + if "http" in path: + dl_name = "%s%s" % (model_name.replace(" ", "_"), ".pth") + filename = load_file_from_url(url=path, model_dir=model_path, file_name=dl_name, progress=True) + else: + filename = path + if filename is None or not os.path.exists(filename): + return None model = net( upscale=scale, in_chans=3, @@ -37,19 +55,29 @@ def load_model(filename, scale=4): return model -def load_models(dirname): - for file in os.listdir(dirname): - path = os.path.join(dirname, file) - model_name, extension = os.path.splitext(file) +def setup_model(dirname): + global model_path + global model_name + global cmd_path + if not os.path.exists(model_path): + os.makedirs(model_path) + cmd_path = dirname + model_file = "" + try: + models = modelloader.load_models(model_path, ext_filter=[".pt", ".pth"], command_path=cmd_path) - if extension != ".pt" and extension != ".pth": - continue + if len(models) != 0: + model_file = models[0] + name = modelloader.friendly_name(model_file) + else: + # Add the "default" model if none are found. + model_file = model_url + name = model_name - try: - modules.shared.sd_upscalers.append(UpscalerSwin(path, model_name)) - except Exception: - print(f"Error loading SwinIR model: {path}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + modules.shared.sd_upscalers.append(UpscalerSwin(model_file, name)) + except Exception: + print(f"Error loading SwinIR model: {model_file}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) def upscale( @@ -115,9 +143,16 @@ def inference(img, model, tile, tile_overlap, window_size, scale): class UpscalerSwin(modules.images.Upscaler): def __init__(self, filename, title): self.name = title - self.model = load_model(filename) + self.filename = filename def do_upscale(self, img): - model = self.model.to(device) + model = load_model(self.filename) + if model is None: + return img + model = model.to(device) img = upscale(img, model) - return img + try: + torch.cuda.empty_cache() + except: + pass + return img \ No newline at end of file diff --git a/webui.py b/webui.py index 9ea5f5a3..7e0b3296 100644 --- a/webui.py +++ b/webui.py @@ -1,37 +1,34 @@ import os +import signal import threading -from modules.paths import script_path - -import signal - -from modules.shared import opts, cmd_opts, state -import modules.shared as shared -import modules.ui +import modules.codeformer_model as codeformer +import modules.esrgan_model as esrgan +import modules.extras +import modules.face_restoration +import modules.gfpgan_model as gfpgan +import modules.img2img +import modules.ldsr_model as ldsr +import modules.lowvram +import modules.realesrgan_model as realesrgan import modules.scripts import modules.sd_hijack -import modules.codeformer_model -import modules.gfpgan_model -import modules.face_restoration -import modules.realesrgan_model as realesrgan -import modules.esrgan_model as esrgan -import modules.ldsr_model as ldsr -import modules.extras -import modules.lowvram -import modules.txt2img -import modules.img2img -import modules.swinir as swinir import modules.sd_models +import modules.shared as shared +import modules.swinir_model as swinir +import modules.txt2img +import modules.ui +from modules.paths import script_path +from modules.shared import cmd_opts - -modules.codeformer_model.setup_codeformer() -modules.gfpgan_model.setup_gfpgan() +codeformer.setup_model(cmd_opts.codeformer_models_path) +gfpgan.setup_model(cmd_opts.gfpgan_models_path) shared.face_restorers.append(modules.face_restoration.FaceRestoration()) -esrgan.load_models(cmd_opts.esrgan_models_path) -swinir.load_models(cmd_opts.swinir_models_path) -realesrgan.setup_realesrgan() -ldsr.add_lsdr() +esrgan.setup_model(cmd_opts.esrgan_models_path) +swinir.setup_model(cmd_opts.swinir_models_path) +realesrgan.setup_model(cmd_opts.realesrgan_models_path) +ldsr.setup_model(cmd_opts.ldsr_models_path) queue_lock = threading.Lock()