48 lines
806 B
Python
48 lines
806 B
Python
|
"""
|
||
|
A sampler that balances data by key_fns.
|
||
|
|
||
|
MIT License
|
||
|
|
||
|
Copyright (c) 2023 Zhe Niu
|
||
|
|
||
|
niuzhe.nz@outlook.com
|
||
|
"""
|
||
|
|
||
|
import random
|
||
|
|
||
|
|
||
|
class Sampler:
|
||
|
def __init__(self, l, key_fns):
|
||
|
self.tree = self._build(l, key_fns)
|
||
|
|
||
|
def _build(self, l, key_fns) -> dict[dict, list]:
|
||
|
if not key_fns:
|
||
|
return l
|
||
|
|
||
|
tree = {}
|
||
|
|
||
|
key_fn, *key_fns = key_fns
|
||
|
|
||
|
for x in l:
|
||
|
k = key_fn(x)
|
||
|
|
||
|
if k in tree:
|
||
|
tree[k].append(x)
|
||
|
else:
|
||
|
tree[k] = [x]
|
||
|
|
||
|
for k in tree:
|
||
|
tree[k] = self._build(tree[k], key_fns)
|
||
|
|
||
|
return tree
|
||
|
|
||
|
def _sample(self, tree: dict | list):
|
||
|
if isinstance(tree, list):
|
||
|
ret = random.choice(tree)
|
||
|
else:
|
||
|
key = random.choice([*tree.keys()])
|
||
|
ret = self._sample(tree[key])
|
||
|
return ret
|
||
|
|
||
|
def sample(self):
|
||
|
return self._sample(self.tree)
|