Variable sized batches

This commit is contained in:
dan 2023-01-20 14:42:20 +08:00
parent b165e341e7
commit 90bc57da51
3 changed files with 96 additions and 30 deletions

View File

@ -582,7 +582,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
break
if shared.state.interrupted:
break
for j, batch in enumerate(dl):
for j, superbatch in enumerate(modules.textual_inversion.dataset.group_batches(dl, batch_size)):
# works as a drop_last=True for gradient accumulation
if j == max_steps_per_epoch:
break
@ -596,6 +596,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
clip_grad_sched.step(hypernetwork.step)
with devices.autocast():
for batch in superbatch:
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
if tag_drop_out != 0 or shuffle_tags:
shared.sd_model.cond_stage_model.to(devices.device)
@ -603,7 +604,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
shared.sd_model.cond_stage_model.to(devices.cpu)
else:
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
loss = shared.sd_model(x, c)[0] / gradient_step
loss = shared.sd_model(x, c)[0] / gradient_step * len(batch) / batch_size
del x
del c

View File

@ -6,7 +6,6 @@ from PIL import Image
from torch.utils.data import Dataset, DataLoader, Sampler
from torchvision import transforms
from collections import defaultdict
from random import shuffle, choices
import random
import tqdm
@ -167,23 +166,85 @@ class GroupedBatchSampler(Sampler):
b = self.batch_size
for g in self.groups:
shuffle(g)
random.shuffle(g)
batches = []
for g in self.groups:
batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
for _ in range(self.n_rand_batches):
rand_group = choices(self.groups, self.probs)[0]
batches.append(choices(rand_group, k=b))
rand_group = random.choices(self.groups, self.probs)[0]
batches.append(random.choices(rand_group, k=b))
shuffle(batches)
random.shuffle(batches)
yield from batches
def greedy_pack(tails, b):
'''inefficient suboptimal packing of remainders from each bucket'''
by_len = defaultdict(list)
for t in tails:
by_len[len(t)].append(t)
n = sum(len(t) for t in tails) // b
superbatches = []
for _ in range(n):
to_pick = b
superbatch = []
while to_pick:
for k, v in sorted(by_len.items(), reverse=True):
# try pick longest
if k <= to_pick:
to_pick -= k
superbatch.append(v.pop())
if not v:
del(by_len[k])
break
else:
# can't find any so split a group
maxlen = max(by_len)
tail = by_len[maxlen].pop()
if not by_len[maxlen]:
del by_len[maxlen]
superbatch.append(tail[:to_pick])
by_len[len(tail[to_pick:])].append(tail[to_pick:])
to_pick = 0
superbatches.append(superbatch)
return superbatches
class VariableBatchSampler(Sampler):
def __init__(self, data_source: PersonalizedBase, batch_size: int):
self.n = len(data_source)
self.groups = data_source.groups
self.batch_size = batch_size
def __iter__(self):
b = self.batch_size
dropped = set(random.sample(range(self.n), self.n % b))
groups = [[x for x in g if x not in dropped] for g in self.groups]
for g in groups:
random.shuffle(g)
superbatches = []
for g in groups:
superbatches.extend([g[i*b:(i+1)*b]] for i in range(len(g) // b))
tails = [g[-(len(g) % b):] for g in groups if len(g) % b != 0]
random.shuffle(tails)
superbatches.extend(greedy_pack(tails, b))
random.shuffle(superbatches)
yield from [batch for superbatch in superbatches for batch in superbatch]
def group_batches(batches, batch_size):
m, superbatch = 0, []
for batch in batches:
superbatch.append(batch)
m += len(batch)
assert m <= batch_size
if m == batch_size:
yield superbatch
m, superbatch = 0, []
class PersonalizedDataLoader(DataLoader):
def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory)
super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=VariableBatchSampler(dataset, batch_size), pin_memory=pin_memory)
if latent_sampling_method == "random":
self.collate_fn = collate_wrapper_random
else:
@ -198,6 +259,9 @@ class BatchLoader:
#self.emb_index = [entry.emb_index for entry in data]
#print(self.latent_sample.device)
def __len__(self):
return len(self.cond)
def pin_memory(self):
self.latent_sample = self.latent_sample.pin_memory()
return self

View File

@ -459,7 +459,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
break
if shared.state.interrupted:
break
for j, batch in enumerate(dl):
for j, superbatch in enumerate(modules.textual_inversion.dataset.group_batches(dl, batch_size)):
# works as a drop_last=True for gradient accumulation
if j == max_steps_per_epoch:
break
@ -473,6 +473,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
clip_grad_sched.step(embedding.step)
with devices.autocast():
for batch in superbatch:
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
c = shared.sd_model.cond_stage_model(batch.cond_text)
@ -484,7 +485,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
else:
cond = c
loss = shared.sd_model(x, cond)[0] / gradient_step
loss = shared.sd_model(x, cond)[0] / gradient_step * len(batch) / batch_size
del x
_loss_step += loss.item()