From 66cbae873142beaa11dbb2a89fa29ffc825d301f Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 9 Dec 2020 14:55:05 -0700 Subject: [PATCH] Add random_dataset for testing --- codes/data/__init__.py | 2 ++ codes/data/random_dataset.py | 17 +++++++++++++++++ 2 files changed, 19 insertions(+) create mode 100644 codes/data/random_dataset.py diff --git a/codes/data/__init__.py b/codes/data/__init__.py index ab6797a4..c62c3700 100644 --- a/codes/data/__init__.py +++ b/codes/data/__init__.py @@ -49,6 +49,8 @@ def create_dataset(dataset_opt): from data.torch_dataset import TorchDataset as D elif mode == 'byol_dataset': from data.byol_attachment import ByolDatasetWrapper as D + elif mode == 'random_dataset': + from data.random_dataset import RandomDataset as D else: raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode)) dataset = D(dataset_opt) diff --git a/codes/data/random_dataset.py b/codes/data/random_dataset.py new file mode 100644 index 00000000..8e140061 --- /dev/null +++ b/codes/data/random_dataset.py @@ -0,0 +1,17 @@ +import torch +from torch.utils.data import Dataset + + +# Dataset that feeds random data into the state. Can be useful for testing or demo purposes without actual data. +class RandomDataset(Dataset): + def __init__(self, opt): + self.hq_shape = tuple(opt['hq_shape']) + self.lq_shape = tuple(opt['lq_shape']) + + def __getitem__(self, item): + return {'lq': torch.rand(self.lq_shape), 'hq': torch.rand(self.hq_shape), + 'LQ_path': '', 'GT_path': ''} + + def __len__(self): + # Arbitrary + return 1024 * 1024