DL-Art-School/codes/data/chunk_with_reference.py

39 lines
1.5 KiB
Python
Raw Normal View History

2020-09-18 15:49:13 +00:00
import os.path as osp
from data import util
import torch
import numpy as np
2020-09-18 15:49:13 +00:00
# Iterable that reads all the images in a directory that contains a reference image, tile images and center coordinates.
class ChunkWithReference:
def __init__(self, opt, path):
2020-09-26 04:19:38 +00:00
self.reload(opt)
self.path = path.path
self.tiles, _ = util.get_image_paths('img', path)
self.centers = None
2020-09-18 15:49:13 +00:00
2020-09-26 04:19:38 +00:00
def reload(self, opt):
self.opt = opt
self.ref = None # This is loaded on the fly.
self.cache_ref = opt['cache_ref'] if 'cache_ref' in opt.keys() else False
2020-09-18 15:49:13 +00:00
def __getitem__(self, item):
# Load centers on the fly and always cache.
if self.centers is None:
self.centers = torch.load(osp.join(self.path, "centers.pt"))
2020-09-18 15:49:13 +00:00
if self.cache_ref:
if self.ref is None:
self.ref = util.read_img(None, osp.join(self.path, "ref.jpg"), rgb=True)
2020-09-18 15:49:13 +00:00
ref = self.ref
else:
2020-09-26 04:19:38 +00:00
ref = util.read_img(None, osp.join(self.path, "ref.jpg"), rgb=True)
tile = util.read_img(None, self.tiles[item], rgb=True)
tile_id = int(osp.splitext(osp.basename(self.tiles[item]))[0])
center, tile_width = self.centers[tile_id]
mask = np.full(tile.shape[:2] + (1,), fill_value=.1, dtype=tile.dtype)
mask[center[0] - tile_width // 2:center[0] + tile_width // 2, center[1] - tile_width // 2:center[1] + tile_width // 2] = 1
2020-09-18 15:49:13 +00:00
return tile, ref, center, mask, self.tiles[item]
2020-09-18 15:49:13 +00:00
def __len__(self):
return len(self.tiles)