forked from mrq/DL-Art-School
Improve multiframe dataset memory usage
This commit is contained in:
parent
afe6af88af
commit
b3d0baaf17
|
@ -30,9 +30,6 @@ class BaseUnsupervisedImageDataset(data.Dataset):
|
||||||
cache_path = os.path.join(path, 'cache.pth')
|
cache_path = os.path.join(path, 'cache.pth')
|
||||||
if os.path.exists(cache_path):
|
if os.path.exists(cache_path):
|
||||||
chunks = torch.load(cache_path)
|
chunks = torch.load(cache_path)
|
||||||
# Update the options.
|
|
||||||
for c in chunks:
|
|
||||||
c.reload(opt)
|
|
||||||
else:
|
else:
|
||||||
chunks = [ChunkWithReference(opt, d) for d in sorted(os.scandir(path), key=lambda e: e.name) if d.is_dir()]
|
chunks = [ChunkWithReference(opt, d) for d in sorted(os.scandir(path), key=lambda e: e.name) if d.is_dir()]
|
||||||
# Prune out chunks that have no images
|
# Prune out chunks that have no images
|
||||||
|
|
|
@ -6,29 +6,15 @@ import numpy as np
|
||||||
# Iterable that reads all the images in a directory that contains a reference image, tile images and center coordinates.
|
# Iterable that reads all the images in a directory that contains a reference image, tile images and center coordinates.
|
||||||
class ChunkWithReference:
|
class ChunkWithReference:
|
||||||
def __init__(self, opt, path):
|
def __init__(self, opt, path):
|
||||||
self.reload(opt)
|
|
||||||
self.path = path.path
|
self.path = path.path
|
||||||
self.tiles, _ = util.get_image_paths('img', self.path)
|
self.tiles, _ = util.get_image_paths('img', self.path)
|
||||||
self.centers = None
|
|
||||||
|
|
||||||
def reload(self, opt):
|
|
||||||
self.opt = opt
|
|
||||||
self.ref = None # This is loaded on the fly.
|
|
||||||
self.cache_ref = opt['cache_ref'] if 'cache_ref' in opt.keys() else False
|
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
# Load centers on the fly and always cache.
|
centers = torch.load(osp.join(self.path, "centers.pt"))
|
||||||
if self.centers is None:
|
|
||||||
self.centers = torch.load(osp.join(self.path, "centers.pt"))
|
|
||||||
if self.cache_ref:
|
|
||||||
if self.ref is None:
|
|
||||||
self.ref = util.read_img(None, osp.join(self.path, "ref.jpg"), rgb=True)
|
|
||||||
ref = self.ref
|
|
||||||
else:
|
|
||||||
ref = util.read_img(None, osp.join(self.path, "ref.jpg"), rgb=True)
|
ref = util.read_img(None, osp.join(self.path, "ref.jpg"), rgb=True)
|
||||||
tile = util.read_img(None, self.tiles[item], rgb=True)
|
tile = util.read_img(None, self.tiles[item], rgb=True)
|
||||||
tile_id = int(osp.splitext(osp.basename(self.tiles[item]))[0])
|
tile_id = int(osp.splitext(osp.basename(self.tiles[item]))[0])
|
||||||
center, tile_width = self.centers[tile_id]
|
center, tile_width = centers[tile_id]
|
||||||
mask = np.full(tile.shape[:2] + (1,), fill_value=.1, dtype=tile.dtype)
|
mask = np.full(tile.shape[:2] + (1,), fill_value=.1, dtype=tile.dtype)
|
||||||
mask[center[0] - tile_width // 2:center[0] + tile_width // 2, center[1] - tile_width // 2:center[1] + tile_width // 2] = 1
|
mask[center[0] - tile_width // 2:center[0] + tile_width // 2, center[1] - tile_width // 2:center[1] + tile_width // 2] = 1
|
||||||
|
|
||||||
|
|
|
@ -56,8 +56,7 @@ class MultiFrameDataset(BaseUnsupervisedImageDataset):
|
||||||
lq_ref = torch.cat([lq_ref, lq_mask], dim=1)
|
lq_ref = torch.cat([lq_ref, lq_mask], dim=1)
|
||||||
|
|
||||||
return {'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
|
return {'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
|
||||||
'lq_center': torch.tensor(lcs, dtype=torch.long), 'gt_center': torch.tensor(hcs, dtype=torch.long),
|
'lq_center': torch.tensor(lcs, dtype=torch.long), 'gt_center': torch.tensor(hcs, dtype=torch.long)}
|
||||||
'LQ_path': path, 'GT_path': path}
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
Loading…
Reference in New Issue
Block a user