diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 825a93b2..f44655f5 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -625,7 +625,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
break
if shared.state.interrupted:
break
- for j, batch in enumerate(dl):
+ for j, superbatch in enumerate(modules.textual_inversion.dataset.group_batches(dl, batch_size)):
# works as a drop_last=True for gradient accumulation
if j == max_steps_per_epoch:
break
@@ -638,19 +638,19 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
if clip_grad:
clip_grad_sched.step(hypernetwork.step)
- with devices.autocast():
- x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
- if tag_drop_out != 0 or shuffle_tags:
- shared.sd_model.cond_stage_model.to(devices.device)
- c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
- shared.sd_model.cond_stage_model.to(devices.cpu)
- else:
- c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
- loss = shared.sd_model(x, c)[0] / gradient_step
- del x
- del c
+ def get_loss(batch):
+ with devices.autocast():
+ x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
+ if tag_drop_out != 0 or shuffle_tags:
+ shared.sd_model.cond_stage_model.to(devices.device)
+ c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
+ shared.sd_model.cond_stage_model.to(devices.cpu)
+ else:
+ c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
+ return shared.sd_model(x, c)[0] / gradient_step * len(batch) / batch_size
- _loss_step += loss.item()
+ loss = sum(get_loss(batch) for batch in superbatch)
+ _loss_step += loss.item()
scaler.scale(loss).backward()
# go back until we reach gradient accumulation steps
@@ -727,7 +727,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
p.width = preview_width
p.height = preview_height
else:
- p.prompt = batch.cond_text[0]
+ p.prompt = superbatch[0].cond_text[0]
p.steps = 20
p.width = training_width
p.height = training_height
@@ -759,7 +759,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
Loss: {loss_step:.7f}
Step: {steps_done}
-Last prompt: {html.escape(batch.cond_text[0])}
+Last prompt: {html.escape(superbatch[0].cond_text[0])}
Last saved hypernetwork: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index d31963d4..9f92f274 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -6,7 +6,6 @@ from PIL import Image
from torch.utils.data import Dataset, DataLoader, Sampler
from torchvision import transforms
from collections import defaultdict
-from random import shuffle, choices
import random
import tqdm
@@ -147,43 +146,72 @@ class PersonalizedBase(Dataset):
return entry
-class GroupedBatchSampler(Sampler):
+def greedy_pack(tails, b):
+ '''inefficient suboptimal packing of remainders from each bucket'''
+ by_len = defaultdict(list)
+ for t in tails:
+ by_len[len(t)].append(t)
+ n = sum(len(t) for t in tails) // b
+ superbatches = []
+ for _ in range(n):
+ to_pick = b
+ superbatch = []
+ while to_pick:
+ for k, v in sorted(by_len.items(), reverse=True):
+ # try pick longest
+ if k <= to_pick:
+ to_pick -= k
+ superbatch.append(v.pop())
+ if not v:
+ del(by_len[k])
+ break
+ else:
+ # can't find any so split a group
+ maxlen = max(by_len)
+ tail = by_len[maxlen].pop()
+ if not by_len[maxlen]:
+ del by_len[maxlen]
+ superbatch.append(tail[:to_pick])
+ by_len[len(tail[to_pick:])].append(tail[to_pick:])
+ to_pick = 0
+ superbatches.append(superbatch)
+ return superbatches
+
+class VariableBatchSampler(Sampler):
def __init__(self, data_source: PersonalizedBase, batch_size: int):
- super().__init__(data_source)
-
- n = len(data_source)
+ self.n = len(data_source)
self.groups = data_source.groups
- self.len = n_batch = n // batch_size
- expected = [len(g) / n * n_batch * batch_size for g in data_source.groups]
- self.base = [int(e) // batch_size for e in expected]
- self.n_rand_batches = nrb = n_batch - sum(self.base)
- self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]
self.batch_size = batch_size
-
- def __len__(self):
- return self.len
-
+
def __iter__(self):
b = self.batch_size
+ dropped = set(random.sample(range(self.n), self.n % b))
+ groups = [[x for x in g if x not in dropped] for g in self.groups]
+ for g in groups:
+ random.shuffle(g)
+ superbatches = []
+ for g in groups:
+ superbatches.extend([g[i*b:(i+1)*b]] for i in range(len(g) // b))
+ tails = [g[-(len(g) % b):] for g in groups if len(g) % b != 0]
+ random.shuffle(tails)
+ superbatches.extend(greedy_pack(tails, b))
+ random.shuffle(superbatches)
+ yield from [batch for superbatch in superbatches for batch in superbatch]
- for g in self.groups:
- shuffle(g)
-
- batches = []
- for g in self.groups:
- batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
- for _ in range(self.n_rand_batches):
- rand_group = choices(self.groups, self.probs)[0]
- batches.append(choices(rand_group, k=b))
-
- shuffle(batches)
-
- yield from batches
+def group_batches(batches, batch_size):
+ m, superbatch = 0, []
+ for batch in batches:
+ superbatch.append(batch)
+ m += len(batch)
+ assert m <= batch_size
+ if m == batch_size:
+ yield superbatch
+ m, superbatch = 0, []
class PersonalizedDataLoader(DataLoader):
def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
- super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory)
+ super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=VariableBatchSampler(dataset, batch_size), pin_memory=pin_memory)
if latent_sampling_method == "random":
self.collate_fn = collate_wrapper_random
else:
@@ -197,6 +225,9 @@ class BatchLoader:
self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
#self.emb_index = [entry.emb_index for entry in data]
#print(self.latent_sample.device)
+
+ def __len__(self):
+ return len(self.cond)
def pin_memory(self):
self.latent_sample = self.latent_sample.pin_memory()
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index a1a406c2..a1de318a 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -465,7 +465,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
break
if shared.state.interrupted:
break
- for j, batch in enumerate(dl):
+ for j, superbatch in enumerate(modules.textual_inversion.dataset.group_batches(dl, batch_size)):
# works as a drop_last=True for gradient accumulation
if j == max_steps_per_epoch:
break
@@ -477,23 +477,24 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
if clip_grad:
clip_grad_sched.step(embedding.step)
-
- with devices.autocast():
- x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
- c = shared.sd_model.cond_stage_model(batch.cond_text)
- if is_training_inpainting_model:
- if img_c is None:
- img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height)
+ def get_loss(batch):
+ with devices.autocast():
+ x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
+ c = shared.sd_model.cond_stage_model(batch.cond_text)
- cond = {"c_concat": [img_c], "c_crossattn": [c]}
- else:
- cond = c
+ if is_training_inpainting_model:
+ if img_c is None:
+ img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height)
- loss = shared.sd_model(x, cond)[0] / gradient_step
- del x
+ cond = {"c_concat": [img_c], "c_crossattn": [c]}
+ else:
+ cond = c
- _loss_step += loss.item()
+ return shared.sd_model(x, cond)[0] / gradient_step * len(batch) / batch_size
+
+ loss = sum(get_loss(batch) for batch in superbatch)
+ _loss_step += loss.item()
scaler.scale(loss).backward()
# go back until we reach gradient accumulation steps
@@ -553,7 +554,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
p.width = preview_width
p.height = preview_height
else:
- p.prompt = batch.cond_text[0]
+ p.prompt = superbatch[0].cond_text[0]
p.steps = 20
p.width = training_width
p.height = training_height
@@ -610,7 +611,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
Loss: {loss_step:.7f}
Step: {steps_done}
-Last prompt: {html.escape(batch.cond_text[0])}
+Last prompt: {html.escape(superbatch[0].cond_text[0])}
Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}