2021-12-23 21:32:33 +00:00
|
|
|
from torch.utils.data import Dataset
|
|
|
|
import datasets
|
|
|
|
|
|
|
|
|
|
|
|
class HfDataset(Dataset):
|
|
|
|
"""
|
|
|
|
Simple wrapper for a HuggingFace dataset that can re-map keys if desired.
|
|
|
|
"""
|
|
|
|
def __init__(self, corpi, cache_path=None, key_maps=None, dataset_spec_key='train'):
|
|
|
|
self.hfd = []
|
|
|
|
for corpus in corpi:
|
|
|
|
dataset_name, config = corpus
|
2021-12-23 22:21:30 +00:00
|
|
|
if config == '' or config == 'None':
|
2021-12-23 21:32:33 +00:00
|
|
|
config = None
|
|
|
|
self.hfd.append(datasets.load_dataset(dataset_name, config, cache_dir=cache_path)[dataset_spec_key])
|
|
|
|
self.key_maps = key_maps
|
|
|
|
|
|
|
|
def __getitem__(self, item):
|
|
|
|
for dataset in self.hfd:
|
|
|
|
if item < len(dataset):
|
|
|
|
val = dataset[item]
|
|
|
|
if self.key_maps is None:
|
|
|
|
return val
|
|
|
|
else:
|
|
|
|
return {k: val[v] for k, v in self.key_maps.items()}
|
|
|
|
else:
|
|
|
|
item -= len(dataset)
|
|
|
|
raise IndexError()
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return sum([len(h) for h in self.hfd])
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
d = HfDataset([['wikipedia', '20200501.en'], ['bookcorpus', '']], dataset_spec_key='train', cache_path='Z:\\huggingface_datasets\\cache')
|
|
|
|
print(d[5])
|