torchscale/examples/fairseq/tasks/data/utils.py
2022-11-26 08:10:15 -08:00

89 lines
2.7 KiB
Python

# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
import numpy as np
from random import Random
from typing import Dict, Iterable, Optional
import collections
from infinibatch import iterators
def apply_to_sample(f, sample):
if hasattr(sample, "__len__") and len(sample) == 0:
return {}
def _apply(x):
if isinstance(x, np.ndarray):
return f(x)
elif isinstance(x, collections.OrderedDict):
# OrderedDict has attributes that needs to be preserved
od = collections.OrderedDict((key, _apply(value)) for key, value in x.items())
od.__dict__ = x.__dict__
return od
elif isinstance(x, dict):
return {key: _apply(value) for key, value in x.items()}
elif isinstance(x, list):
return [_apply(x) for x in x]
elif isinstance(x, tuple):
return tuple(_apply(x) for x in x)
elif isinstance(x, set):
return {_apply(x) for x in x}
else:
return x
return _apply(sample)
class NativeCheckpointableIterator(iterators.CheckpointableIterator):
def __init__(self, iterable: Iterable):
self._input_iterable = iterable
self.setstate(None)
def getstate(self) -> Dict:
return {'num_items_yielded': self._num_items_yielded}
def setstate(self, checkpoint: Optional[Dict]):
self._iterator = iter(self._input_iterable)
self._num_items_yielded = iterators._advance_iterator(
self._iterator,
checkpoint['num_items_yielded']
) if checkpoint is not None else 0
def __next__(self):
item = next(self._iterator)
self._num_items_yielded += 1
return item
def close(self):
pass
class WeightIterator(object):
def __init__(self, weights, seed):
self.weights = weights
self.seed = seed
self.control_index = list(range(len(weights)))
self.setstate(None)
def __iter__(self):
return self
def getstate(self):
return {"random_state": self._random_state}
def setstate(self, checkpoint):
self._random_state = checkpoint["random_state"] if checkpoint else None
self._random = None # this will trigger the lazy initialization in self.__next__
def __next__(self):
if self._random is None:
self._random = Random(self.seed)
if self._random_state is not None:
self._random.setstate(self._random_state)
idx = self._random.choices(self.control_index, self.weights)[0]
self._random_state = self._random.getstate()
return idx
def close(self):
pass