DL-Art-School/codes/data/text/hf_datasets_wrapper.py
2021-12-23 14:32:33 -07:00

37 lines
1.2 KiB
Python

from torch.utils.data import Dataset
import datasets
class HfDataset(Dataset):
"""
Simple wrapper for a HuggingFace dataset that can re-map keys if desired.
"""
def __init__(self, corpi, cache_path=None, key_maps=None, dataset_spec_key='train'):
self.hfd = []
for corpus in corpi:
dataset_name, config = corpus
if config == '':
config = None
self.hfd.append(datasets.load_dataset(dataset_name, config, cache_dir=cache_path)[dataset_spec_key])
self.key_maps = key_maps
def __getitem__(self, item):
for dataset in self.hfd:
if item < len(dataset):
val = dataset[item]
if self.key_maps is None:
return val
else:
return {k: val[v] for k, v in self.key_maps.items()}
else:
item -= len(dataset)
raise IndexError()
def __len__(self):
return sum([len(h) for h in self.hfd])
if __name__ == '__main__':
d = HfDataset([['wikipedia', '20200501.en'], ['bookcorpus', '']], dataset_spec_key='train', cache_path='Z:\\huggingface_datasets\\cache')
print(d[5])