82 lines
2.6 KiB
Python
82 lines
2.6 KiB
Python
|
import os
|
||
|
import gzip
|
||
|
import numpy as np
|
||
|
from random import Random
|
||
|
from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Tuple, Union
|
||
|
import collections
|
||
|
from infinibatch import iterators
|
||
|
|
||
|
def apply_to_sample(f, sample):
|
||
|
if hasattr(sample, "__len__") and len(sample) == 0:
|
||
|
return {}
|
||
|
|
||
|
def _apply(x):
|
||
|
if isinstance(x, np.ndarray):
|
||
|
return f(x)
|
||
|
elif isinstance(x, collections.OrderedDict):
|
||
|
# OrderedDict has attributes that needs to be preserved
|
||
|
od = collections.OrderedDict((key, _apply(value)) for key, value in x.items())
|
||
|
od.__dict__ = x.__dict__
|
||
|
return od
|
||
|
elif isinstance(x, dict):
|
||
|
return {key: _apply(value) for key, value in x.items()}
|
||
|
elif isinstance(x, list):
|
||
|
return [_apply(x) for x in x]
|
||
|
elif isinstance(x, tuple):
|
||
|
return tuple(_apply(x) for x in x)
|
||
|
elif isinstance(x, set):
|
||
|
return {_apply(x) for x in x}
|
||
|
else:
|
||
|
return x
|
||
|
|
||
|
return _apply(sample)
|
||
|
|
||
|
class NativeCheckpointableIterator(iterators.CheckpointableIterator):
|
||
|
def __init__(self, iterable: Iterable):
|
||
|
self._input_iterable = iterable
|
||
|
self.setstate(None)
|
||
|
|
||
|
def getstate(self) -> Dict:
|
||
|
return {'num_items_yielded': self._num_items_yielded}
|
||
|
|
||
|
def setstate(self, checkpoint: Optional[Dict]):
|
||
|
self._iterator = iter(self._input_iterable)
|
||
|
self._num_items_yielded = iterators._advance_iterator(self._iterator, checkpoint['num_items_yielded']) if checkpoint is not None else 0
|
||
|
|
||
|
def __next__(self):
|
||
|
item = next(self._iterator)
|
||
|
self._num_items_yielded += 1
|
||
|
return item
|
||
|
|
||
|
def close(self):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class WeightIterator(object):
|
||
|
def __init__(self, weights, seed):
|
||
|
self.weights = weights
|
||
|
self.seed = seed
|
||
|
self.control_index = list(range(len(weights)))
|
||
|
self.setstate(None)
|
||
|
|
||
|
def __iter__(self):
|
||
|
return self
|
||
|
|
||
|
def getstate(self):
|
||
|
return {"random_state": self._random_state}
|
||
|
|
||
|
def setstate(self, checkpoint):
|
||
|
self._random_state = checkpoint["random_state"] if checkpoint else None
|
||
|
self._random = None # this will trigger the lazy initialization in self.__next__
|
||
|
|
||
|
def __next__(self):
|
||
|
if self._random is None:
|
||
|
self._random = Random(self.seed)
|
||
|
if self._random_state is not None:
|
||
|
self._random.setstate(self._random_state)
|
||
|
idx = self._random.choices(self.control_index, self.weights)[0]
|
||
|
self._random_state = self._random.getstate()
|
||
|
return idx
|
||
|
|
||
|
def close(self):
|
||
|
pass
|