37 lines
1.2 KiB
Python
37 lines
1.2 KiB
Python
|
from torch.utils.data import Dataset
|
||
|
import datasets
|
||
|
|
||
|
|
||
|
class HfDataset(Dataset):
|
||
|
"""
|
||
|
Simple wrapper for a HuggingFace dataset that can re-map keys if desired.
|
||
|
"""
|
||
|
def __init__(self, corpi, cache_path=None, key_maps=None, dataset_spec_key='train'):
|
||
|
self.hfd = []
|
||
|
for corpus in corpi:
|
||
|
dataset_name, config = corpus
|
||
|
if config == '':
|
||
|
config = None
|
||
|
self.hfd.append(datasets.load_dataset(dataset_name, config, cache_dir=cache_path)[dataset_spec_key])
|
||
|
self.key_maps = key_maps
|
||
|
|
||
|
def __getitem__(self, item):
|
||
|
for dataset in self.hfd:
|
||
|
if item < len(dataset):
|
||
|
val = dataset[item]
|
||
|
if self.key_maps is None:
|
||
|
return val
|
||
|
else:
|
||
|
return {k: val[v] for k, v in self.key_maps.items()}
|
||
|
else:
|
||
|
item -= len(dataset)
|
||
|
raise IndexError()
|
||
|
|
||
|
def __len__(self):
|
||
|
return sum([len(h) for h in self.hfd])
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
d = HfDataset([['wikipedia', '20200501.en'], ['bookcorpus', '']], dataset_spec_key='train', cache_path='Z:\\huggingface_datasets\\cache')
|
||
|
print(d[5])
|