faed465a0b
Get ESRGAN, SCUNet, and SwinIR working correctly on MPS by ensuring memory is contiguous for tensor views before sending to MPS device.
225 lines
8.6 KiB
Python
225 lines
8.6 KiB
Python
import os
|
|
|
|
import numpy as np
|
|
import torch
|
|
from PIL import Image
|
|
from basicsr.utils.download_util import load_file_from_url
|
|
|
|
import modules.esrgan_model_arch as arch
|
|
from modules import shared, modelloader, images, devices
|
|
from modules.upscaler import Upscaler, UpscalerData
|
|
from modules.shared import opts
|
|
|
|
|
|
|
|
def mod2normal(state_dict):
|
|
# this code is copied from https://github.com/victorca25/iNNfer
|
|
if 'conv_first.weight' in state_dict:
|
|
crt_net = {}
|
|
items = []
|
|
for k, v in state_dict.items():
|
|
items.append(k)
|
|
|
|
crt_net['model.0.weight'] = state_dict['conv_first.weight']
|
|
crt_net['model.0.bias'] = state_dict['conv_first.bias']
|
|
|
|
for k in items.copy():
|
|
if 'RDB' in k:
|
|
ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
|
|
if '.weight' in k:
|
|
ori_k = ori_k.replace('.weight', '.0.weight')
|
|
elif '.bias' in k:
|
|
ori_k = ori_k.replace('.bias', '.0.bias')
|
|
crt_net[ori_k] = state_dict[k]
|
|
items.remove(k)
|
|
|
|
crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight']
|
|
crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias']
|
|
crt_net['model.3.weight'] = state_dict['upconv1.weight']
|
|
crt_net['model.3.bias'] = state_dict['upconv1.bias']
|
|
crt_net['model.6.weight'] = state_dict['upconv2.weight']
|
|
crt_net['model.6.bias'] = state_dict['upconv2.bias']
|
|
crt_net['model.8.weight'] = state_dict['HRconv.weight']
|
|
crt_net['model.8.bias'] = state_dict['HRconv.bias']
|
|
crt_net['model.10.weight'] = state_dict['conv_last.weight']
|
|
crt_net['model.10.bias'] = state_dict['conv_last.bias']
|
|
state_dict = crt_net
|
|
return state_dict
|
|
|
|
|
|
def resrgan2normal(state_dict, nb=23):
|
|
# this code is copied from https://github.com/victorca25/iNNfer
|
|
if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
|
|
crt_net = {}
|
|
items = []
|
|
for k, v in state_dict.items():
|
|
items.append(k)
|
|
|
|
crt_net['model.0.weight'] = state_dict['conv_first.weight']
|
|
crt_net['model.0.bias'] = state_dict['conv_first.bias']
|
|
|
|
for k in items.copy():
|
|
if "rdb" in k:
|
|
ori_k = k.replace('body.', 'model.1.sub.')
|
|
ori_k = ori_k.replace('.rdb', '.RDB')
|
|
if '.weight' in k:
|
|
ori_k = ori_k.replace('.weight', '.0.weight')
|
|
elif '.bias' in k:
|
|
ori_k = ori_k.replace('.bias', '.0.bias')
|
|
crt_net[ori_k] = state_dict[k]
|
|
items.remove(k)
|
|
|
|
crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight']
|
|
crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias']
|
|
crt_net['model.3.weight'] = state_dict['conv_up1.weight']
|
|
crt_net['model.3.bias'] = state_dict['conv_up1.bias']
|
|
crt_net['model.6.weight'] = state_dict['conv_up2.weight']
|
|
crt_net['model.6.bias'] = state_dict['conv_up2.bias']
|
|
crt_net['model.8.weight'] = state_dict['conv_hr.weight']
|
|
crt_net['model.8.bias'] = state_dict['conv_hr.bias']
|
|
crt_net['model.10.weight'] = state_dict['conv_last.weight']
|
|
crt_net['model.10.bias'] = state_dict['conv_last.bias']
|
|
state_dict = crt_net
|
|
return state_dict
|
|
|
|
|
|
def infer_params(state_dict):
|
|
# this code is copied from https://github.com/victorca25/iNNfer
|
|
scale2x = 0
|
|
scalemin = 6
|
|
n_uplayer = 0
|
|
plus = False
|
|
|
|
for block in list(state_dict):
|
|
parts = block.split(".")
|
|
n_parts = len(parts)
|
|
if n_parts == 5 and parts[2] == "sub":
|
|
nb = int(parts[3])
|
|
elif n_parts == 3:
|
|
part_num = int(parts[1])
|
|
if (part_num > scalemin
|
|
and parts[0] == "model"
|
|
and parts[2] == "weight"):
|
|
scale2x += 1
|
|
if part_num > n_uplayer:
|
|
n_uplayer = part_num
|
|
out_nc = state_dict[block].shape[0]
|
|
if not plus and "conv1x1" in block:
|
|
plus = True
|
|
|
|
nf = state_dict["model.0.weight"].shape[0]
|
|
in_nc = state_dict["model.0.weight"].shape[1]
|
|
out_nc = out_nc
|
|
scale = 2 ** scale2x
|
|
|
|
return in_nc, out_nc, nf, nb, plus, scale
|
|
|
|
|
|
class UpscalerESRGAN(Upscaler):
|
|
def __init__(self, dirname):
|
|
self.name = "ESRGAN"
|
|
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/ESRGAN.pth"
|
|
self.model_name = "ESRGAN_4x"
|
|
self.scalers = []
|
|
self.user_path = dirname
|
|
super().__init__()
|
|
model_paths = self.find_models(ext_filter=[".pt", ".pth"])
|
|
scalers = []
|
|
if len(model_paths) == 0:
|
|
scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
|
|
scalers.append(scaler_data)
|
|
for file in model_paths:
|
|
if "http" in file:
|
|
name = self.model_name
|
|
else:
|
|
name = modelloader.friendly_name(file)
|
|
|
|
scaler_data = UpscalerData(name, file, self, 4)
|
|
self.scalers.append(scaler_data)
|
|
|
|
def do_upscale(self, img, selected_model):
|
|
model = self.load_model(selected_model)
|
|
if model is None:
|
|
return img
|
|
model.to(devices.device_esrgan)
|
|
img = esrgan_upscale(model, img)
|
|
return img
|
|
|
|
def load_model(self, path: str):
|
|
if "http" in path:
|
|
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path,
|
|
file_name="%s.pth" % self.model_name,
|
|
progress=True)
|
|
else:
|
|
filename = path
|
|
if not os.path.exists(filename) or filename is None:
|
|
print("Unable to load %s from %s" % (self.model_path, filename))
|
|
return None
|
|
|
|
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
|
|
|
|
if "params_ema" in state_dict:
|
|
state_dict = state_dict["params_ema"]
|
|
elif "params" in state_dict:
|
|
state_dict = state_dict["params"]
|
|
num_conv = 16 if "realesr-animevideov3" in filename else 32
|
|
model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu')
|
|
model.load_state_dict(state_dict)
|
|
model.eval()
|
|
return model
|
|
|
|
if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict:
|
|
nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23
|
|
state_dict = resrgan2normal(state_dict, nb)
|
|
elif "conv_first.weight" in state_dict:
|
|
state_dict = mod2normal(state_dict)
|
|
elif "model.0.weight" not in state_dict:
|
|
raise Exception("The file is not a recognized ESRGAN model.")
|
|
|
|
in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict)
|
|
|
|
model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus)
|
|
model.load_state_dict(state_dict)
|
|
model.eval()
|
|
|
|
return model
|
|
|
|
|
|
def upscale_without_tiling(model, img):
|
|
img = np.array(img)
|
|
img = img[:, :, ::-1]
|
|
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
|
|
img = torch.from_numpy(img).float()
|
|
img = devices.mps_contiguous_to(img.unsqueeze(0), devices.device_esrgan)
|
|
with torch.no_grad():
|
|
output = model(img)
|
|
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
|
output = 255. * np.moveaxis(output, 0, 2)
|
|
output = output.astype(np.uint8)
|
|
output = output[:, :, ::-1]
|
|
return Image.fromarray(output, 'RGB')
|
|
|
|
|
|
def esrgan_upscale(model, img):
|
|
if opts.ESRGAN_tile == 0:
|
|
return upscale_without_tiling(model, img)
|
|
|
|
grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
|
|
newtiles = []
|
|
scale_factor = 1
|
|
|
|
for y, h, row in grid.tiles:
|
|
newrow = []
|
|
for tiledata in row:
|
|
x, w, tile = tiledata
|
|
|
|
output = upscale_without_tiling(model, tile)
|
|
scale_factor = output.width // tile.width
|
|
|
|
newrow.append([x * scale_factor, w * scale_factor, output])
|
|
newtiles.append([y * scale_factor, h * scale_factor, newrow])
|
|
|
|
newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
|
|
output = images.combine_grid(newgrid)
|
|
return output
|