Variable sized batches

This commit is contained in:
dan 2023-01-20 14:42:20 +08:00
parent b165e341e7
commit 90bc57da51
3 changed files with 96 additions and 30 deletions

View File

@ -582,7 +582,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
break break
if shared.state.interrupted: if shared.state.interrupted:
break break
for j, batch in enumerate(dl): for j, superbatch in enumerate(modules.textual_inversion.dataset.group_batches(dl, batch_size)):
# works as a drop_last=True for gradient accumulation # works as a drop_last=True for gradient accumulation
if j == max_steps_per_epoch: if j == max_steps_per_epoch:
break break
@ -596,18 +596,19 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
clip_grad_sched.step(hypernetwork.step) clip_grad_sched.step(hypernetwork.step)
with devices.autocast(): with devices.autocast():
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) for batch in superbatch:
if tag_drop_out != 0 or shuffle_tags: x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
shared.sd_model.cond_stage_model.to(devices.device) if tag_drop_out != 0 or shuffle_tags:
c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory) shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.cond_stage_model.to(devices.cpu) c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
else: shared.sd_model.cond_stage_model.to(devices.cpu)
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory) else:
loss = shared.sd_model(x, c)[0] / gradient_step c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
del x loss = shared.sd_model(x, c)[0] / gradient_step * len(batch) / batch_size
del c del x
del c
_loss_step += loss.item() _loss_step += loss.item()
scaler.scale(loss).backward() scaler.scale(loss).backward()
# go back until we reach gradient accumulation steps # go back until we reach gradient accumulation steps

View File

@ -6,7 +6,6 @@ from PIL import Image
from torch.utils.data import Dataset, DataLoader, Sampler from torch.utils.data import Dataset, DataLoader, Sampler
from torchvision import transforms from torchvision import transforms
from collections import defaultdict from collections import defaultdict
from random import shuffle, choices
import random import random
import tqdm import tqdm
@ -167,23 +166,85 @@ class GroupedBatchSampler(Sampler):
b = self.batch_size b = self.batch_size
for g in self.groups: for g in self.groups:
shuffle(g) random.shuffle(g)
batches = [] batches = []
for g in self.groups: for g in self.groups:
batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b)) batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
for _ in range(self.n_rand_batches): for _ in range(self.n_rand_batches):
rand_group = choices(self.groups, self.probs)[0] rand_group = random.choices(self.groups, self.probs)[0]
batches.append(choices(rand_group, k=b)) batches.append(random.choices(rand_group, k=b))
shuffle(batches) random.shuffle(batches)
yield from batches yield from batches
def greedy_pack(tails, b):
'''inefficient suboptimal packing of remainders from each bucket'''
by_len = defaultdict(list)
for t in tails:
by_len[len(t)].append(t)
n = sum(len(t) for t in tails) // b
superbatches = []
for _ in range(n):
to_pick = b
superbatch = []
while to_pick:
for k, v in sorted(by_len.items(), reverse=True):
# try pick longest
if k <= to_pick:
to_pick -= k
superbatch.append(v.pop())
if not v:
del(by_len[k])
break
else:
# can't find any so split a group
maxlen = max(by_len)
tail = by_len[maxlen].pop()
if not by_len[maxlen]:
del by_len[maxlen]
superbatch.append(tail[:to_pick])
by_len[len(tail[to_pick:])].append(tail[to_pick:])
to_pick = 0
superbatches.append(superbatch)
return superbatches
class VariableBatchSampler(Sampler):
def __init__(self, data_source: PersonalizedBase, batch_size: int):
self.n = len(data_source)
self.groups = data_source.groups
self.batch_size = batch_size
def __iter__(self):
b = self.batch_size
dropped = set(random.sample(range(self.n), self.n % b))
groups = [[x for x in g if x not in dropped] for g in self.groups]
for g in groups:
random.shuffle(g)
superbatches = []
for g in groups:
superbatches.extend([g[i*b:(i+1)*b]] for i in range(len(g) // b))
tails = [g[-(len(g) % b):] for g in groups if len(g) % b != 0]
random.shuffle(tails)
superbatches.extend(greedy_pack(tails, b))
random.shuffle(superbatches)
yield from [batch for superbatch in superbatches for batch in superbatch]
def group_batches(batches, batch_size):
m, superbatch = 0, []
for batch in batches:
superbatch.append(batch)
m += len(batch)
assert m <= batch_size
if m == batch_size:
yield superbatch
m, superbatch = 0, []
class PersonalizedDataLoader(DataLoader): class PersonalizedDataLoader(DataLoader):
def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False): def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory) super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=VariableBatchSampler(dataset, batch_size), pin_memory=pin_memory)
if latent_sampling_method == "random": if latent_sampling_method == "random":
self.collate_fn = collate_wrapper_random self.collate_fn = collate_wrapper_random
else: else:
@ -197,6 +258,9 @@ class BatchLoader:
self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1) self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
#self.emb_index = [entry.emb_index for entry in data] #self.emb_index = [entry.emb_index for entry in data]
#print(self.latent_sample.device) #print(self.latent_sample.device)
def __len__(self):
return len(self.cond)
def pin_memory(self): def pin_memory(self):
self.latent_sample = self.latent_sample.pin_memory() self.latent_sample = self.latent_sample.pin_memory()

View File

@ -459,7 +459,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
break break
if shared.state.interrupted: if shared.state.interrupted:
break break
for j, batch in enumerate(dl): for j, superbatch in enumerate(modules.textual_inversion.dataset.group_batches(dl, batch_size)):
# works as a drop_last=True for gradient accumulation # works as a drop_last=True for gradient accumulation
if j == max_steps_per_epoch: if j == max_steps_per_epoch:
break break
@ -473,21 +473,22 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
clip_grad_sched.step(embedding.step) clip_grad_sched.step(embedding.step)
with devices.autocast(): with devices.autocast():
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) for batch in superbatch:
c = shared.sd_model.cond_stage_model(batch.cond_text) x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
c = shared.sd_model.cond_stage_model(batch.cond_text)
if is_training_inpainting_model: if is_training_inpainting_model:
if img_c is None: if img_c is None:
img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height) img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height)
cond = {"c_concat": [img_c], "c_crossattn": [c]} cond = {"c_concat": [img_c], "c_crossattn": [c]}
else: else:
cond = c cond = c
loss = shared.sd_model(x, cond)[0] / gradient_step loss = shared.sd_model(x, cond)[0] / gradient_step * len(batch) / batch_size
del x del x
_loss_step += loss.item() _loss_step += loss.item()
scaler.scale(loss).backward() scaler.scale(loss).backward()
# go back until we reach gradient accumulation steps # go back until we reach gradient accumulation steps