import PIL.Image import zipfile import torch import torchvision from torch.utils.data import DataLoader from torchvision.transforms import Compose, ToTensor, Normalize, Resize class ZipFileDataset(torch.utils.data.Dataset): def __init__(self, opt): self.path = opt['path'] zip = zipfile.ZipFile(self.path) self.all_files = list(zip.namelist()) self.resolution = opt['resolution'] self.paired_mode = opt['paired_mode'] self.transforms = Compose([ToTensor(), Resize(self.resolution), Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) self.zip = None def __len__(self): return len(self.all_files) # Loaded on the fly because ZipFile does not tolerate pickling. def get_zip(self): if self.zip is None: self.zip = zipfile.ZipFile(self.path) return self.zip def load_image(self, path): file = self.get_zip().open(path, 'r') pilimg = PIL.Image.open(file) tensor = self.transforms(pilimg) return tensor def __getitem__(self, i): try: fname = self.all_files[i] out = { 'hq': self.load_image(fname), 'HQ_path': fname, 'has_alt': self.paired_mode } if self.paired_mode: if fname.endswith('0.jpg'): aname = fname.replace('0.jpg', '1.jpg') else: aname = fname.replace('1.jpg', '0.jpg') out['alt_hq'] = self.load_image(aname) except: print(f"Error loading {fname} from zipfile. Attempting to recover by loading next element.") return self[i+1] return out if __name__ == '__main__': opt = { 'path': 'E:\\4k6k\\datasets\\images\\youtube-imagenet-paired\\output.zip', 'resolution': 224, 'paired_mode': True } dataset = ZipFileDataset(opt) print(len(dataset)) loader = DataLoader(dataset, shuffle=True) for i, d in enumerate(loader): torchvision.utils.save_image(d['hq'], f'{i}_hq.png') torchvision.utils.save_image(d['alt_hq'], f'{i}_althq.png')