Remove unused code

This commit is contained in:
dan 2023-01-23 23:18:07 +08:00
parent f4f070a548
commit a7c92bd5ec

View File

@ -146,39 +146,6 @@ class PersonalizedBase(Dataset):
return entry
class GroupedBatchSampler(Sampler):
def __init__(self, data_source: PersonalizedBase, batch_size: int):
super().__init__(data_source)
n = len(data_source)
self.groups = data_source.groups
self.len = n_batch = n // batch_size
expected = [len(g) / n * n_batch * batch_size for g in data_source.groups]
self.base = [int(e) // batch_size for e in expected]
self.n_rand_batches = nrb = n_batch - sum(self.base)
self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]
self.batch_size = batch_size
def __len__(self):
return self.len
def __iter__(self):
b = self.batch_size
for g in self.groups:
random.shuffle(g)
batches = []
for g in self.groups:
batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
for _ in range(self.n_rand_batches):
rand_group = random.choices(self.groups, self.probs)[0]
batches.append(random.choices(rand_group, k=b))
random.shuffle(batches)
yield from batches
def greedy_pack(tails, b):
'''inefficient suboptimal packing of remainders from each bucket'''
by_len = defaultdict(list)