Remove unused code
This commit is contained in:
parent
f4f070a548
commit
a7c92bd5ec
|
@ -146,39 +146,6 @@ class PersonalizedBase(Dataset):
|
||||||
return entry
|
return entry
|
||||||
|
|
||||||
|
|
||||||
class GroupedBatchSampler(Sampler):
|
|
||||||
def __init__(self, data_source: PersonalizedBase, batch_size: int):
|
|
||||||
super().__init__(data_source)
|
|
||||||
|
|
||||||
n = len(data_source)
|
|
||||||
self.groups = data_source.groups
|
|
||||||
self.len = n_batch = n // batch_size
|
|
||||||
expected = [len(g) / n * n_batch * batch_size for g in data_source.groups]
|
|
||||||
self.base = [int(e) // batch_size for e in expected]
|
|
||||||
self.n_rand_batches = nrb = n_batch - sum(self.base)
|
|
||||||
self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]
|
|
||||||
self.batch_size = batch_size
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.len
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
b = self.batch_size
|
|
||||||
|
|
||||||
for g in self.groups:
|
|
||||||
random.shuffle(g)
|
|
||||||
|
|
||||||
batches = []
|
|
||||||
for g in self.groups:
|
|
||||||
batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
|
|
||||||
for _ in range(self.n_rand_batches):
|
|
||||||
rand_group = random.choices(self.groups, self.probs)[0]
|
|
||||||
batches.append(random.choices(rand_group, k=b))
|
|
||||||
|
|
||||||
random.shuffle(batches)
|
|
||||||
|
|
||||||
yield from batches
|
|
||||||
|
|
||||||
def greedy_pack(tails, b):
|
def greedy_pack(tails, b):
|
||||||
'''inefficient suboptimal packing of remainders from each bucket'''
|
'''inefficient suboptimal packing of remainders from each bucket'''
|
||||||
by_len = defaultdict(list)
|
by_len = defaultdict(list)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user