2020-09-18 15:49:13 +00:00
|
|
|
import os.path as osp
|
|
|
|
from data import util
|
|
|
|
import torch
|
2020-09-25 22:37:54 +00:00
|
|
|
import numpy as np
|
2020-09-18 15:49:13 +00:00
|
|
|
|
|
|
|
# Iterable that reads all the images in a directory that contains a reference image, tile images and center coordinates.
|
|
|
|
class ChunkWithReference:
|
|
|
|
def __init__(self, opt, path):
|
2020-09-25 22:37:54 +00:00
|
|
|
self.path = path.path
|
2020-09-26 04:45:57 +00:00
|
|
|
self.tiles, _ = util.get_image_paths('img', self.path)
|
2020-09-26 04:19:38 +00:00
|
|
|
|
2020-10-11 14:33:18 +00:00
|
|
|
# Odd failures occur at times. Rather than crashing, report the error and just return zeros.
|
|
|
|
def read_image_or_get_zero(self, img_path):
|
|
|
|
img = util.read_img(None, img_path, rgb=True)
|
|
|
|
if img is None:
|
|
|
|
return np.zeros(128, 128, 3)
|
|
|
|
|
2020-09-18 15:49:13 +00:00
|
|
|
def __getitem__(self, item):
|
2020-10-09 14:40:00 +00:00
|
|
|
centers = torch.load(osp.join(self.path, "centers.pt"))
|
2020-10-11 14:33:18 +00:00
|
|
|
ref = self.read_image_or_get_zero(osp.join(self.path, "ref.jpg"))
|
|
|
|
tile = self.read_image_or_get_zero(self.tiles[item])
|
2020-09-25 22:37:54 +00:00
|
|
|
tile_id = int(osp.splitext(osp.basename(self.tiles[item]))[0])
|
2020-10-09 14:40:00 +00:00
|
|
|
center, tile_width = centers[tile_id]
|
2020-09-25 22:37:54 +00:00
|
|
|
mask = np.full(tile.shape[:2] + (1,), fill_value=.1, dtype=tile.dtype)
|
|
|
|
mask[center[0] - tile_width // 2:center[0] + tile_width // 2, center[1] - tile_width // 2:center[1] + tile_width // 2] = 1
|
2020-09-18 15:49:13 +00:00
|
|
|
|
2020-09-25 22:37:54 +00:00
|
|
|
return tile, ref, center, mask, self.tiles[item]
|
2020-09-18 15:49:13 +00:00
|
|
|
|
|
|
|
def __len__(self):
|
2020-09-25 22:37:54 +00:00
|
|
|
return len(self.tiles)
|