Variable sized batches
This commit is contained in:
parent
b165e341e7
commit
90bc57da51
|
@ -582,7 +582,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
||||||
break
|
break
|
||||||
if shared.state.interrupted:
|
if shared.state.interrupted:
|
||||||
break
|
break
|
||||||
for j, batch in enumerate(dl):
|
for j, superbatch in enumerate(modules.textual_inversion.dataset.group_batches(dl, batch_size)):
|
||||||
# works as a drop_last=True for gradient accumulation
|
# works as a drop_last=True for gradient accumulation
|
||||||
if j == max_steps_per_epoch:
|
if j == max_steps_per_epoch:
|
||||||
break
|
break
|
||||||
|
@ -596,6 +596,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
||||||
clip_grad_sched.step(hypernetwork.step)
|
clip_grad_sched.step(hypernetwork.step)
|
||||||
|
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
|
for batch in superbatch:
|
||||||
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||||
if tag_drop_out != 0 or shuffle_tags:
|
if tag_drop_out != 0 or shuffle_tags:
|
||||||
shared.sd_model.cond_stage_model.to(devices.device)
|
shared.sd_model.cond_stage_model.to(devices.device)
|
||||||
|
@ -603,7 +604,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
||||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||||
else:
|
else:
|
||||||
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
|
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
|
||||||
loss = shared.sd_model(x, c)[0] / gradient_step
|
loss = shared.sd_model(x, c)[0] / gradient_step * len(batch) / batch_size
|
||||||
del x
|
del x
|
||||||
del c
|
del c
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,6 @@ from PIL import Image
|
||||||
from torch.utils.data import Dataset, DataLoader, Sampler
|
from torch.utils.data import Dataset, DataLoader, Sampler
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from random import shuffle, choices
|
|
||||||
|
|
||||||
import random
|
import random
|
||||||
import tqdm
|
import tqdm
|
||||||
|
@ -167,23 +166,85 @@ class GroupedBatchSampler(Sampler):
|
||||||
b = self.batch_size
|
b = self.batch_size
|
||||||
|
|
||||||
for g in self.groups:
|
for g in self.groups:
|
||||||
shuffle(g)
|
random.shuffle(g)
|
||||||
|
|
||||||
batches = []
|
batches = []
|
||||||
for g in self.groups:
|
for g in self.groups:
|
||||||
batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
|
batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
|
||||||
for _ in range(self.n_rand_batches):
|
for _ in range(self.n_rand_batches):
|
||||||
rand_group = choices(self.groups, self.probs)[0]
|
rand_group = random.choices(self.groups, self.probs)[0]
|
||||||
batches.append(choices(rand_group, k=b))
|
batches.append(random.choices(rand_group, k=b))
|
||||||
|
|
||||||
shuffle(batches)
|
random.shuffle(batches)
|
||||||
|
|
||||||
yield from batches
|
yield from batches
|
||||||
|
|
||||||
|
def greedy_pack(tails, b):
|
||||||
|
'''inefficient suboptimal packing of remainders from each bucket'''
|
||||||
|
by_len = defaultdict(list)
|
||||||
|
for t in tails:
|
||||||
|
by_len[len(t)].append(t)
|
||||||
|
n = sum(len(t) for t in tails) // b
|
||||||
|
superbatches = []
|
||||||
|
for _ in range(n):
|
||||||
|
to_pick = b
|
||||||
|
superbatch = []
|
||||||
|
while to_pick:
|
||||||
|
for k, v in sorted(by_len.items(), reverse=True):
|
||||||
|
# try pick longest
|
||||||
|
if k <= to_pick:
|
||||||
|
to_pick -= k
|
||||||
|
superbatch.append(v.pop())
|
||||||
|
if not v:
|
||||||
|
del(by_len[k])
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# can't find any so split a group
|
||||||
|
maxlen = max(by_len)
|
||||||
|
tail = by_len[maxlen].pop()
|
||||||
|
if not by_len[maxlen]:
|
||||||
|
del by_len[maxlen]
|
||||||
|
superbatch.append(tail[:to_pick])
|
||||||
|
by_len[len(tail[to_pick:])].append(tail[to_pick:])
|
||||||
|
to_pick = 0
|
||||||
|
superbatches.append(superbatch)
|
||||||
|
return superbatches
|
||||||
|
|
||||||
|
class VariableBatchSampler(Sampler):
|
||||||
|
def __init__(self, data_source: PersonalizedBase, batch_size: int):
|
||||||
|
self.n = len(data_source)
|
||||||
|
self.groups = data_source.groups
|
||||||
|
self.batch_size = batch_size
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
b = self.batch_size
|
||||||
|
dropped = set(random.sample(range(self.n), self.n % b))
|
||||||
|
groups = [[x for x in g if x not in dropped] for g in self.groups]
|
||||||
|
for g in groups:
|
||||||
|
random.shuffle(g)
|
||||||
|
superbatches = []
|
||||||
|
for g in groups:
|
||||||
|
superbatches.extend([g[i*b:(i+1)*b]] for i in range(len(g) // b))
|
||||||
|
tails = [g[-(len(g) % b):] for g in groups if len(g) % b != 0]
|
||||||
|
random.shuffle(tails)
|
||||||
|
superbatches.extend(greedy_pack(tails, b))
|
||||||
|
random.shuffle(superbatches)
|
||||||
|
yield from [batch for superbatch in superbatches for batch in superbatch]
|
||||||
|
|
||||||
|
def group_batches(batches, batch_size):
|
||||||
|
m, superbatch = 0, []
|
||||||
|
for batch in batches:
|
||||||
|
superbatch.append(batch)
|
||||||
|
m += len(batch)
|
||||||
|
assert m <= batch_size
|
||||||
|
if m == batch_size:
|
||||||
|
yield superbatch
|
||||||
|
m, superbatch = 0, []
|
||||||
|
|
||||||
|
|
||||||
class PersonalizedDataLoader(DataLoader):
|
class PersonalizedDataLoader(DataLoader):
|
||||||
def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
|
def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
|
||||||
super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory)
|
super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=VariableBatchSampler(dataset, batch_size), pin_memory=pin_memory)
|
||||||
if latent_sampling_method == "random":
|
if latent_sampling_method == "random":
|
||||||
self.collate_fn = collate_wrapper_random
|
self.collate_fn = collate_wrapper_random
|
||||||
else:
|
else:
|
||||||
|
@ -198,6 +259,9 @@ class BatchLoader:
|
||||||
#self.emb_index = [entry.emb_index for entry in data]
|
#self.emb_index = [entry.emb_index for entry in data]
|
||||||
#print(self.latent_sample.device)
|
#print(self.latent_sample.device)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.cond)
|
||||||
|
|
||||||
def pin_memory(self):
|
def pin_memory(self):
|
||||||
self.latent_sample = self.latent_sample.pin_memory()
|
self.latent_sample = self.latent_sample.pin_memory()
|
||||||
return self
|
return self
|
||||||
|
|
|
@ -459,7 +459,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
||||||
break
|
break
|
||||||
if shared.state.interrupted:
|
if shared.state.interrupted:
|
||||||
break
|
break
|
||||||
for j, batch in enumerate(dl):
|
for j, superbatch in enumerate(modules.textual_inversion.dataset.group_batches(dl, batch_size)):
|
||||||
# works as a drop_last=True for gradient accumulation
|
# works as a drop_last=True for gradient accumulation
|
||||||
if j == max_steps_per_epoch:
|
if j == max_steps_per_epoch:
|
||||||
break
|
break
|
||||||
|
@ -473,6 +473,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
||||||
clip_grad_sched.step(embedding.step)
|
clip_grad_sched.step(embedding.step)
|
||||||
|
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
|
for batch in superbatch:
|
||||||
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||||
c = shared.sd_model.cond_stage_model(batch.cond_text)
|
c = shared.sd_model.cond_stage_model(batch.cond_text)
|
||||||
|
|
||||||
|
@ -484,7 +485,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
||||||
else:
|
else:
|
||||||
cond = c
|
cond = c
|
||||||
|
|
||||||
loss = shared.sd_model(x, cond)[0] / gradient_step
|
loss = shared.sd_model(x, cond)[0] / gradient_step * len(batch) / batch_size
|
||||||
del x
|
del x
|
||||||
|
|
||||||
_loss_step += loss.item()
|
_loss_step += loss.item()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user