# Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] import os import gzip import numpy as np from random import Random from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Tuple, Union import collections from infinibatch import iterators def apply_to_sample(f, sample): if hasattr(sample, "__len__") and len(sample) == 0: return {} def _apply(x): if isinstance(x, np.ndarray): return f(x) elif isinstance(x, collections.OrderedDict): # OrderedDict has attributes that needs to be preserved od = collections.OrderedDict((key, _apply(value)) for key, value in x.items()) od.__dict__ = x.__dict__ return od elif isinstance(x, dict): return {key: _apply(value) for key, value in x.items()} elif isinstance(x, list): return [_apply(x) for x in x] elif isinstance(x, tuple): return tuple(_apply(x) for x in x) elif isinstance(x, set): return {_apply(x) for x in x} else: return x return _apply(sample) class NativeCheckpointableIterator(iterators.CheckpointableIterator): def __init__(self, iterable: Iterable): self._input_iterable = iterable self.setstate(None) def getstate(self) -> Dict: return {'num_items_yielded': self._num_items_yielded} def setstate(self, checkpoint: Optional[Dict]): self._iterator = iter(self._input_iterable) self._num_items_yielded = iterators._advance_iterator(self._iterator, checkpoint['num_items_yielded']) if checkpoint is not None else 0 def __next__(self): item = next(self._iterator) self._num_items_yielded += 1 return item def close(self): pass class WeightIterator(object): def __init__(self, weights, seed): self.weights = weights self.seed = seed self.control_index = list(range(len(weights))) self.setstate(None) def __iter__(self): return self def getstate(self): return {"random_state": self._random_state} def setstate(self, checkpoint): self._random_state = checkpoint["random_state"] if checkpoint else None self._random = None # this will trigger the lazy initialization in self.__next__ def __next__(self): if self._random is None: self._random = Random(self.seed) if self._random_state is not None: self._random.setstate(self._random_state) idx = self._random.choices(self.control_index, self.weights)[0] self._random_state = self._random.getstate() return idx def close(self): pass