from torch.utils.data import Dataset import datasets class HfDataset(Dataset): """ Simple wrapper for a HuggingFace dataset that can re-map keys if desired. """ def __init__(self, corpi, cache_path=None, key_maps=None, dataset_spec_key='train'): self.hfd = [] for corpus in corpi: dataset_name, config = corpus if config == '' or config == 'None': config = None self.hfd.append(datasets.load_dataset(dataset_name, config, cache_dir=cache_path)[dataset_spec_key]) self.key_maps = key_maps def __getitem__(self, item): for dataset in self.hfd: if item < len(dataset): val = dataset[item] if self.key_maps is None: return val else: return {k: val[v] for k, v in self.key_maps.items()} else: item -= len(dataset) raise IndexError() def __len__(self): return sum([len(h) for h in self.hfd]) if __name__ == '__main__': d = HfDataset([['wikipedia', '20200501.en'], ['bookcorpus', '']], dataset_spec_key='train', cache_path='Z:\\huggingface_datasets\\cache') print(d[5])