diff --git a/codes/data/__init__.py b/codes/data/__init__.py index 5d33a206..60bfc90a 100644 --- a/codes/data/__init__.py +++ b/codes/data/__init__.py @@ -55,6 +55,8 @@ def create_dataset(dataset_opt): from data.byol_attachment import DatasetRandomAugWrapper as D elif mode == 'random_dataset': from data.random_dataset import RandomDataset as D + elif mode == 'zipfile': + from data.zip_file_dataset import ZipFileDataset as D else: raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode)) dataset = D(dataset_opt) diff --git a/codes/data/zip_file_dataset.py b/codes/data/zip_file_dataset.py new file mode 100644 index 00000000..513391c3 --- /dev/null +++ b/codes/data/zip_file_dataset.py @@ -0,0 +1,64 @@ +import PIL.Image +import zipfile +import torch +import torchvision +from torch.utils.data import DataLoader +from torchvision.transforms import Compose, ToTensor, Normalize, Resize + + +class ZipFileDataset(torch.utils.data.Dataset): + def __init__(self, opt): + self.path = opt['path'] + zip = zipfile.ZipFile(self.path) + self.all_files = list(zip.namelist()) + self.resolution = opt['resolution'] + self.paired_mode = opt['paired_mode'] + self.transforms = Compose([ToTensor(), + Resize(self.resolution), + Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) + ]) + self.zip = None + + def __len__(self): + return len(self.all_files) + + # Loaded on the fly because ZipFile does not tolerate pickling. + def get_zip(self): + if self.zip is None: + self.zip = zipfile.ZipFile(self.path) + return self.zip + + def load_image(self, path): + file = self.get_zip().open(path, 'r') + pilimg = PIL.Image.open(file) + tensor = self.transforms(pilimg) + return tensor + + def __getitem__(self, i): + fname = self.all_files[i] + out = { + 'hq': self.load_image(fname), + 'HQ_path': fname, + 'has_alt': self.paired_mode + } + if self.paired_mode: + if fname.endswith('0.jpg'): + aname = fname.replace('0.jpg', '1.jpg') + else: + aname = fname.replace('1.jpg', '0.jpg') + out['alt_hq'] = self.load_image(aname) + return out + +if __name__ == '__main__': + opt = { + 'path': 'E:\\4k6k\\datasets\\images\\youtube-imagenet-paired\\output.zip', + 'resolution': 224, + 'paired_mode': True + } + dataset = ZipFileDataset(opt) + print(len(dataset)) + loader = DataLoader(dataset, shuffle=True) + for i, d in enumerate(loader): + torchvision.utils.save_image(d['hq'], f'{i}_hq.png') + torchvision.utils.save_image(d['alt_hq'], f'{i}_althq.png') +