Re-implement universal model loading

This commit is contained in:
d8ahazard 2022-09-26 09:29:50 -05:00
parent bfb7f15d46
commit 740070ea9c
12 changed files with 449 additions and 134 deletions

View File

@ -5,22 +5,28 @@ import traceback
import cv2 import cv2
import torch import torch
from modules import shared, devices from modules import shared, devices, modelloader
from modules.paths import script_path from modules.paths import script_path, models_path
import modules.shared import modules.shared
import modules.face_restoration import modules.face_restoration
from importlib import reload from importlib import reload
# codeformer people made a choice to include modified basicsr librry to their projectwhich makes # codeformer people made a choice to include modified basicsr library to their project, which makes
# it utterly impossiblr to use it alongside with other libraries that also use basicsr, like GFPGAN. # it utterly impossible to use it alongside other libraries that also use basicsr, like GFPGAN.
# I am making a choice to include some files from codeformer to work around this issue. # I am making a choice to include some files from codeformer to work around this issue.
model_dir = "Codeformer"
pretrain_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' model_path = os.path.join(models_path, model_dir)
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
have_codeformer = False have_codeformer = False
codeformer = None codeformer = None
def setup_codeformer():
def setup_model(dirname):
global model_path
if not os.path.exists(model_path):
os.makedirs(model_path)
path = modules.paths.paths.get("CodeFormer", None) path = modules.paths.paths.get("CodeFormer", None)
if path is None: if path is None:
return return
@ -44,16 +50,22 @@ def setup_codeformer():
def name(self): def name(self):
return "CodeFormer" return "CodeFormer"
def __init__(self): def __init__(self, dirname):
self.net = None self.net = None
self.face_helper = None self.face_helper = None
self.cmd_dir = dirname
def create_models(self): def create_models(self):
if self.net is not None and self.face_helper is not None: if self.net is not None and self.face_helper is not None:
self.net.to(devices.device_codeformer) self.net.to(devices.device_codeformer)
return self.net, self.face_helper return self.net, self.face_helper
model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir)
if len(model_paths) != 0:
ckpt_path = model_paths[0]
else:
print("Unable to load codeformer model.")
return None, None
net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer) net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
ckpt_path = load_file_from_url(url=pretrain_model_url, model_dir=os.path.join(path, 'weights/CodeFormer'), progress=True) ckpt_path = load_file_from_url(url=pretrain_model_url, model_dir=os.path.join(path, 'weights/CodeFormer'), progress=True)
checkpoint = torch.load(ckpt_path)['params_ema'] checkpoint = torch.load(ckpt_path)['params_ema']
@ -74,6 +86,9 @@ def setup_codeformer():
original_resolution = np_image.shape[0:2] original_resolution = np_image.shape[0:2]
self.create_models() self.create_models()
if self.net is None or self.face_helper is None:
return np_image
self.face_helper.clean_all() self.face_helper.clean_all()
self.face_helper.read_image(np_image) self.face_helper.read_image(np_image)
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
@ -114,7 +129,7 @@ def setup_codeformer():
have_codeformer = True have_codeformer = True
global codeformer global codeformer
codeformer = FaceRestorerCodeFormer() codeformer = FaceRestorerCodeFormer(dirname)
shared.face_restorers.append(codeformer) shared.face_restorers.append(codeformer)
except Exception: except Exception:

View File

@ -5,15 +5,35 @@ import traceback
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from basicsr.utils.download_util import load_file_from_url
import modules.esrgam_model_arch as arch import modules.esrgam_model_arch as arch
from modules import shared
from modules.shared import opts
from modules.devices import has_mps
import modules.images import modules.images
from modules import shared
from modules import shared, modelloader
from modules.devices import has_mps
from modules.paths import models_path
from modules.shared import opts
model_dir = "ESRGAN"
model_path = os.path.join(models_path, model_dir)
model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download"
model_name = "ESRGAN_x4.pth"
def load_model(filename): def load_model(path: str, name: str):
global model_path
global model_url
global model_dir
global model_name
if "http" in path:
filename = load_file_from_url(url=model_url, model_dir=model_path, file_name=model_name, progress=True)
else:
filename = path
if not os.path.exists(filename) or filename is None:
print("Unable to load %s from %s" % (model_dir, filename))
return None
print("Loading %s from %s" % (model_dir, filename))
# this code is adapted from https://github.com/xinntao/ESRGAN # this code is adapted from https://github.com/xinntao/ESRGAN
pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None) pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None)
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32) crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
@ -118,24 +138,30 @@ def esrgan_upscale(model, img):
class UpscalerESRGAN(modules.images.Upscaler): class UpscalerESRGAN(modules.images.Upscaler):
def __init__(self, filename, title): def __init__(self, filename, title):
self.name = title self.name = title
self.model = load_model(filename) self.filename = filename
def do_upscale(self, img): def do_upscale(self, img):
model = self.model.to(shared.device) model = load_model(self.filename, self.name)
if model is None:
return img
model.to(shared.device)
img = esrgan_upscale(model, img) img = esrgan_upscale(model, img)
return img return img
def load_models(dirname): def setup_model(dirname):
for file in os.listdir(dirname): global model_path
path = os.path.join(dirname, file) global model_name
model_name, extension = os.path.splitext(file) if not os.path.exists(model_path):
os.makedirs(model_path)
if extension != '.pt' and extension != '.pth':
continue
model_paths = modelloader.load_models(model_path, command_path=dirname, ext_filter=[".pt", ".pth"])
if len(model_paths) == 0:
modules.shared.sd_upscalers.append(UpscalerESRGAN(model_url, model_name))
for file in model_paths:
name = modelloader.friendly_name(file)
try: try:
modules.shared.sd_upscalers.append(UpscalerESRGAN(path, model_name)) modules.shared.sd_upscalers.append(UpscalerESRGAN(file, name))
except Exception: except Exception:
print(f"Error loading ESRGAN model: {path}", file=sys.stderr) print(f"Error loading ESRGAN model: {file}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)

View File

@ -36,6 +36,8 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
outputs = [] outputs = []
for image, image_name in zip(imageArr, imageNameArr): for image, image_name in zip(imageArr, imageNameArr):
if image is None:
return outputs, "Please select an input image.", ''
existing_pnginfo = image.info or {} existing_pnginfo = image.info or {}
image = image.convert("RGB") image = image.convert("RGB")

View File

@ -7,33 +7,20 @@ from modules import shared, devices
from modules.shared import cmd_opts from modules.shared import cmd_opts
from modules.paths import script_path from modules.paths import script_path
import modules.face_restoration import modules.face_restoration
from modules import shared, devices, modelloader
from modules.paths import models_path
model_dir = "GFPGAN"
def gfpgan_model_path(): cmd_dir = None
from modules.shared import cmd_opts model_path = os.path.join(models_path, model_dir)
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
filemask = 'GFPGAN*.pth'
if cmd_opts.gfpgan_model is not None:
return cmd_opts.gfpgan_model
places = [script_path, '.', os.path.join(cmd_opts.gfpgan_dir, 'experiments/pretrained_models')]
filename = None
for place in places:
filename = next(iter(glob(os.path.join(place, filemask))), None)
if filename is not None:
break
return filename
loaded_gfpgan_model = None loaded_gfpgan_model = None
def gfpgan(): def gfpgan():
global loaded_gfpgan_model global loaded_gfpgan_model
global model_path
if loaded_gfpgan_model is not None: if loaded_gfpgan_model is not None:
loaded_gfpgan_model.gfpgan.to(shared.device) loaded_gfpgan_model.gfpgan.to(shared.device)
return loaded_gfpgan_model return loaded_gfpgan_model
@ -41,7 +28,15 @@ def gfpgan():
if gfpgan_constructor is None: if gfpgan_constructor is None:
return None return None
model = gfpgan_constructor(model_path=gfpgan_model_path() or 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth', upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None) models = modelloader.load_models(model_path, model_url, cmd_dir)
if len(models) != 0:
latest_file = max(models, key=os.path.getctime)
model_file = latest_file
else:
print("Unable to load gfpgan model!")
return None
model = gfpgan_constructor(model_path=model_file, model_dir=model_path, upscale=1, arch='clean', channel_multiplier=2,
bg_upsampler=None)
model.gfpgan.to(shared.device) model.gfpgan.to(shared.device)
loaded_gfpgan_model = model loaded_gfpgan_model = model
@ -50,7 +45,8 @@ def gfpgan():
def gfpgan_fix_faces(np_image): def gfpgan_fix_faces(np_image):
model = gfpgan() model = gfpgan()
if model is None:
return np_image
np_image_bgr = np_image[:, :, ::-1] np_image_bgr = np_image[:, :, ::-1]
cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True) cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
np_image = gfpgan_output_bgr[:, :, ::-1] np_image = gfpgan_output_bgr[:, :, ::-1]
@ -64,19 +60,21 @@ def gfpgan_fix_faces(np_image):
have_gfpgan = False have_gfpgan = False
gfpgan_constructor = None gfpgan_constructor = None
def setup_gfpgan():
def setup_model(dirname):
global model_path
if not os.path.exists(model_path):
os.makedirs(model_path)
try: try:
gfpgan_model_path() from modules.gfpgan_model_arch import GFPGANerr
global cmd_dir
if os.path.exists(cmd_opts.gfpgan_dir):
sys.path.append(os.path.abspath(cmd_opts.gfpgan_dir))
from gfpgan import GFPGANer
global have_gfpgan global have_gfpgan
have_gfpgan = True
global gfpgan_constructor global gfpgan_constructor
gfpgan_constructor = GFPGANer
cmd_dir = dirname
have_gfpgan = True
gfpgan_constructor = GFPGANerr
class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration): class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
def name(self): def name(self):

View File

@ -0,0 +1,150 @@
# GFPGAN likes to download stuff "wherever", and we're trying to fix that, so this is a copy of the original...
import cv2
import os
import torch
from basicsr.utils import img2tensor, tensor2img
from basicsr.utils.download_util import load_file_from_url
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from torchvision.transforms.functional import normalize
from gfpgan.archs.gfpgan_bilinear_arch import GFPGANBilinear
from gfpgan.archs.gfpganv1_arch import GFPGANv1
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
class GFPGANerr():
"""Helper for restoration with GFPGAN.
It will detect and crop faces, and then resize the faces to 512x512.
GFPGAN is used to restored the resized faces.
The background is upsampled with the bg_upsampler.
Finally, the faces will be pasted back to the upsample background image.
Args:
model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically).
upscale (float): The upscale of the final output. Default: 2.
arch (str): The GFPGAN architecture. Option: clean | original. Default: clean.
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
bg_upsampler (nn.Module): The upsampler for the background. Default: None.
"""
def __init__(self, model_path, model_dir, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None, device=None):
self.upscale = upscale
self.bg_upsampler = bg_upsampler
# initialize model
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
# initialize the GFP-GAN
if arch == 'clean':
self.gfpgan = GFPGANv1Clean(
out_size=512,
num_style_feat=512,
channel_multiplier=channel_multiplier,
decoder_load_path=None,
fix_decoder=False,
num_mlp=8,
input_is_latent=True,
different_w=True,
narrow=1,
sft_half=True)
elif arch == 'bilinear':
self.gfpgan = GFPGANBilinear(
out_size=512,
num_style_feat=512,
channel_multiplier=channel_multiplier,
decoder_load_path=None,
fix_decoder=False,
num_mlp=8,
input_is_latent=True,
different_w=True,
narrow=1,
sft_half=True)
elif arch == 'original':
self.gfpgan = GFPGANv1(
out_size=512,
num_style_feat=512,
channel_multiplier=channel_multiplier,
decoder_load_path=None,
fix_decoder=True,
num_mlp=8,
input_is_latent=True,
different_w=True,
narrow=1,
sft_half=True)
elif arch == 'RestoreFormer':
from gfpgan.archs.restoreformer_arch import RestoreFormer
self.gfpgan = RestoreFormer()
# initialize face helper
self.face_helper = FaceRestoreHelper(
upscale,
face_size=512,
crop_ratio=(1, 1),
det_model='retinaface_resnet50',
save_ext='png',
use_parse=True,
device=self.device,
model_rootpath=model_dir)
if model_path.startswith('https://'):
model_path = load_file_from_url(
url=model_path, model_dir=model_dir, progress=True, file_name=None)
loadnet = torch.load(model_path)
if 'params_ema' in loadnet:
keyname = 'params_ema'
else:
keyname = 'params'
self.gfpgan.load_state_dict(loadnet[keyname], strict=True)
self.gfpgan.eval()
self.gfpgan = self.gfpgan.to(self.device)
@torch.no_grad()
def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True, weight=0.5):
self.face_helper.clean_all()
if has_aligned: # the inputs are already aligned
img = cv2.resize(img, (512, 512))
self.face_helper.cropped_faces = [img]
else:
self.face_helper.read_image(img)
# get face landmarks for each face
self.face_helper.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5)
# eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels
# TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations.
# align and warp each face
self.face_helper.align_warp_face()
# face restoration
for cropped_face in self.face_helper.cropped_faces:
# prepare data
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
try:
output = self.gfpgan(cropped_face_t, return_rgb=False, weight=weight)[0]
# convert to image
restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
except RuntimeError as error:
print(f'\tFailed inference for GFPGAN: {error}.')
restored_face = cropped_face
restored_face = restored_face.astype('uint8')
self.face_helper.add_restored_face(restored_face)
if not has_aligned and paste_back:
# upsample the background
if self.bg_upsampler is not None:
# Now only support RealESRGAN for upsampling background
bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
else:
bg_img = None
self.face_helper.get_inverse_affine(None)
# paste each restored face to the input image
restored_img = self.face_helper.paste_faces_to_input_image(upsample_img=bg_img)
return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img
else:
return self.face_helper.cropped_faces, self.face_helper.restored_faces, None

View File

@ -3,11 +3,14 @@ import sys
import traceback import traceback
from collections import namedtuple from collections import namedtuple
from basicsr.utils.download_util import load_file_from_url from modules import shared, images, modelloader, paths
from modules.paths import models_path
import modules.images model_dir = "LDSR"
from modules import shared model_path = os.path.join(models_path, model_dir)
from modules.paths import script_path cmd_path = None
model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
LDSRModelInfo = namedtuple("LDSRModelInfo", ["name", "location", "model", "netscale"]) LDSRModelInfo = namedtuple("LDSRModelInfo", ["name", "location", "model", "netscale"])
@ -25,28 +28,32 @@ class UpscalerLDSR(modules.images.Upscaler):
return upscale_with_ldsr(img) return upscale_with_ldsr(img)
def add_lsdr(): def setup_model(dirname):
modules.shared.sd_upscalers.append(UpscalerLDSR(100)) global cmd_path
global model_path
if not os.path.exists(model_path):
os.makedirs(model_path)
cmd_path = dirname
shared.sd_upscalers.append(UpscalerLDSR(100))
def setup_ldsr(): def prepare_ldsr():
path = modules.paths.paths.get("LDSR", None) path = paths.paths.get("LDSR", None)
if path is None: if path is None:
return return
global have_ldsr global have_ldsr
global LDSR_obj global LDSR_obj
try: try:
from LDSR import LDSR from LDSR import LDSR
model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1" model_files = modelloader.load_models(model_path, model_url, cmd_path, dl_name="model.ckpt", ext_filter=[".ckpt"])
yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1" yaml_files = modelloader.load_models(model_path, yaml_url, cmd_path, dl_name="project.yaml", ext_filter=[".yaml"])
repo_path = 'latent-diffusion/experiments/pretrained_models/' if len(model_files) != 0 and len(yaml_files) != 0:
model_path = load_file_from_url(url=model_url, model_dir=os.path.join("repositories", repo_path), model_file = model_files[0]
progress=True, file_name="model.chkpt") yaml_file = yaml_files[0]
yaml_path = load_file_from_url(url=yaml_url, model_dir=os.path.join("repositories", repo_path), have_ldsr = True
progress=True, file_name="project.yaml") LDSR_obj = LDSR(model_file, yaml_file)
have_ldsr = True else:
LDSR_obj = LDSR(model_path, yaml_path) return
except Exception: except Exception:
print("Error importing LDSR:", file=sys.stderr) print("Error importing LDSR:", file=sys.stderr)
@ -55,7 +62,7 @@ def setup_ldsr():
def upscale_with_ldsr(image): def upscale_with_ldsr(image):
setup_ldsr() prepare_ldsr()
if not have_ldsr or LDSR_obj is None: if not have_ldsr or LDSR_obj is None:
return image return image

65
modules/modelloader.py Normal file
View File

@ -0,0 +1,65 @@
import os
from urllib.parse import urlparse
from basicsr.utils.download_util import load_file_from_url
def load_models(model_path: str, model_url: str = None, command_path: str = None, dl_name: str = None, existing=None,
ext_filter=None) -> list:
"""
A one-and done loader to try finding the desired models in specified directories.
@param dl_name: The file name to use for downloading a model. If not specified, it will be used from the URL.
@param model_url: If specified, attempt to download model from the given URL.
@param model_path: The location to store/find models in.
@param command_path: A command-line argument to search for models in first.
@param existing: An array of existing model paths.
@param ext_filter: An optional list of filename extensions to filter by
@return: A list of paths containing the desired model(s)
"""
if ext_filter is None:
ext_filter = []
if existing is None:
existing = []
try:
places = []
if command_path is not None and command_path != model_path:
pretrained_path = os.path.join(command_path, 'experiments/pretrained_models')
if os.path.exists(pretrained_path):
places.append(pretrained_path)
elif os.path.exists(command_path):
places.append(command_path)
places.append(model_path)
for place in places:
if os.path.exists(place):
for file in os.listdir(place):
if os.path.isdir(file):
continue
if len(ext_filter) != 0:
model_name, extension = os.path.splitext(file)
if extension not in ext_filter:
continue
if file not in existing:
path = os.path.join(place, file)
existing.append(path)
if model_url is not None:
if dl_name is not None:
model_file = load_file_from_url(url=model_url, model_dir=model_path, file_name=dl_name, progress=True)
else:
model_file = load_file_from_url(url=model_url, model_dir=model_path, progress=True)
if os.path.exists(model_file) and os.path.isfile(model_file) and model_file not in existing:
existing.append(model_file)
except:
pass
return existing
def friendly_name(file: str):
if "http" in file:
file = urlparse(file).path
file = os.path.basename(file)
model_name, extension = os.path.splitext(file)
model_name = model_name.replace("_", " ").title()
return model_name

View File

@ -3,9 +3,10 @@ import os
import sys import sys
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
models_path = os.path.join(script_path, "models")
sys.path.insert(0, script_path) sys.path.insert(0, script_path)
# search for directory of stable diffsuion in following palces # search for directory of stable diffusion in following places
sd_path = None sd_path = None
possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion'), '.', os.path.dirname(script_path)] possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion'), '.', os.path.dirname(script_path)]
for possible_sd_path in possible_sd_paths: for possible_sd_path in possible_sd_paths:

View File

@ -1,14 +1,20 @@
import os
import sys import sys
import traceback import traceback
from collections import namedtuple from collections import namedtuple
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer from realesrgan import RealESRGANer
import modules.images import modules.images
from modules.paths import models_path
from modules.shared import cmd_opts, opts from modules.shared import cmd_opts, opts
model_dir = "RealESRGAN"
model_path = os.path.join(models_path, model_dir)
cmd_dir = None
RealesrganModelInfo = namedtuple("RealesrganModelInfo", ["name", "location", "model", "netscale"]) RealesrganModelInfo = namedtuple("RealesrganModelInfo", ["name", "location", "model", "netscale"])
realesrgan_models = [] realesrgan_models = []
have_realesrgan = False have_realesrgan = False
@ -17,7 +23,6 @@ have_realesrgan = False
def get_realesrgan_models(): def get_realesrgan_models():
try: try:
from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact from realesrgan.archs.srvgg_arch import SRVGGNetCompact
models = [ models = [
RealesrganModelInfo( RealesrganModelInfo(
@ -59,7 +64,7 @@ def get_realesrgan_models():
] ]
return models return models
except Exception as e: except Exception as e:
print("Error makeing Real-ESRGAN midels list:", file=sys.stderr) print("Error making Real-ESRGAN models list:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
@ -73,10 +78,15 @@ class UpscalerRealESRGAN(modules.images.Upscaler):
return upscale_with_realesrgan(img, self.upscaling, self.model_index) return upscale_with_realesrgan(img, self.upscaling, self.model_index)
def setup_realesrgan(): def setup_model(dirname):
global model_path
if not os.path.exists(model_path):
os.makedirs(model_path)
global realesrgan_models global realesrgan_models
global have_realesrgan global have_realesrgan
if model_path != dirname:
model_path = dirname
try: try:
from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer from realesrgan import RealESRGANer
@ -104,6 +114,11 @@ def upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index)
info = realesrgan_models[RealESRGAN_model_index] info = realesrgan_models[RealESRGAN_model_index]
model = info.model() model = info.model()
model_file = load_file_from_url(url=info.location, model_dir=model_path, progress=True)
if not os.path.exists(model_file):
print("Unable to load RealESRGAN model: %s" % info.name)
return image
upsampler = RealESRGANer( upsampler = RealESRGANer(
scale=info.netscale, scale=info.netscale,
model_path=info.location, model_path=info.location,

View File

@ -16,11 +16,11 @@ import modules.sd_models
sd_model_file = os.path.join(script_path, 'model.ckpt') sd_model_file = os.path.join(script_path, 'model.ckpt')
default_sd_model_file = sd_model_file default_sd_model_file = sd_model_file
model_path = os.path.join(script_path, 'models')
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",) parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; this checkpoint will be added to the list of checkpoints and loaded by default if you don't have a checkpoint selected in settings",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; this checkpoint will be added to the list of checkpoints and loaded by default if you don't have a checkpoint selected in settings",)
parser.add_argument("--ckpt-dir", type=str, default=os.path.join(script_path, 'models'), help="path to directory with stable diffusion checkpoints",) parser.add_argument("--ckpt-dir", type=str, default=model_path, help="path to directory with stable diffusion checkpoints",)
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None) parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats") parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
@ -34,8 +34,12 @@ parser.add_argument("--always-batch-cond-uncond", action='store_true', help="dis
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.") parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)") parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
parser.add_argument("--esrgan-models-path", type=str, help="path to directory with ESRGAN models", default=os.path.join(script_path, 'ESRGAN')) parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model(s)", default=os.path.join(model_path, 'Codeformer'))
parser.add_argument("--swinir-models-path", type=str, help="path to directory with SwinIR models", default=os.path.join(script_path, 'SwinIR')) parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model(s)", default=os.path.join(model_path, 'GFPGAN'))
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN models", default=os.path.join(model_path, 'ESRGAN'))
parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN models", default=os.path.join(model_path, 'RealESRGAN'))
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR models", default=os.path.join(model_path, 'SwinIR'))
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR models", default=os.path.join(model_path, 'LDSR'))
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.") parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")

View File

@ -1,21 +1,39 @@
import contextlib
import os
import sys import sys
import traceback import traceback
import cv2
import os
import contextlib
import numpy as np
from PIL import Image
import torch
import modules.images
from modules.shared import cmd_opts, opts, device
from modules.swinir_arch import SwinIR as net
import numpy as np
import torch
from PIL import Image
from basicsr.utils.download_util import load_file_from_url
import modules.images
from modules import modelloader
from modules.paths import models_path
from modules.shared import cmd_opts, opts, device
from modules.swinir_model_arch import SwinIR as net
model_dir = "SwinIR"
model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
model_name = "SwinIR x4"
model_path = os.path.join(models_path, model_dir)
cmd_path = ""
precision_scope = ( precision_scope = (
torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
) )
def load_model(filename, scale=4): def load_model(path, scale=4):
global model_path
global model_name
if "http" in path:
dl_name = "%s%s" % (model_name.replace(" ", "_"), ".pth")
filename = load_file_from_url(url=path, model_dir=model_path, file_name=dl_name, progress=True)
else:
filename = path
if filename is None or not os.path.exists(filename):
return None
model = net( model = net(
upscale=scale, upscale=scale,
in_chans=3, in_chans=3,
@ -37,19 +55,29 @@ def load_model(filename, scale=4):
return model return model
def load_models(dirname): def setup_model(dirname):
for file in os.listdir(dirname): global model_path
path = os.path.join(dirname, file) global model_name
model_name, extension = os.path.splitext(file) global cmd_path
if not os.path.exists(model_path):
os.makedirs(model_path)
cmd_path = dirname
model_file = ""
try:
models = modelloader.load_models(model_path, ext_filter=[".pt", ".pth"], command_path=cmd_path)
if extension != ".pt" and extension != ".pth": if len(models) != 0:
continue model_file = models[0]
name = modelloader.friendly_name(model_file)
else:
# Add the "default" model if none are found.
model_file = model_url
name = model_name
try: modules.shared.sd_upscalers.append(UpscalerSwin(model_file, name))
modules.shared.sd_upscalers.append(UpscalerSwin(path, model_name)) except Exception:
except Exception: print(f"Error loading SwinIR model: {model_file}", file=sys.stderr)
print(f"Error loading SwinIR model: {path}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
def upscale( def upscale(
@ -115,9 +143,16 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
class UpscalerSwin(modules.images.Upscaler): class UpscalerSwin(modules.images.Upscaler):
def __init__(self, filename, title): def __init__(self, filename, title):
self.name = title self.name = title
self.model = load_model(filename) self.filename = filename
def do_upscale(self, img): def do_upscale(self, img):
model = self.model.to(device) model = load_model(self.filename)
if model is None:
return img
model = model.to(device)
img = upscale(img, model) img = upscale(img, model)
try:
torch.cuda.empty_cache()
except:
pass
return img return img

View File

@ -1,37 +1,34 @@
import os import os
import signal
import threading import threading
from modules.paths import script_path import modules.codeformer_model as codeformer
import modules.esrgan_model as esrgan
import signal import modules.extras
import modules.face_restoration
from modules.shared import opts, cmd_opts, state import modules.gfpgan_model as gfpgan
import modules.shared as shared import modules.img2img
import modules.ui import modules.ldsr_model as ldsr
import modules.lowvram
import modules.realesrgan_model as realesrgan
import modules.scripts import modules.scripts
import modules.sd_hijack import modules.sd_hijack
import modules.codeformer_model
import modules.gfpgan_model
import modules.face_restoration
import modules.realesrgan_model as realesrgan
import modules.esrgan_model as esrgan
import modules.ldsr_model as ldsr
import modules.extras
import modules.lowvram
import modules.txt2img
import modules.img2img
import modules.swinir as swinir
import modules.sd_models import modules.sd_models
import modules.shared as shared
import modules.swinir_model as swinir
import modules.txt2img
import modules.ui
from modules.paths import script_path
from modules.shared import cmd_opts
codeformer.setup_model(cmd_opts.codeformer_models_path)
modules.codeformer_model.setup_codeformer() gfpgan.setup_model(cmd_opts.gfpgan_models_path)
modules.gfpgan_model.setup_gfpgan()
shared.face_restorers.append(modules.face_restoration.FaceRestoration()) shared.face_restorers.append(modules.face_restoration.FaceRestoration())
esrgan.load_models(cmd_opts.esrgan_models_path) esrgan.setup_model(cmd_opts.esrgan_models_path)
swinir.load_models(cmd_opts.swinir_models_path) swinir.setup_model(cmd_opts.swinir_models_path)
realesrgan.setup_realesrgan() realesrgan.setup_model(cmd_opts.realesrgan_models_path)
ldsr.add_lsdr() ldsr.setup_model(cmd_opts.ldsr_models_path)
queue_lock = threading.Lock() queue_lock = threading.Lock()