Remove unused code
This commit is contained in:
parent
f4f070a548
commit
a7c92bd5ec
|
@ -146,39 +146,6 @@ class PersonalizedBase(Dataset):
|
|||
return entry
|
||||
|
||||
|
||||
class GroupedBatchSampler(Sampler):
|
||||
def __init__(self, data_source: PersonalizedBase, batch_size: int):
|
||||
super().__init__(data_source)
|
||||
|
||||
n = len(data_source)
|
||||
self.groups = data_source.groups
|
||||
self.len = n_batch = n // batch_size
|
||||
expected = [len(g) / n * n_batch * batch_size for g in data_source.groups]
|
||||
self.base = [int(e) // batch_size for e in expected]
|
||||
self.n_rand_batches = nrb = n_batch - sum(self.base)
|
||||
self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]
|
||||
self.batch_size = batch_size
|
||||
|
||||
def __len__(self):
|
||||
return self.len
|
||||
|
||||
def __iter__(self):
|
||||
b = self.batch_size
|
||||
|
||||
for g in self.groups:
|
||||
random.shuffle(g)
|
||||
|
||||
batches = []
|
||||
for g in self.groups:
|
||||
batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
|
||||
for _ in range(self.n_rand_batches):
|
||||
rand_group = random.choices(self.groups, self.probs)[0]
|
||||
batches.append(random.choices(rand_group, k=b))
|
||||
|
||||
random.shuffle(batches)
|
||||
|
||||
yield from batches
|
||||
|
||||
def greedy_pack(tails, b):
|
||||
'''inefficient suboptimal packing of remainders from each bucket'''
|
||||
by_len = defaultdict(list)
|
||||
|
|
Loading…
Reference in New Issue
Block a user