From b3d0baaf174c291ca0471e39b81bceb0aa469fdd Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 9 Oct 2020 08:40:00 -0600 Subject: [PATCH] Improve multiframe dataset memory usage --- codes/data/base_unsupervised_image_dataset.py | 3 --- codes/data/chunk_with_reference.py | 20 +++---------------- codes/data/multi_frame_dataset.py | 3 +-- 3 files changed, 4 insertions(+), 22 deletions(-) diff --git a/codes/data/base_unsupervised_image_dataset.py b/codes/data/base_unsupervised_image_dataset.py index 1cf3db13..7b7fe10a 100644 --- a/codes/data/base_unsupervised_image_dataset.py +++ b/codes/data/base_unsupervised_image_dataset.py @@ -30,9 +30,6 @@ class BaseUnsupervisedImageDataset(data.Dataset): cache_path = os.path.join(path, 'cache.pth') if os.path.exists(cache_path): chunks = torch.load(cache_path) - # Update the options. - for c in chunks: - c.reload(opt) else: chunks = [ChunkWithReference(opt, d) for d in sorted(os.scandir(path), key=lambda e: e.name) if d.is_dir()] # Prune out chunks that have no images diff --git a/codes/data/chunk_with_reference.py b/codes/data/chunk_with_reference.py index 4da75f4d..dc437180 100644 --- a/codes/data/chunk_with_reference.py +++ b/codes/data/chunk_with_reference.py @@ -6,29 +6,15 @@ import numpy as np # Iterable that reads all the images in a directory that contains a reference image, tile images and center coordinates. class ChunkWithReference: def __init__(self, opt, path): - self.reload(opt) self.path = path.path self.tiles, _ = util.get_image_paths('img', self.path) - self.centers = None - - def reload(self, opt): - self.opt = opt - self.ref = None # This is loaded on the fly. - self.cache_ref = opt['cache_ref'] if 'cache_ref' in opt.keys() else False def __getitem__(self, item): - # Load centers on the fly and always cache. - if self.centers is None: - self.centers = torch.load(osp.join(self.path, "centers.pt")) - if self.cache_ref: - if self.ref is None: - self.ref = util.read_img(None, osp.join(self.path, "ref.jpg"), rgb=True) - ref = self.ref - else: - ref = util.read_img(None, osp.join(self.path, "ref.jpg"), rgb=True) + centers = torch.load(osp.join(self.path, "centers.pt")) + ref = util.read_img(None, osp.join(self.path, "ref.jpg"), rgb=True) tile = util.read_img(None, self.tiles[item], rgb=True) tile_id = int(osp.splitext(osp.basename(self.tiles[item]))[0]) - center, tile_width = self.centers[tile_id] + center, tile_width = centers[tile_id] mask = np.full(tile.shape[:2] + (1,), fill_value=.1, dtype=tile.dtype) mask[center[0] - tile_width // 2:center[0] + tile_width // 2, center[1] - tile_width // 2:center[1] + tile_width // 2] = 1 diff --git a/codes/data/multi_frame_dataset.py b/codes/data/multi_frame_dataset.py index 17cc43fb..111e1e17 100644 --- a/codes/data/multi_frame_dataset.py +++ b/codes/data/multi_frame_dataset.py @@ -56,8 +56,7 @@ class MultiFrameDataset(BaseUnsupervisedImageDataset): lq_ref = torch.cat([lq_ref, lq_mask], dim=1) return {'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref, - 'lq_center': torch.tensor(lcs, dtype=torch.long), 'gt_center': torch.tensor(hcs, dtype=torch.long), - 'LQ_path': path, 'GT_path': path} + 'lq_center': torch.tensor(lcs, dtype=torch.long), 'gt_center': torch.tensor(hcs, dtype=torch.long)} if __name__ == '__main__':