vall-e/vall_e/utils/sampler.py

48 lines
806 B
Python
Raw Normal View History

2023-08-02 21:53:35 +00:00
"""
A sampler that balances data by key_fns.
MIT License
Copyright (c) 2023 Zhe Niu
niuzhe.nz@outlook.com
"""
import random
class Sampler:
def __init__(self, l, key_fns):
self.tree = self._build(l, key_fns)
def _build(self, l, key_fns) -> dict[dict, list]:
if not key_fns:
return l
tree = {}
key_fn, *key_fns = key_fns
for x in l:
k = key_fn(x)
if k in tree:
tree[k].append(x)
else:
tree[k] = [x]
for k in tree:
tree[k] = self._build(tree[k], key_fns)
return tree
def _sample(self, tree: dict | list):
if isinstance(tree, list):
ret = random.choice(tree)
else:
key = random.choice([*tree.keys()])
ret = self._sample(tree[key])
return ret
def sample(self):
return self._sample(self.tree)