torchscale/examples/fairseq/tasks/data/utils.py

95 lines
2.7 KiB
Python
Raw Normal View History

2022-11-23 16:36:55 +00:00
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
2022-11-26 17:01:02 +00:00
import collections
2022-11-23 16:21:58 +00:00
from random import Random
2022-11-26 16:10:15 +00:00
from typing import Dict, Iterable, Optional
2022-11-26 17:01:02 +00:00
import numpy as np
2022-11-23 16:21:58 +00:00
from infinibatch import iterators
2022-11-26 16:10:15 +00:00
2022-11-23 16:21:58 +00:00
def apply_to_sample(f, sample):
if hasattr(sample, "__len__") and len(sample) == 0:
return {}
def _apply(x):
if isinstance(x, np.ndarray):
return f(x)
elif isinstance(x, collections.OrderedDict):
# OrderedDict has attributes that needs to be preserved
2022-11-26 17:01:02 +00:00
od = collections.OrderedDict(
(key, _apply(value)) for key, value in x.items()
)
2022-11-23 16:21:58 +00:00
od.__dict__ = x.__dict__
return od
elif isinstance(x, dict):
return {key: _apply(value) for key, value in x.items()}
elif isinstance(x, list):
return [_apply(x) for x in x]
elif isinstance(x, tuple):
return tuple(_apply(x) for x in x)
elif isinstance(x, set):
return {_apply(x) for x in x}
else:
return x
return _apply(sample)
2022-11-26 16:10:15 +00:00
2022-11-23 16:21:58 +00:00
class NativeCheckpointableIterator(iterators.CheckpointableIterator):
def __init__(self, iterable: Iterable):
self._input_iterable = iterable
self.setstate(None)
def getstate(self) -> Dict:
2022-11-26 17:01:02 +00:00
return {"num_items_yielded": self._num_items_yielded}
2022-11-23 16:21:58 +00:00
def setstate(self, checkpoint: Optional[Dict]):
self._iterator = iter(self._input_iterable)
2022-11-26 17:01:02 +00:00
self._num_items_yielded = (
iterators._advance_iterator(self._iterator, checkpoint["num_items_yielded"])
if checkpoint is not None
else 0
)
2022-11-23 16:21:58 +00:00
def __next__(self):
item = next(self._iterator)
self._num_items_yielded += 1
return item
2022-11-26 16:10:15 +00:00
2022-11-23 16:21:58 +00:00
def close(self):
pass
class WeightIterator(object):
def __init__(self, weights, seed):
self.weights = weights
self.seed = seed
self.control_index = list(range(len(weights)))
self.setstate(None)
2022-11-26 16:10:15 +00:00
2022-11-23 16:21:58 +00:00
def __iter__(self):
return self
2022-11-26 16:10:15 +00:00
2022-11-23 16:21:58 +00:00
def getstate(self):
return {"random_state": self._random_state}
def setstate(self, checkpoint):
self._random_state = checkpoint["random_state"] if checkpoint else None
2022-11-26 17:01:02 +00:00
self._random = (
None # this will trigger the lazy initialization in self.__next__
)
2022-11-26 16:10:15 +00:00
2022-11-23 16:21:58 +00:00
def __next__(self):
if self._random is None:
self._random = Random(self.seed)
if self._random_state is not None:
self._random.setstate(self._random_state)
idx = self._random.choices(self.control_index, self.weights)[0]
self._random_state = self._random.getstate()
return idx
2022-11-26 16:10:15 +00:00
2022-11-23 16:21:58 +00:00
def close(self):
2022-11-26 16:10:15 +00:00
pass