forked from mrq/DL-Art-School
Move datasets to INTER_AREA interpolation for downsizing
Looks **FAR** better visually
This commit is contained in:
parent
d56745b2ec
commit
1ba01d69b5
|
@ -68,9 +68,9 @@ class BaseUnsupervisedImageDataset(data.Dataset):
|
||||||
for hq, hq_ref, hq_mask, hq_center in zip(imgs_hq, refs_hq, masks_hq, centers_hq):
|
for hq, hq_ref, hq_mask, hq_center in zip(imgs_hq, refs_hq, masks_hq, centers_hq):
|
||||||
# It is assumed that the target size is a square.
|
# It is assumed that the target size is a square.
|
||||||
target_size = (self.target_hq_size, self.target_hq_size)
|
target_size = (self.target_hq_size, self.target_hq_size)
|
||||||
hqs_adjusted.append(cv2.resize(hq, target_size, interpolation=cv2.INTER_LINEAR))
|
hqs_adjusted.append(cv2.resize(hq, target_size, interpolation=cv2.INTER_AREA))
|
||||||
hq_refs_adjusted.append(cv2.resize(hq_ref, target_size, interpolation=cv2.INTER_LINEAR))
|
hq_refs_adjusted.append(cv2.resize(hq_ref, target_size, interpolation=cv2.INTER_AREA))
|
||||||
hq_masks_adjusted.append(cv2.resize(hq_mask, target_size, interpolation=cv2.INTER_LINEAR))
|
hq_masks_adjusted.append(cv2.resize(hq_mask, target_size, interpolation=cv2.INTER_AREA))
|
||||||
hq_centers_adjusted.append(self.resize_point(hq_center, (h, w), target_size))
|
hq_centers_adjusted.append(self.resize_point(hq_center, (h, w), target_size))
|
||||||
h, w = self.target_hq_size, self.target_hq_size
|
h, w = self.target_hq_size, self.target_hq_size
|
||||||
else:
|
else:
|
||||||
|
@ -97,9 +97,9 @@ class BaseUnsupervisedImageDataset(data.Dataset):
|
||||||
lms.append(hq_mask)
|
lms.append(hq_mask)
|
||||||
lcs.append(hq_center)
|
lcs.append(hq_center)
|
||||||
else:
|
else:
|
||||||
ls.append(cv2.resize(hq, (h // self.scale, w // self.scale), interpolation=cv2.INTER_LINEAR))
|
ls.append(cv2.resize(hq, (h // self.scale, w // self.scale), interpolation=cv2.INTER_AREA))
|
||||||
lrs.append(cv2.resize(hq_ref, (h // self.scale, w // self.scale), interpolation=cv2.INTER_LINEAR))
|
lrs.append(cv2.resize(hq_ref, (h // self.scale, w // self.scale), interpolation=cv2.INTER_AREA))
|
||||||
lms.append(cv2.resize(hq_mask, (h // self.scale, w // self.scale), interpolation=cv2.INTER_LINEAR))
|
lms.append(cv2.resize(hq_mask, (h // self.scale, w // self.scale), interpolation=cv2.INTER_AREA))
|
||||||
lcs.append(self.resize_point(hq_center, (h, w), ls[0].shape[:2]))
|
lcs.append(self.resize_point(hq_center, (h, w), ls[0].shape[:2]))
|
||||||
# Corrupt the LQ image (only in eval mode)
|
# Corrupt the LQ image (only in eval mode)
|
||||||
if not self.for_eval:
|
if not self.for_eval:
|
||||||
|
|
|
@ -40,7 +40,7 @@ class MultiScaleDataset(data.Dataset):
|
||||||
return image[:, left:left+h, :]
|
return image[:, left:left+h, :]
|
||||||
|
|
||||||
def recursively_extract_patches(self, input_img, result_list, depth):
|
def recursively_extract_patches(self, input_img, result_list, depth):
|
||||||
if depth > self.num_scales:
|
if depth >= self.num_scales:
|
||||||
return
|
return
|
||||||
patch_size = self.hq_size_cap // (2 ** depth)
|
patch_size = self.hq_size_cap // (2 ** depth)
|
||||||
# First pull the four sub-patches.
|
# First pull the four sub-patches.
|
||||||
|
@ -48,7 +48,7 @@ class MultiScaleDataset(data.Dataset):
|
||||||
input_img[:patch_size, patch_size:],
|
input_img[:patch_size, patch_size:],
|
||||||
input_img[patch_size:, :patch_size],
|
input_img[patch_size:, :patch_size],
|
||||||
input_img[patch_size:, patch_size:]]
|
input_img[patch_size:, patch_size:]]
|
||||||
result_list.extend([cv2.resize(p, (self.tile_size, self.tile_size), interpolation=cv2.INTER_LINEAR) for p in patches])
|
result_list.extend([cv2.resize(p, (self.tile_size, self.tile_size), interpolation=cv2.INTER_AREA) for p in patches])
|
||||||
for p in patches:
|
for p in patches:
|
||||||
self.recursively_extract_patches(p, result_list, depth+1)
|
self.recursively_extract_patches(p, result_list, depth+1)
|
||||||
|
|
||||||
|
@ -59,8 +59,8 @@ class MultiScaleDataset(data.Dataset):
|
||||||
img_full = util.channel_convert(img_full.shape[2], 'RGB', [img_full])[0]
|
img_full = util.channel_convert(img_full.shape[2], 'RGB', [img_full])[0]
|
||||||
img_full = util.augment([img_full], True, True)[0]
|
img_full = util.augment([img_full], True, True)[0]
|
||||||
img_full = self.get_square_image(img_full)
|
img_full = self.get_square_image(img_full)
|
||||||
img_full = cv2.resize(img_full, (self.hq_size_cap, self.hq_size_cap), interpolation=cv2.INTER_LINEAR)
|
img_full = cv2.resize(img_full, (self.hq_size_cap, self.hq_size_cap), interpolation=cv2.INTER_AREA)
|
||||||
patches_hq = [cv2.resize(img_full, (self.tile_size, self.tile_size), interpolation=cv2.INTER_LINEAR)]
|
patches_hq = [cv2.resize(img_full, (self.tile_size, self.tile_size), interpolation=cv2.INTER_AREA)]
|
||||||
self.recursively_extract_patches(img_full, patches_hq, 1)
|
self.recursively_extract_patches(img_full, patches_hq, 1)
|
||||||
|
|
||||||
# BGR to RGB, HWC to CHW, numpy to tensor
|
# BGR to RGB, HWC to CHW, numpy to tensor
|
||||||
|
@ -75,24 +75,34 @@ class MultiScaleDataset(data.Dataset):
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.paths_hq)
|
return len(self.paths_hq)
|
||||||
|
|
||||||
|
class MultiscaleTreeNode:
|
||||||
|
def __init__(self, index, parent):
|
||||||
|
self.index = index
|
||||||
|
self.parent = parent
|
||||||
|
self.children = []
|
||||||
|
|
||||||
|
def add_child(self, child):
|
||||||
|
self.children.append(child)
|
||||||
|
return child
|
||||||
|
|
||||||
|
|
||||||
def build_multiscale_patch_index_map(depth):
|
def build_multiscale_patch_index_map(depth):
|
||||||
if depth < 0:
|
if depth < 0:
|
||||||
return
|
return
|
||||||
recursive_list = []
|
root = MultiscaleTreeNode(0, None)
|
||||||
map = (0, recursive_list)
|
leaves = []
|
||||||
_build_multiscale_patch_index_map(depth, 1, recursive_list)
|
_build_multiscale_patch_index_map(depth-1, 1, root, leaves)
|
||||||
return map
|
return leaves
|
||||||
|
|
||||||
|
|
||||||
def _build_multiscale_patch_index_map(depth, ind, recursive_list):
|
def _build_multiscale_patch_index_map(depth, ind, node, leaves):
|
||||||
if depth <= 0:
|
subnodes = [node.add_child(MultiscaleTreeNode(ind+i, node)) for i in range(4)]
|
||||||
return ind
|
|
||||||
patches = [(ind+i, []) for i in range(4)]
|
|
||||||
recursive_list.extend(patches)
|
|
||||||
ind += 4
|
ind += 4
|
||||||
for _, p in patches:
|
if depth == 1:
|
||||||
ind = _build_multiscale_patch_index_map(depth-1, ind, p)
|
leaves.extend(subnodes)
|
||||||
|
else:
|
||||||
|
for n in subnodes:
|
||||||
|
ind = _build_multiscale_patch_index_map(depth-1, ind, n, leaves)
|
||||||
return ind
|
return ind
|
||||||
|
|
||||||
|
|
||||||
|
@ -109,18 +119,19 @@ if __name__ == '__main__':
|
||||||
ds = MultiScaleDataset(opt)
|
ds = MultiScaleDataset(opt)
|
||||||
import os
|
import os
|
||||||
os.makedirs("debug", exist_ok=True)
|
os.makedirs("debug", exist_ok=True)
|
||||||
multiscale_map = build_multiscale_patch_index_map(4)
|
multiscale_tree = build_multiscale_patch_index_map(4)
|
||||||
for i in range(900, len(ds)):
|
for i in range(500, len(ds)):
|
||||||
quadrant=2
|
quadrant=2
|
||||||
print(i)
|
print(i)
|
||||||
o = ds[i]
|
o = ds[random.randint(0, len(ds))]
|
||||||
k = 'HQ'
|
k = 'HQ'
|
||||||
v = o['HQ']
|
v = o['HQ']
|
||||||
#for j, img in enumerate(v):
|
#for j, img in enumerate(v):
|
||||||
# torchvision.utils.save_image(img.unsqueeze(0), "debug/%i_%s_%i.png" % (i, k, j))
|
# torchvision.utils.save_image(img.unsqueeze(0), "debug/%i_%s_%i.png" % (i, k, j))
|
||||||
torchvision.utils.save_image(v[0].unsqueeze(0), "debug/%i_%s_0.png" % (i, k))
|
tree_ind = random.randint(0, len(multiscale_tree))
|
||||||
map_tuple = multiscale_map[1][quadrant]
|
node = multiscale_tree[tree_ind]
|
||||||
while map_tuple[1]:
|
depth = 0
|
||||||
ind = map_tuple[0]
|
while node is not None:
|
||||||
torchvision.utils.save_image(v[ind].unsqueeze(0), "debug/%i_%s_%i.png" % (i, k, ind+1))
|
torchvision.utils.save_image(v[node.index].unsqueeze(0), "debug/%i_%s_%i.png" % (i, k, depth))
|
||||||
map_tuple = map_tuple[1][quadrant]
|
depth += 1
|
||||||
|
node = node.parent
|
Loading…
Reference in New Issue
Block a user