import torch from torch.utils.data import Dataset import torchvision.transforms as T from torchvision import datasets # Wrapper for basic pytorch datasets which re-wraps them into a format usable by ExtensibleTrainer. class TorchDataset(Dataset): def __init__(self, opt): DATASET_MAP = { "mnist": datasets.MNIST, "fmnist": datasets.FashionMNIST, "cifar10": datasets.CIFAR10, } transforms = [] if opt['flip']: transforms.append(T.RandomHorizontalFlip()) if opt['crop_sz']: transforms.append(T.RandomCrop(opt['crop_sz'], padding=opt['padding'], padding_mode="reflect")) transforms.append(T.ToTensor()) transforms = T.Compose(transforms) is_for_training = opt['test'] if 'test' in opt.keys() else True self.dataset = DATASET_MAP[opt['dataset']](opt['datapath'], train=is_for_training, download=True, transform=transforms) self.len = opt['fixed_len'] if 'fixed_len' in opt.keys() else len(self.dataset) def __getitem__(self, item): underlying_item = self.dataset[item][0] return {'lq': underlying_item, 'hq': underlying_item, 'LQ_path': str(item), 'GT_path': str(item)} def __len__(self): return self.len