615 lines
21 KiB
Python
615 lines
21 KiB
Python
import os
|
|
import math
|
|
import pickle
|
|
import random
|
|
|
|
import numpy
|
|
import numpy as np
|
|
import glob
|
|
import torch
|
|
import torchvision
|
|
import cv2
|
|
|
|
####################
|
|
# Files & IO
|
|
####################
|
|
|
|
###################### get image path list ######################
|
|
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.webp', '.WEBP']
|
|
|
|
|
|
def torch2cv(tensor):
|
|
if len(tensor.shape) == 4:
|
|
squeezed = True
|
|
tensor = tensor.squeeze(0)
|
|
assert len(tensor.shape) == 3
|
|
pil = torchvision.transforms.ToPILImage()(tensor)
|
|
np = numpy.array(pil)
|
|
return cv2.cvtColor(np, cv2.COLOR_RGB2BGR) / 255.0
|
|
|
|
|
|
def cv2torch(cv, batchify=True):
|
|
cv = cv2.cvtColor(cv, cv2.COLOR_BGR2RGB)
|
|
tens = torch.from_numpy(np.ascontiguousarray(np.transpose(cv, (2, 0, 1)))).float()
|
|
if batchify:
|
|
tens = tens.unsqueeze(0)
|
|
return tens
|
|
|
|
|
|
def is_image_file(filename):
|
|
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
|
|
|
|
|
def is_wav_file(filename):
|
|
return filename.endswith('.wav')
|
|
|
|
|
|
def is_audio_file(filename):
|
|
AUDIO_EXTENSIONS = ['.wav', '.mp3', '.wma', 'm4b']
|
|
return any(filename.endswith(extension) for extension in AUDIO_EXTENSIONS)
|
|
|
|
|
|
def _get_paths_from_images(path, qualifier=is_image_file):
|
|
"""get image path list from image folder"""
|
|
assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
|
|
images = []
|
|
for dirpath, _, fnames in sorted(os.walk(path)):
|
|
for fname in sorted(fnames):
|
|
if qualifier(fname) and 'ref.jpg' not in fname:
|
|
img_path = os.path.join(dirpath, fname)
|
|
images.append(img_path)
|
|
if not images:
|
|
print("Warning: {:s} has no valid image file".format(path))
|
|
return images
|
|
|
|
|
|
def _get_paths_from_lmdb(dataroot):
|
|
"""get image path list from lmdb meta info"""
|
|
meta_info = pickle.load(open(os.path.join(dataroot, 'meta_info.pkl'), 'rb'))
|
|
paths = meta_info['keys']
|
|
sizes = meta_info['resolution']
|
|
if len(sizes) == 1:
|
|
sizes = sizes * len(paths)
|
|
return paths, sizes
|
|
|
|
|
|
def find_audio_files(dataroot, include_nonwav=False):
|
|
if include_nonwav:
|
|
return find_files_of_type(None, dataroot, qualifier=is_audio_file)[0]
|
|
else:
|
|
return find_files_of_type(None, dataroot, qualifier=is_wav_file)[0]
|
|
|
|
|
|
def find_files_of_type(data_type, dataroot, weights=[], qualifier=is_image_file):
|
|
if isinstance(dataroot, list):
|
|
paths = []
|
|
for i in range(len(dataroot)):
|
|
r = dataroot[i]
|
|
extends = 1
|
|
|
|
# Weights have the effect of repeatedly adding the paths from the given root to the final product.
|
|
if weights:
|
|
extends = weights[i]
|
|
for j in range(extends):
|
|
paths.extend(_get_paths_from_images(r, qualifier))
|
|
paths = sorted(paths)
|
|
sizes = len(paths)
|
|
else:
|
|
paths = sorted(_get_paths_from_images(dataroot, qualifier))
|
|
sizes = len(paths)
|
|
return paths, sizes
|
|
|
|
|
|
def glob_file_list(root):
|
|
return sorted(glob.glob(os.path.join(root, '*')))
|
|
|
|
|
|
###################### read images ######################
|
|
def _read_img_lmdb(env, key, size):
|
|
"""read image from lmdb with key (w/ and w/o fixed size)
|
|
size: (C, H, W) tuple"""
|
|
with env.begin(write=False) as txn:
|
|
buf = txn.get(key.encode('ascii'))
|
|
img_flat = np.frombuffer(buf, dtype=np.uint8)
|
|
C, H, W = size
|
|
img = img_flat.reshape(H, W, C)
|
|
return img
|
|
|
|
|
|
def read_img(env, path, size=None, rgb=False):
|
|
"""read image by cv2 or from lmdb or from a buffer (in which case path=buffer)
|
|
return: Numpy float32, HWC, BGR, [0,1]"""
|
|
if env is None: # img
|
|
# Indirect open then process to support unicode files.
|
|
stream = open(path, "rb")
|
|
bytes = bytearray(stream.read())
|
|
img = cv2.imdecode(np.asarray(bytes, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
|
|
elif env == 'lmdb':
|
|
img = _read_img_lmdb(env, path, size)
|
|
elif env == 'buffer':
|
|
img = cv2.imdecode(path, cv2.IMREAD_UNCHANGED)
|
|
else:
|
|
raise NotImplementedError("Unsupported env: %s" % (env,))
|
|
|
|
if img is None:
|
|
print("Image error: %s" % (path,))
|
|
img = img.astype(np.float32) / 255.
|
|
if img.ndim == 2:
|
|
img = np.expand_dims(img, axis=2)
|
|
# some images have 4 channels
|
|
if img.shape[2] > 3:
|
|
img = img[:, :, :3]
|
|
|
|
if rgb:
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
return img
|
|
|
|
|
|
def read_img_seq(path):
|
|
"""Read a sequence of images from a given folder path
|
|
Args:
|
|
path (list/str): list of image paths/image folder path
|
|
|
|
Returns:
|
|
imgs (Tensor): size (T, C, H, W), RGB, [0, 1]
|
|
"""
|
|
if type(path) is list:
|
|
img_path_l = path
|
|
else:
|
|
img_path_l = sorted(glob.glob(os.path.join(path, '*')))
|
|
img_l = [read_img(None, v) for v in img_path_l]
|
|
# stack to Torch tensor
|
|
imgs = np.stack(img_l, axis=0)
|
|
imgs = imgs[:, :, :, [2, 1, 0]]
|
|
imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float()
|
|
return imgs
|
|
|
|
|
|
def index_generation(crt_i, max_n, N, padding='reflection'):
|
|
"""Generate an index list for reading N frames from a sequence of images
|
|
Args:
|
|
crt_i (int): current center index
|
|
max_n (int): max number of the sequence of images (calculated from 1)
|
|
N (int): reading N frames
|
|
padding (str): padding mode, one of replicate | reflection | new_info | circle
|
|
Example: crt_i = 0, N = 5
|
|
replicate: [0, 0, 0, 1, 2]
|
|
reflection: [2, 1, 0, 1, 2]
|
|
new_info: [4, 3, 0, 1, 2]
|
|
circle: [3, 4, 0, 1, 2]
|
|
|
|
Returns:
|
|
return_l (list [int]): a list of indexes
|
|
"""
|
|
max_n = max_n - 1
|
|
n_pad = N // 2
|
|
return_l = []
|
|
|
|
for i in range(crt_i - n_pad, crt_i + n_pad + 1):
|
|
if i < 0:
|
|
if padding == 'replicate':
|
|
add_idx = 0
|
|
elif padding == 'reflection':
|
|
add_idx = -i
|
|
elif padding == 'new_info':
|
|
add_idx = (crt_i + n_pad) + (-i)
|
|
elif padding == 'circle':
|
|
add_idx = N + i
|
|
else:
|
|
raise ValueError('Wrong padding mode')
|
|
elif i > max_n:
|
|
if padding == 'replicate':
|
|
add_idx = max_n
|
|
elif padding == 'reflection':
|
|
add_idx = max_n * 2 - i
|
|
elif padding == 'new_info':
|
|
add_idx = (crt_i - n_pad) - (i - max_n)
|
|
elif padding == 'circle':
|
|
add_idx = i - N
|
|
else:
|
|
raise ValueError('Wrong padding mode')
|
|
else:
|
|
add_idx = i
|
|
return_l.append(add_idx)
|
|
return return_l
|
|
|
|
|
|
####################
|
|
# image processing
|
|
# process on numpy image
|
|
####################
|
|
|
|
|
|
def augment(img_list, hflip=True, rot=True):
|
|
"""horizontal flip OR rotate (0, 90, 180, 270 degrees)"""
|
|
hflip = hflip and random.random() < 0.5
|
|
vflip = rot and random.random() < 0.5
|
|
rot90 = rot and random.random() < 0.5
|
|
|
|
def _augment(img):
|
|
if hflip:
|
|
img = img[:, ::-1, :]
|
|
if vflip:
|
|
img = img[::-1, :, :]
|
|
if rot90:
|
|
img = img.transpose(1, 0, 2)
|
|
return img
|
|
|
|
return [_augment(img) for img in img_list]
|
|
|
|
|
|
def augment_flow(img_list, flow_list, hflip=True, rot=True):
|
|
"""horizontal flip OR rotate (0, 90, 180, 270 degrees) with flows"""
|
|
hflip = hflip and random.random() < 0.5
|
|
vflip = rot and random.random() < 0.5
|
|
rot90 = rot and random.random() < 0.5
|
|
|
|
def _augment(img):
|
|
if hflip:
|
|
img = img[:, ::-1, :]
|
|
if vflip:
|
|
img = img[::-1, :, :]
|
|
if rot90:
|
|
img = img.transpose(1, 0, 2)
|
|
return img
|
|
|
|
def _augment_flow(flow):
|
|
if hflip:
|
|
flow = flow[:, ::-1, :]
|
|
flow[:, :, 0] *= -1
|
|
if vflip:
|
|
flow = flow[::-1, :, :]
|
|
flow[:, :, 1] *= -1
|
|
if rot90:
|
|
flow = flow.transpose(1, 0, 2)
|
|
flow = flow[:, :, [1, 0]]
|
|
return flow
|
|
|
|
rlt_img_list = [_augment(img) for img in img_list]
|
|
rlt_flow_list = [_augment_flow(flow) for flow in flow_list]
|
|
|
|
return rlt_img_list, rlt_flow_list
|
|
|
|
|
|
def channel_convert(in_c, tar_type, img_list):
|
|
"""conversion among BGR, gray and y"""
|
|
if in_c == 3 and tar_type == 'gray': # BGR to gray
|
|
gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
|
|
return [np.expand_dims(img, axis=2) for img in gray_list]
|
|
elif in_c == 3 and tar_type == 'y': # BGR to y
|
|
y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
|
|
return [np.expand_dims(img, axis=2) for img in y_list]
|
|
elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR
|
|
return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
|
|
else:
|
|
return img_list
|
|
|
|
|
|
def rgb2ycbcr(img, only_y=True):
|
|
"""same as matlab rgb2ycbcr
|
|
only_y: only return Y channel
|
|
Input:
|
|
uint8, [0, 255]
|
|
float, [0, 1]
|
|
"""
|
|
in_img_type = img.dtype
|
|
img.astype(np.float32)
|
|
if in_img_type != np.uint8:
|
|
img *= 255.
|
|
# convert
|
|
if only_y:
|
|
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
|
|
else:
|
|
rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
|
|
[24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
|
|
if in_img_type == np.uint8:
|
|
rlt = rlt.round()
|
|
else:
|
|
rlt /= 255.
|
|
return rlt.astype(in_img_type)
|
|
|
|
|
|
def bgr2ycbcr(img, only_y=True):
|
|
"""bgr version of rgb2ycbcr
|
|
only_y: only return Y channel
|
|
Input:
|
|
uint8, [0, 255]
|
|
float, [0, 1]
|
|
"""
|
|
in_img_type = img.dtype
|
|
img.astype(np.float32)
|
|
if in_img_type != np.uint8:
|
|
img *= 255.
|
|
# convert
|
|
if only_y:
|
|
rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
|
|
else:
|
|
rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
|
|
[65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
|
|
if in_img_type == np.uint8:
|
|
rlt = rlt.round()
|
|
else:
|
|
rlt /= 255.
|
|
return rlt.astype(in_img_type)
|
|
|
|
|
|
def ycbcr2rgb(img):
|
|
"""same as matlab ycbcr2rgb
|
|
Input:
|
|
uint8, [0, 255]
|
|
float, [0, 1]
|
|
"""
|
|
in_img_type = img.dtype
|
|
img.astype(np.float32)
|
|
if in_img_type != np.uint8:
|
|
img *= 255.
|
|
# convert
|
|
rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
|
|
[0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
|
|
if in_img_type == np.uint8:
|
|
rlt = rlt.round()
|
|
else:
|
|
rlt /= 255.
|
|
return rlt.astype(in_img_type)
|
|
|
|
|
|
def modcrop(img_in, scale):
|
|
"""img_in: Numpy, HWC or HW"""
|
|
img = np.copy(img_in)
|
|
if img.ndim == 2:
|
|
H, W = img.shape
|
|
H_r, W_r = H % scale, W % scale
|
|
img = img[:H - H_r, :W - W_r]
|
|
elif img.ndim == 3:
|
|
H, W, C = img.shape
|
|
H_r, W_r = H % scale, W % scale
|
|
img = img[:H - H_r, :W - W_r, :]
|
|
else:
|
|
raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
|
|
return img
|
|
|
|
|
|
####################
|
|
# Functions
|
|
####################
|
|
|
|
|
|
# matlab 'imresize' function, now only support 'bicubic'
|
|
def cubic(x):
|
|
absx = torch.abs(x)
|
|
absx2 = absx**2
|
|
absx3 = absx**3
|
|
return (1.5 * absx3 - 2.5 * absx2 + 1) * (
|
|
(absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * ((
|
|
(absx > 1) * (absx <= 2)).type_as(absx))
|
|
|
|
|
|
def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
|
|
if (scale < 1) and (antialiasing):
|
|
# Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
|
|
kernel_width = kernel_width / scale
|
|
|
|
# Output-space coordinates
|
|
x = torch.linspace(1, out_length, out_length)
|
|
|
|
# Input-space coordinates. Calculate the inverse mapping such that 0.5
|
|
# in output space maps to 0.5 in input space, and 0.5+scale in output
|
|
# space maps to 1.5 in input space.
|
|
u = x / scale + 0.5 * (1 - 1 / scale)
|
|
|
|
# What is the left-most pixel that can be involved in the computation?
|
|
left = torch.floor(u - kernel_width / 2)
|
|
|
|
# What is the maximum number of pixels that can be involved in the
|
|
# computation? Note: it's OK to use an extra pixel here; if the
|
|
# corresponding weights are all zero, it will be eliminated at the end
|
|
# of this function.
|
|
P = math.ceil(kernel_width) + 2
|
|
|
|
# The indices of the input pixels involved in computing the k-th output
|
|
# pixel are in row k of the indices matrix.
|
|
indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
|
|
1, P).expand(out_length, P)
|
|
|
|
# The weights used to compute the k-th output pixel are in row k of the
|
|
# weights matrix.
|
|
distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
|
|
# apply cubic kernel
|
|
if (scale < 1) and (antialiasing):
|
|
weights = scale * cubic(distance_to_center * scale)
|
|
else:
|
|
weights = cubic(distance_to_center)
|
|
# Normalize the weights matrix so that each row sums to 1.
|
|
weights_sum = torch.sum(weights, 1).view(out_length, 1)
|
|
weights = weights / weights_sum.expand(out_length, P)
|
|
|
|
# If a column in weights is all zero, get rid of it. only consider the first and last column.
|
|
weights_zero_tmp = torch.sum((weights == 0), 0)
|
|
if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
|
|
indices = indices.narrow(1, 1, P - 2)
|
|
weights = weights.narrow(1, 1, P - 2)
|
|
if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
|
|
indices = indices.narrow(1, 0, P - 2)
|
|
weights = weights.narrow(1, 0, P - 2)
|
|
weights = weights.contiguous()
|
|
indices = indices.contiguous()
|
|
sym_len_s = -indices.min() + 1
|
|
sym_len_e = indices.max() - in_length
|
|
indices = indices + sym_len_s - 1
|
|
return weights, indices, int(sym_len_s), int(sym_len_e)
|
|
|
|
|
|
def imresize(img, scale, antialiasing=True):
|
|
# Now the scale should be the same for H and W
|
|
# input: img: CHW RGB [0,1]
|
|
# output: CHW RGB [0,1] w/o round
|
|
|
|
in_C, in_H, in_W = img.size()
|
|
_, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
|
|
kernel_width = 4
|
|
kernel = 'cubic'
|
|
|
|
# Return the desired dimension order for performing the resize. The
|
|
# strategy is to perform the resize first along the dimension with the
|
|
# smallest scale factor.
|
|
# Now we do not support this.
|
|
|
|
# get weights and indices
|
|
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
|
|
in_H, out_H, scale, kernel, kernel_width, antialiasing)
|
|
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
|
|
in_W, out_W, scale, kernel, kernel_width, antialiasing)
|
|
# process H dimension
|
|
# symmetric copying
|
|
img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
|
|
img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
|
|
|
|
sym_patch = img[:, :sym_len_Hs, :]
|
|
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
|
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
|
img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
|
|
|
|
sym_patch = img[:, -sym_len_He:, :]
|
|
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
|
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
|
img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
|
|
|
|
out_1 = torch.FloatTensor(in_C, out_H, in_W)
|
|
kernel_width = weights_H.size(1)
|
|
for i in range(out_H):
|
|
idx = int(indices_H[i][0])
|
|
out_1[0, i, :] = img_aug[0, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
|
|
out_1[1, i, :] = img_aug[1, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
|
|
out_1[2, i, :] = img_aug[2, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
|
|
|
|
# process W dimension
|
|
# symmetric copying
|
|
out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
|
|
out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
|
|
|
|
sym_patch = out_1[:, :, :sym_len_Ws]
|
|
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
|
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
|
out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
|
|
|
|
sym_patch = out_1[:, :, -sym_len_We:]
|
|
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
|
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
|
out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
|
|
|
|
out_2 = torch.FloatTensor(in_C, out_H, out_W)
|
|
kernel_width = weights_W.size(1)
|
|
for i in range(out_W):
|
|
idx = int(indices_W[i][0])
|
|
out_2[0, :, i] = out_1_aug[0, :, idx:idx + kernel_width].mv(weights_W[i])
|
|
out_2[1, :, i] = out_1_aug[1, :, idx:idx + kernel_width].mv(weights_W[i])
|
|
out_2[2, :, i] = out_1_aug[2, :, idx:idx + kernel_width].mv(weights_W[i])
|
|
|
|
return out_2
|
|
|
|
|
|
def imresize_np(img, scale, antialiasing=True):
|
|
# Now the scale should be the same for H and W
|
|
# input: img: Numpy, HWC BGR [0,1]
|
|
# output: HWC BGR [0,1] w/o round
|
|
img = torch.from_numpy(img)
|
|
|
|
in_H, in_W, in_C = img.size()
|
|
_, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
|
|
kernel_width = 4
|
|
kernel = 'cubic'
|
|
|
|
# Return the desired dimension order for performing the resize. The
|
|
# strategy is to perform the resize first along the dimension with the
|
|
# smallest scale factor.
|
|
# Now we do not support this.
|
|
|
|
# get weights and indices
|
|
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
|
|
in_H, out_H, scale, kernel, kernel_width, antialiasing)
|
|
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
|
|
in_W, out_W, scale, kernel, kernel_width, antialiasing)
|
|
# process H dimension
|
|
# symmetric copying
|
|
img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
|
|
img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
|
|
|
|
sym_patch = img[:sym_len_Hs, :, :]
|
|
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
|
|
sym_patch_inv = sym_patch.index_select(0, inv_idx)
|
|
img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
|
|
|
|
sym_patch = img[-sym_len_He:, :, :]
|
|
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
|
|
sym_patch_inv = sym_patch.index_select(0, inv_idx)
|
|
img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
|
|
|
|
out_1 = torch.FloatTensor(out_H, in_W, in_C)
|
|
kernel_width = weights_H.size(1)
|
|
for i in range(out_H):
|
|
idx = int(indices_H[i][0])
|
|
out_1[i, :, 0] = img_aug[idx:idx + kernel_width, :, 0].transpose(0, 1).mv(weights_H[i])
|
|
out_1[i, :, 1] = img_aug[idx:idx + kernel_width, :, 1].transpose(0, 1).mv(weights_H[i])
|
|
out_1[i, :, 2] = img_aug[idx:idx + kernel_width, :, 2].transpose(0, 1).mv(weights_H[i])
|
|
|
|
# process W dimension
|
|
# symmetric copying
|
|
out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
|
|
out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
|
|
|
|
sym_patch = out_1[:, :sym_len_Ws, :]
|
|
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
|
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
|
out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
|
|
|
|
sym_patch = out_1[:, -sym_len_We:, :]
|
|
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
|
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
|
out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
|
|
|
|
out_2 = torch.FloatTensor(out_H, out_W, in_C)
|
|
kernel_width = weights_W.size(1)
|
|
for i in range(out_W):
|
|
idx = int(indices_W[i][0])
|
|
out_2[:, i, 0] = out_1_aug[:, idx:idx + kernel_width, 0].mv(weights_W[i])
|
|
out_2[:, i, 1] = out_1_aug[:, idx:idx + kernel_width, 1].mv(weights_W[i])
|
|
out_2[:, i, 2] = out_1_aug[:, idx:idx + kernel_width, 2].mv(weights_W[i])
|
|
|
|
return out_2.numpy()
|
|
|
|
|
|
def load_paths_from_cache(paths, cache_path):
|
|
if not isinstance(paths, list):
|
|
paths = [paths]
|
|
if os.path.exists(cache_path):
|
|
output = torch.load(cache_path)
|
|
else:
|
|
print(f"Building cache for contents of {paths}..")
|
|
output = []
|
|
for p in paths:
|
|
output.extend(find_files_of_type('img', p, qualifier=is_audio_file)[0])
|
|
torch.save(output, cache_path)
|
|
return output
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# test imresize function
|
|
# read images
|
|
img = cv2.imread('test.png')
|
|
img = img * 1.0 / 255
|
|
img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
|
|
# imresize
|
|
scale = 1 / 4
|
|
import time
|
|
total_time = 0
|
|
for i in range(10):
|
|
start_time = time.time()
|
|
rlt = imresize(img, scale, antialiasing=True)
|
|
use_time = time.time() - start_time
|
|
total_time += use_time
|
|
print('average time: {}'.format(total_time / 10))
|
|
|
|
import torchvision.utils
|
|
torchvision.utils.save_image((rlt * 255).round() / 255, 'rlt.png', nrow=1, padding=0,
|
|
normalize=False)
|