diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 74e78582..2526c4ae 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -582,7 +582,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi break if shared.state.interrupted: break - for j, batch in enumerate(dl): + for j, superbatch in enumerate(modules.textual_inversion.dataset.group_batches(dl, batch_size)): # works as a drop_last=True for gradient accumulation if j == max_steps_per_epoch: break @@ -596,18 +596,19 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi clip_grad_sched.step(hypernetwork.step) with devices.autocast(): - x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) - if tag_drop_out != 0 or shuffle_tags: - shared.sd_model.cond_stage_model.to(devices.device) - c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory) - shared.sd_model.cond_stage_model.to(devices.cpu) - else: - c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory) - loss = shared.sd_model(x, c)[0] / gradient_step - del x - del c + for batch in superbatch: + x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) + if tag_drop_out != 0 or shuffle_tags: + shared.sd_model.cond_stage_model.to(devices.device) + c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory) + shared.sd_model.cond_stage_model.to(devices.cpu) + else: + c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory) + loss = shared.sd_model(x, c)[0] / gradient_step * len(batch) / batch_size + del x + del c - _loss_step += loss.item() + _loss_step += loss.item() scaler.scale(loss).backward() # go back until we reach gradient accumulation steps diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index d31963d4..555c6e56 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -6,7 +6,6 @@ from PIL import Image from torch.utils.data import Dataset, DataLoader, Sampler from torchvision import transforms from collections import defaultdict -from random import shuffle, choices import random import tqdm @@ -167,23 +166,85 @@ class GroupedBatchSampler(Sampler): b = self.batch_size for g in self.groups: - shuffle(g) + random.shuffle(g) batches = [] for g in self.groups: batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b)) for _ in range(self.n_rand_batches): - rand_group = choices(self.groups, self.probs)[0] - batches.append(choices(rand_group, k=b)) + rand_group = random.choices(self.groups, self.probs)[0] + batches.append(random.choices(rand_group, k=b)) - shuffle(batches) + random.shuffle(batches) yield from batches +def greedy_pack(tails, b): + '''inefficient suboptimal packing of remainders from each bucket''' + by_len = defaultdict(list) + for t in tails: + by_len[len(t)].append(t) + n = sum(len(t) for t in tails) // b + superbatches = [] + for _ in range(n): + to_pick = b + superbatch = [] + while to_pick: + for k, v in sorted(by_len.items(), reverse=True): + # try pick longest + if k <= to_pick: + to_pick -= k + superbatch.append(v.pop()) + if not v: + del(by_len[k]) + break + else: + # can't find any so split a group + maxlen = max(by_len) + tail = by_len[maxlen].pop() + if not by_len[maxlen]: + del by_len[maxlen] + superbatch.append(tail[:to_pick]) + by_len[len(tail[to_pick:])].append(tail[to_pick:]) + to_pick = 0 + superbatches.append(superbatch) + return superbatches + +class VariableBatchSampler(Sampler): + def __init__(self, data_source: PersonalizedBase, batch_size: int): + self.n = len(data_source) + self.groups = data_source.groups + self.batch_size = batch_size + + def __iter__(self): + b = self.batch_size + dropped = set(random.sample(range(self.n), self.n % b)) + groups = [[x for x in g if x not in dropped] for g in self.groups] + for g in groups: + random.shuffle(g) + superbatches = [] + for g in groups: + superbatches.extend([g[i*b:(i+1)*b]] for i in range(len(g) // b)) + tails = [g[-(len(g) % b):] for g in groups if len(g) % b != 0] + random.shuffle(tails) + superbatches.extend(greedy_pack(tails, b)) + random.shuffle(superbatches) + yield from [batch for superbatch in superbatches for batch in superbatch] + +def group_batches(batches, batch_size): + m, superbatch = 0, [] + for batch in batches: + superbatch.append(batch) + m += len(batch) + assert m <= batch_size + if m == batch_size: + yield superbatch + m, superbatch = 0, [] + class PersonalizedDataLoader(DataLoader): def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False): - super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory) + super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=VariableBatchSampler(dataset, batch_size), pin_memory=pin_memory) if latent_sampling_method == "random": self.collate_fn = collate_wrapper_random else: @@ -197,6 +258,9 @@ class BatchLoader: self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1) #self.emb_index = [entry.emb_index for entry in data] #print(self.latent_sample.device) + + def __len__(self): + return len(self.cond) def pin_memory(self): self.latent_sample = self.latent_sample.pin_memory() diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 5a7be422..d60179d1 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -459,7 +459,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st break if shared.state.interrupted: break - for j, batch in enumerate(dl): + for j, superbatch in enumerate(modules.textual_inversion.dataset.group_batches(dl, batch_size)): # works as a drop_last=True for gradient accumulation if j == max_steps_per_epoch: break @@ -473,21 +473,22 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st clip_grad_sched.step(embedding.step) with devices.autocast(): - x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) - c = shared.sd_model.cond_stage_model(batch.cond_text) + for batch in superbatch: + x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) + c = shared.sd_model.cond_stage_model(batch.cond_text) - if is_training_inpainting_model: - if img_c is None: - img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height) + if is_training_inpainting_model: + if img_c is None: + img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height) - cond = {"c_concat": [img_c], "c_crossattn": [c]} - else: - cond = c + cond = {"c_concat": [img_c], "c_crossattn": [c]} + else: + cond = c - loss = shared.sd_model(x, cond)[0] / gradient_step - del x + loss = shared.sd_model(x, cond)[0] / gradient_step * len(batch) / batch_size + del x - _loss_step += loss.item() + _loss_step += loss.item() scaler.scale(loss).backward() # go back until we reach gradient accumulation steps