From f9dc472f63fb7d7799d85e2926b8bf4729ca6dd3 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 22 Oct 2020 10:16:17 -0600 Subject: [PATCH] Misc nonfunctional mods to datasets --- codes/data/base_unsupervised_image_dataset.py | 1 + codes/data/multi_frame_dataset.py | 32 +++++++++++++------ 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/codes/data/base_unsupervised_image_dataset.py b/codes/data/base_unsupervised_image_dataset.py index fdd2284e..1b0e021b 100644 --- a/codes/data/base_unsupervised_image_dataset.py +++ b/codes/data/base_unsupervised_image_dataset.py @@ -31,6 +31,7 @@ class BaseUnsupervisedImageDataset(data.Dataset): if os.path.exists(cache_path): chunks = torch.load(cache_path) else: + print("Building chunk cache, this can take some time for large datasets..") chunks = [ChunkWithReference(opt, d) for d in sorted(os.scandir(path), key=lambda e: e.name) if d.is_dir()] # Prune out chunks that have no images res = [] diff --git a/codes/data/multi_frame_dataset.py b/codes/data/multi_frame_dataset.py index 111e1e17..2e9a5627 100644 --- a/codes/data/multi_frame_dataset.py +++ b/codes/data/multi_frame_dataset.py @@ -55,14 +55,14 @@ class MultiFrameDataset(BaseUnsupervisedImageDataset): lq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(lms))).unsqueeze(dim=1) lq_ref = torch.cat([lq_ref, lq_mask], dim=1) - return {'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref, + return {'GT_path': path, 'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref, 'lq_center': torch.tensor(lcs, dtype=torch.long), 'gt_center': torch.tensor(hcs, dtype=torch.long)} if __name__ == '__main__': opt = { 'name': 'amalgam', - 'paths': ['/content/fullvideo_256_tiled_test'], + 'paths': ['F:\\4k6k\\datasets\\images\\fullvideo\\ge_fv_256_tiled'], 'weights': [1], 'target_size': 128, 'force_multiple': 32, @@ -77,14 +77,28 @@ if __name__ == '__main__': ds = MultiFrameDataset(opt) import os os.makedirs("debug", exist_ok=True) - for i in [3]: + bs = 0 + batch = None + for i in range(len(ds)): import random - o = ds[i] - k = 'GT' - v = o[k] + k = 'LQ' + element = ds[random.randint(0,len(ds))] + base_file = osp.basename(element["GT_path"]) + o = element[k].unsqueeze(0) + if bs < 32: + if batch is None: + batch = o + else: + batch = torch.cat([batch, o], dim=0) + bs += 1 + continue + if 'path' not in k and 'center' not in k: - fr, f, h, w = v.shape + b, fr, f, h, w = batch.shape for j in range(fr): import torchvision - base=osp.basename(o["GT_path"]) - torchvision.utils.save_image(v[j].unsqueeze(0), "debug/%i_%s_%i__%s.png" % (i, k, j, base)) + base=osp.basename(base_file) + torchvision.utils.save_image(batch[:, j], "debug/%i_%s_%i__%s.png" % (i, k, j, base)) + + bs = 0 + batch = None