forked from mrq/DL-Art-School
Add zipfilesdataset
This commit is contained in:
parent
1a2b9fa130
commit
6649ef2dae
|
@ -55,6 +55,8 @@ def create_dataset(dataset_opt):
|
||||||
from data.byol_attachment import DatasetRandomAugWrapper as D
|
from data.byol_attachment import DatasetRandomAugWrapper as D
|
||||||
elif mode == 'random_dataset':
|
elif mode == 'random_dataset':
|
||||||
from data.random_dataset import RandomDataset as D
|
from data.random_dataset import RandomDataset as D
|
||||||
|
elif mode == 'zipfile':
|
||||||
|
from data.zip_file_dataset import ZipFileDataset as D
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
|
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
|
||||||
dataset = D(dataset_opt)
|
dataset = D(dataset_opt)
|
||||||
|
|
64
codes/data/zip_file_dataset.py
Normal file
64
codes/data/zip_file_dataset.py
Normal file
|
@ -0,0 +1,64 @@
|
||||||
|
import PIL.Image
|
||||||
|
import zipfile
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torchvision.transforms import Compose, ToTensor, Normalize, Resize
|
||||||
|
|
||||||
|
|
||||||
|
class ZipFileDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(self, opt):
|
||||||
|
self.path = opt['path']
|
||||||
|
zip = zipfile.ZipFile(self.path)
|
||||||
|
self.all_files = list(zip.namelist())
|
||||||
|
self.resolution = opt['resolution']
|
||||||
|
self.paired_mode = opt['paired_mode']
|
||||||
|
self.transforms = Compose([ToTensor(),
|
||||||
|
Resize(self.resolution),
|
||||||
|
Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
|
||||||
|
])
|
||||||
|
self.zip = None
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.all_files)
|
||||||
|
|
||||||
|
# Loaded on the fly because ZipFile does not tolerate pickling.
|
||||||
|
def get_zip(self):
|
||||||
|
if self.zip is None:
|
||||||
|
self.zip = zipfile.ZipFile(self.path)
|
||||||
|
return self.zip
|
||||||
|
|
||||||
|
def load_image(self, path):
|
||||||
|
file = self.get_zip().open(path, 'r')
|
||||||
|
pilimg = PIL.Image.open(file)
|
||||||
|
tensor = self.transforms(pilimg)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
fname = self.all_files[i]
|
||||||
|
out = {
|
||||||
|
'hq': self.load_image(fname),
|
||||||
|
'HQ_path': fname,
|
||||||
|
'has_alt': self.paired_mode
|
||||||
|
}
|
||||||
|
if self.paired_mode:
|
||||||
|
if fname.endswith('0.jpg'):
|
||||||
|
aname = fname.replace('0.jpg', '1.jpg')
|
||||||
|
else:
|
||||||
|
aname = fname.replace('1.jpg', '0.jpg')
|
||||||
|
out['alt_hq'] = self.load_image(aname)
|
||||||
|
return out
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
opt = {
|
||||||
|
'path': 'E:\\4k6k\\datasets\\images\\youtube-imagenet-paired\\output.zip',
|
||||||
|
'resolution': 224,
|
||||||
|
'paired_mode': True
|
||||||
|
}
|
||||||
|
dataset = ZipFileDataset(opt)
|
||||||
|
print(len(dataset))
|
||||||
|
loader = DataLoader(dataset, shuffle=True)
|
||||||
|
for i, d in enumerate(loader):
|
||||||
|
torchvision.utils.save_image(d['hq'], f'{i}_hq.png')
|
||||||
|
torchvision.utils.save_image(d['alt_hq'], f'{i}_althq.png')
|
||||||
|
|
Loading…
Reference in New Issue
Block a user