Variable sized batches
This commit is contained in:
parent
b165e341e7
commit
90bc57da51
|
@ -582,7 +582,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||
break
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
for j, batch in enumerate(dl):
|
||||
for j, superbatch in enumerate(modules.textual_inversion.dataset.group_batches(dl, batch_size)):
|
||||
# works as a drop_last=True for gradient accumulation
|
||||
if j == max_steps_per_epoch:
|
||||
break
|
||||
|
@ -596,18 +596,19 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||
clip_grad_sched.step(hypernetwork.step)
|
||||
|
||||
with devices.autocast():
|
||||
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||
if tag_drop_out != 0 or shuffle_tags:
|
||||
shared.sd_model.cond_stage_model.to(devices.device)
|
||||
c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
|
||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||
else:
|
||||
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
|
||||
loss = shared.sd_model(x, c)[0] / gradient_step
|
||||
del x
|
||||
del c
|
||||
for batch in superbatch:
|
||||
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||
if tag_drop_out != 0 or shuffle_tags:
|
||||
shared.sd_model.cond_stage_model.to(devices.device)
|
||||
c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
|
||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||
else:
|
||||
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
|
||||
loss = shared.sd_model(x, c)[0] / gradient_step * len(batch) / batch_size
|
||||
del x
|
||||
del c
|
||||
|
||||
_loss_step += loss.item()
|
||||
_loss_step += loss.item()
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
# go back until we reach gradient accumulation steps
|
||||
|
|
|
@ -6,7 +6,6 @@ from PIL import Image
|
|||
from torch.utils.data import Dataset, DataLoader, Sampler
|
||||
from torchvision import transforms
|
||||
from collections import defaultdict
|
||||
from random import shuffle, choices
|
||||
|
||||
import random
|
||||
import tqdm
|
||||
|
@ -167,23 +166,85 @@ class GroupedBatchSampler(Sampler):
|
|||
b = self.batch_size
|
||||
|
||||
for g in self.groups:
|
||||
shuffle(g)
|
||||
random.shuffle(g)
|
||||
|
||||
batches = []
|
||||
for g in self.groups:
|
||||
batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
|
||||
for _ in range(self.n_rand_batches):
|
||||
rand_group = choices(self.groups, self.probs)[0]
|
||||
batches.append(choices(rand_group, k=b))
|
||||
rand_group = random.choices(self.groups, self.probs)[0]
|
||||
batches.append(random.choices(rand_group, k=b))
|
||||
|
||||
shuffle(batches)
|
||||
random.shuffle(batches)
|
||||
|
||||
yield from batches
|
||||
|
||||
def greedy_pack(tails, b):
|
||||
'''inefficient suboptimal packing of remainders from each bucket'''
|
||||
by_len = defaultdict(list)
|
||||
for t in tails:
|
||||
by_len[len(t)].append(t)
|
||||
n = sum(len(t) for t in tails) // b
|
||||
superbatches = []
|
||||
for _ in range(n):
|
||||
to_pick = b
|
||||
superbatch = []
|
||||
while to_pick:
|
||||
for k, v in sorted(by_len.items(), reverse=True):
|
||||
# try pick longest
|
||||
if k <= to_pick:
|
||||
to_pick -= k
|
||||
superbatch.append(v.pop())
|
||||
if not v:
|
||||
del(by_len[k])
|
||||
break
|
||||
else:
|
||||
# can't find any so split a group
|
||||
maxlen = max(by_len)
|
||||
tail = by_len[maxlen].pop()
|
||||
if not by_len[maxlen]:
|
||||
del by_len[maxlen]
|
||||
superbatch.append(tail[:to_pick])
|
||||
by_len[len(tail[to_pick:])].append(tail[to_pick:])
|
||||
to_pick = 0
|
||||
superbatches.append(superbatch)
|
||||
return superbatches
|
||||
|
||||
class VariableBatchSampler(Sampler):
|
||||
def __init__(self, data_source: PersonalizedBase, batch_size: int):
|
||||
self.n = len(data_source)
|
||||
self.groups = data_source.groups
|
||||
self.batch_size = batch_size
|
||||
|
||||
def __iter__(self):
|
||||
b = self.batch_size
|
||||
dropped = set(random.sample(range(self.n), self.n % b))
|
||||
groups = [[x for x in g if x not in dropped] for g in self.groups]
|
||||
for g in groups:
|
||||
random.shuffle(g)
|
||||
superbatches = []
|
||||
for g in groups:
|
||||
superbatches.extend([g[i*b:(i+1)*b]] for i in range(len(g) // b))
|
||||
tails = [g[-(len(g) % b):] for g in groups if len(g) % b != 0]
|
||||
random.shuffle(tails)
|
||||
superbatches.extend(greedy_pack(tails, b))
|
||||
random.shuffle(superbatches)
|
||||
yield from [batch for superbatch in superbatches for batch in superbatch]
|
||||
|
||||
def group_batches(batches, batch_size):
|
||||
m, superbatch = 0, []
|
||||
for batch in batches:
|
||||
superbatch.append(batch)
|
||||
m += len(batch)
|
||||
assert m <= batch_size
|
||||
if m == batch_size:
|
||||
yield superbatch
|
||||
m, superbatch = 0, []
|
||||
|
||||
|
||||
class PersonalizedDataLoader(DataLoader):
|
||||
def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
|
||||
super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory)
|
||||
super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=VariableBatchSampler(dataset, batch_size), pin_memory=pin_memory)
|
||||
if latent_sampling_method == "random":
|
||||
self.collate_fn = collate_wrapper_random
|
||||
else:
|
||||
|
@ -198,6 +259,9 @@ class BatchLoader:
|
|||
#self.emb_index = [entry.emb_index for entry in data]
|
||||
#print(self.latent_sample.device)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.cond)
|
||||
|
||||
def pin_memory(self):
|
||||
self.latent_sample = self.latent_sample.pin_memory()
|
||||
return self
|
||||
|
|
|
@ -459,7 +459,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
|||
break
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
for j, batch in enumerate(dl):
|
||||
for j, superbatch in enumerate(modules.textual_inversion.dataset.group_batches(dl, batch_size)):
|
||||
# works as a drop_last=True for gradient accumulation
|
||||
if j == max_steps_per_epoch:
|
||||
break
|
||||
|
@ -473,21 +473,22 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
|||
clip_grad_sched.step(embedding.step)
|
||||
|
||||
with devices.autocast():
|
||||
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||
c = shared.sd_model.cond_stage_model(batch.cond_text)
|
||||
for batch in superbatch:
|
||||
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||
c = shared.sd_model.cond_stage_model(batch.cond_text)
|
||||
|
||||
if is_training_inpainting_model:
|
||||
if img_c is None:
|
||||
img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height)
|
||||
if is_training_inpainting_model:
|
||||
if img_c is None:
|
||||
img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height)
|
||||
|
||||
cond = {"c_concat": [img_c], "c_crossattn": [c]}
|
||||
else:
|
||||
cond = c
|
||||
cond = {"c_concat": [img_c], "c_crossattn": [c]}
|
||||
else:
|
||||
cond = c
|
||||
|
||||
loss = shared.sd_model(x, cond)[0] / gradient_step
|
||||
del x
|
||||
loss = shared.sd_model(x, cond)[0] / gradient_step * len(batch) / batch_size
|
||||
del x
|
||||
|
||||
_loss_step += loss.item()
|
||||
_loss_step += loss.item()
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
# go back until we reach gradient accumulation steps
|
||||
|
|
Loading…
Reference in New Issue
Block a user