diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 555c6e56..9f92f274 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -146,39 +146,6 @@ class PersonalizedBase(Dataset): return entry -class GroupedBatchSampler(Sampler): - def __init__(self, data_source: PersonalizedBase, batch_size: int): - super().__init__(data_source) - - n = len(data_source) - self.groups = data_source.groups - self.len = n_batch = n // batch_size - expected = [len(g) / n * n_batch * batch_size for g in data_source.groups] - self.base = [int(e) // batch_size for e in expected] - self.n_rand_batches = nrb = n_batch - sum(self.base) - self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected] - self.batch_size = batch_size - - def __len__(self): - return self.len - - def __iter__(self): - b = self.batch_size - - for g in self.groups: - random.shuffle(g) - - batches = [] - for g in self.groups: - batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b)) - for _ in range(self.n_rand_batches): - rand_group = random.choices(self.groups, self.probs)[0] - batches.append(random.choices(rand_group, k=b)) - - random.shuffle(batches) - - yield from batches - def greedy_pack(tails, b): '''inefficient suboptimal packing of remainders from each bucket''' by_len = defaultdict(list)