diff --git a/codes/data/chunk_with_reference.py b/codes/data/chunk_with_reference.py index 177a0280..0393a67c 100644 --- a/codes/data/chunk_with_reference.py +++ b/codes/data/chunk_with_reference.py @@ -8,6 +8,7 @@ class ChunkWithReference: def __init__(self, opt, path): self.path = path.path self.tiles, _ = util.get_image_paths('img', self.path) + self.strict = opt['strict'] if 'strict' in opt.keys() else True if 'ignore_first' in opt.keys(): self.ignore = opt['ignore_first'] self.tiles = self.tiles[self.ignore:] @@ -22,11 +23,17 @@ class ChunkWithReference: return img def __getitem__(self, item): - centers = torch.load(osp.join(self.path, "centers.pt"))[self.ignore:] + centers = torch.load(osp.join(self.path, "centers.pt")) ref = self.read_image_or_get_zero(osp.join(self.path, "ref.jpg")) tile = self.read_image_or_get_zero(self.tiles[item]) tile_id = int(osp.splitext(osp.basename(self.tiles[item]))[0]) - center, tile_width = centers[tile_id] + if tile_id in centers.keys(): + center, tile_width = centers[tile_id] + elif self.strict: + raise FileNotFoundError(tile_id, self.tiles[item]) + else: + center = torch.tensor([128,128], dtype=torch.long) + tile_width = 256 mask = np.full(tile.shape[:2] + (1,), fill_value=.1, dtype=tile.dtype) mask[center[0] - tile_width // 2:center[0] + tile_width // 2, center[1] - tile_width // 2:center[1] + tile_width // 2] = 1 diff --git a/codes/scripts/extract_subimages_with_ref.py b/codes/scripts/extract_subimages_with_ref.py index 9c69318f..86f3070d 100644 --- a/codes/scripts/extract_subimages_with_ref.py +++ b/codes/scripts/extract_subimages_with_ref.py @@ -5,8 +5,6 @@ import numpy as np import cv2 from PIL import Image import data.util as data_util # noqa: E402 -import lmdb -import pyarrow import torch.utils.data as data from tqdm import tqdm import torch @@ -16,7 +14,7 @@ def main(): mode = 'single' # single (one input folder) | pair (extract corresponding GT and LR pairs) split_img = False opt = {} - opt['n_thread'] = 0 + opt['n_thread'] = 2 opt['compression_level'] = 90 # JPEG compression quality rating. # CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer # compression time. If read raw images during training, use 0 for faster IO speed. @@ -244,7 +242,7 @@ class TiledDataset(data.Dataset): h, w, c = img.shape # Uncomment to filter any image that doesnt meet a threshold size. - if min(h,w) < 1024: + if min(h,w) < 512: return None left = 0 right = w