Merge a7c92bd5ec
into ea9bd9fc74
This commit is contained in:
commit
820e26db7a
|
@ -625,7 +625,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
||||||
break
|
break
|
||||||
if shared.state.interrupted:
|
if shared.state.interrupted:
|
||||||
break
|
break
|
||||||
for j, batch in enumerate(dl):
|
for j, superbatch in enumerate(modules.textual_inversion.dataset.group_batches(dl, batch_size)):
|
||||||
# works as a drop_last=True for gradient accumulation
|
# works as a drop_last=True for gradient accumulation
|
||||||
if j == max_steps_per_epoch:
|
if j == max_steps_per_epoch:
|
||||||
break
|
break
|
||||||
|
@ -638,19 +638,19 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
||||||
if clip_grad:
|
if clip_grad:
|
||||||
clip_grad_sched.step(hypernetwork.step)
|
clip_grad_sched.step(hypernetwork.step)
|
||||||
|
|
||||||
with devices.autocast():
|
def get_loss(batch):
|
||||||
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
with devices.autocast():
|
||||||
if tag_drop_out != 0 or shuffle_tags:
|
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||||
shared.sd_model.cond_stage_model.to(devices.device)
|
if tag_drop_out != 0 or shuffle_tags:
|
||||||
c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
|
shared.sd_model.cond_stage_model.to(devices.device)
|
||||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
|
||||||
else:
|
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||||
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
|
else:
|
||||||
loss = shared.sd_model(x, c)[0] / gradient_step
|
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
|
||||||
del x
|
return shared.sd_model(x, c)[0] / gradient_step * len(batch) / batch_size
|
||||||
del c
|
|
||||||
|
|
||||||
_loss_step += loss.item()
|
loss = sum(get_loss(batch) for batch in superbatch)
|
||||||
|
_loss_step += loss.item()
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
|
|
||||||
# go back until we reach gradient accumulation steps
|
# go back until we reach gradient accumulation steps
|
||||||
|
@ -727,7 +727,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
||||||
p.width = preview_width
|
p.width = preview_width
|
||||||
p.height = preview_height
|
p.height = preview_height
|
||||||
else:
|
else:
|
||||||
p.prompt = batch.cond_text[0]
|
p.prompt = superbatch[0].cond_text[0]
|
||||||
p.steps = 20
|
p.steps = 20
|
||||||
p.width = training_width
|
p.width = training_width
|
||||||
p.height = training_height
|
p.height = training_height
|
||||||
|
@ -759,7 +759,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
||||||
<p>
|
<p>
|
||||||
Loss: {loss_step:.7f}<br/>
|
Loss: {loss_step:.7f}<br/>
|
||||||
Step: {steps_done}<br/>
|
Step: {steps_done}<br/>
|
||||||
Last prompt: {html.escape(batch.cond_text[0])}<br/>
|
Last prompt: {html.escape(superbatch[0].cond_text[0])}<br/>
|
||||||
Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
|
Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
|
||||||
Last saved image: {html.escape(last_saved_image)}<br/>
|
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||||
</p>
|
</p>
|
||||||
|
|
|
@ -6,7 +6,6 @@ from PIL import Image
|
||||||
from torch.utils.data import Dataset, DataLoader, Sampler
|
from torch.utils.data import Dataset, DataLoader, Sampler
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from random import shuffle, choices
|
|
||||||
|
|
||||||
import random
|
import random
|
||||||
import tqdm
|
import tqdm
|
||||||
|
@ -147,43 +146,72 @@ class PersonalizedBase(Dataset):
|
||||||
return entry
|
return entry
|
||||||
|
|
||||||
|
|
||||||
class GroupedBatchSampler(Sampler):
|
def greedy_pack(tails, b):
|
||||||
|
'''inefficient suboptimal packing of remainders from each bucket'''
|
||||||
|
by_len = defaultdict(list)
|
||||||
|
for t in tails:
|
||||||
|
by_len[len(t)].append(t)
|
||||||
|
n = sum(len(t) for t in tails) // b
|
||||||
|
superbatches = []
|
||||||
|
for _ in range(n):
|
||||||
|
to_pick = b
|
||||||
|
superbatch = []
|
||||||
|
while to_pick:
|
||||||
|
for k, v in sorted(by_len.items(), reverse=True):
|
||||||
|
# try pick longest
|
||||||
|
if k <= to_pick:
|
||||||
|
to_pick -= k
|
||||||
|
superbatch.append(v.pop())
|
||||||
|
if not v:
|
||||||
|
del(by_len[k])
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# can't find any so split a group
|
||||||
|
maxlen = max(by_len)
|
||||||
|
tail = by_len[maxlen].pop()
|
||||||
|
if not by_len[maxlen]:
|
||||||
|
del by_len[maxlen]
|
||||||
|
superbatch.append(tail[:to_pick])
|
||||||
|
by_len[len(tail[to_pick:])].append(tail[to_pick:])
|
||||||
|
to_pick = 0
|
||||||
|
superbatches.append(superbatch)
|
||||||
|
return superbatches
|
||||||
|
|
||||||
|
class VariableBatchSampler(Sampler):
|
||||||
def __init__(self, data_source: PersonalizedBase, batch_size: int):
|
def __init__(self, data_source: PersonalizedBase, batch_size: int):
|
||||||
super().__init__(data_source)
|
self.n = len(data_source)
|
||||||
|
|
||||||
n = len(data_source)
|
|
||||||
self.groups = data_source.groups
|
self.groups = data_source.groups
|
||||||
self.len = n_batch = n // batch_size
|
|
||||||
expected = [len(g) / n * n_batch * batch_size for g in data_source.groups]
|
|
||||||
self.base = [int(e) // batch_size for e in expected]
|
|
||||||
self.n_rand_batches = nrb = n_batch - sum(self.base)
|
|
||||||
self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]
|
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.len
|
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
b = self.batch_size
|
b = self.batch_size
|
||||||
|
dropped = set(random.sample(range(self.n), self.n % b))
|
||||||
|
groups = [[x for x in g if x not in dropped] for g in self.groups]
|
||||||
|
for g in groups:
|
||||||
|
random.shuffle(g)
|
||||||
|
superbatches = []
|
||||||
|
for g in groups:
|
||||||
|
superbatches.extend([g[i*b:(i+1)*b]] for i in range(len(g) // b))
|
||||||
|
tails = [g[-(len(g) % b):] for g in groups if len(g) % b != 0]
|
||||||
|
random.shuffle(tails)
|
||||||
|
superbatches.extend(greedy_pack(tails, b))
|
||||||
|
random.shuffle(superbatches)
|
||||||
|
yield from [batch for superbatch in superbatches for batch in superbatch]
|
||||||
|
|
||||||
for g in self.groups:
|
def group_batches(batches, batch_size):
|
||||||
shuffle(g)
|
m, superbatch = 0, []
|
||||||
|
for batch in batches:
|
||||||
batches = []
|
superbatch.append(batch)
|
||||||
for g in self.groups:
|
m += len(batch)
|
||||||
batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
|
assert m <= batch_size
|
||||||
for _ in range(self.n_rand_batches):
|
if m == batch_size:
|
||||||
rand_group = choices(self.groups, self.probs)[0]
|
yield superbatch
|
||||||
batches.append(choices(rand_group, k=b))
|
m, superbatch = 0, []
|
||||||
|
|
||||||
shuffle(batches)
|
|
||||||
|
|
||||||
yield from batches
|
|
||||||
|
|
||||||
|
|
||||||
class PersonalizedDataLoader(DataLoader):
|
class PersonalizedDataLoader(DataLoader):
|
||||||
def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
|
def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
|
||||||
super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory)
|
super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=VariableBatchSampler(dataset, batch_size), pin_memory=pin_memory)
|
||||||
if latent_sampling_method == "random":
|
if latent_sampling_method == "random":
|
||||||
self.collate_fn = collate_wrapper_random
|
self.collate_fn = collate_wrapper_random
|
||||||
else:
|
else:
|
||||||
|
@ -197,6 +225,9 @@ class BatchLoader:
|
||||||
self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
|
self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
|
||||||
#self.emb_index = [entry.emb_index for entry in data]
|
#self.emb_index = [entry.emb_index for entry in data]
|
||||||
#print(self.latent_sample.device)
|
#print(self.latent_sample.device)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.cond)
|
||||||
|
|
||||||
def pin_memory(self):
|
def pin_memory(self):
|
||||||
self.latent_sample = self.latent_sample.pin_memory()
|
self.latent_sample = self.latent_sample.pin_memory()
|
||||||
|
|
|
@ -465,7 +465,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
||||||
break
|
break
|
||||||
if shared.state.interrupted:
|
if shared.state.interrupted:
|
||||||
break
|
break
|
||||||
for j, batch in enumerate(dl):
|
for j, superbatch in enumerate(modules.textual_inversion.dataset.group_batches(dl, batch_size)):
|
||||||
# works as a drop_last=True for gradient accumulation
|
# works as a drop_last=True for gradient accumulation
|
||||||
if j == max_steps_per_epoch:
|
if j == max_steps_per_epoch:
|
||||||
break
|
break
|
||||||
|
@ -477,23 +477,24 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
||||||
|
|
||||||
if clip_grad:
|
if clip_grad:
|
||||||
clip_grad_sched.step(embedding.step)
|
clip_grad_sched.step(embedding.step)
|
||||||
|
|
||||||
with devices.autocast():
|
|
||||||
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
|
||||||
c = shared.sd_model.cond_stage_model(batch.cond_text)
|
|
||||||
|
|
||||||
if is_training_inpainting_model:
|
def get_loss(batch):
|
||||||
if img_c is None:
|
with devices.autocast():
|
||||||
img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height)
|
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||||
|
c = shared.sd_model.cond_stage_model(batch.cond_text)
|
||||||
|
|
||||||
cond = {"c_concat": [img_c], "c_crossattn": [c]}
|
if is_training_inpainting_model:
|
||||||
else:
|
if img_c is None:
|
||||||
cond = c
|
img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height)
|
||||||
|
|
||||||
loss = shared.sd_model(x, cond)[0] / gradient_step
|
cond = {"c_concat": [img_c], "c_crossattn": [c]}
|
||||||
del x
|
else:
|
||||||
|
cond = c
|
||||||
|
|
||||||
_loss_step += loss.item()
|
return shared.sd_model(x, cond)[0] / gradient_step * len(batch) / batch_size
|
||||||
|
|
||||||
|
loss = sum(get_loss(batch) for batch in superbatch)
|
||||||
|
_loss_step += loss.item()
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
|
|
||||||
# go back until we reach gradient accumulation steps
|
# go back until we reach gradient accumulation steps
|
||||||
|
@ -553,7 +554,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
||||||
p.width = preview_width
|
p.width = preview_width
|
||||||
p.height = preview_height
|
p.height = preview_height
|
||||||
else:
|
else:
|
||||||
p.prompt = batch.cond_text[0]
|
p.prompt = superbatch[0].cond_text[0]
|
||||||
p.steps = 20
|
p.steps = 20
|
||||||
p.width = training_width
|
p.width = training_width
|
||||||
p.height = training_height
|
p.height = training_height
|
||||||
|
@ -610,7 +611,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
||||||
<p>
|
<p>
|
||||||
Loss: {loss_step:.7f}<br/>
|
Loss: {loss_step:.7f}<br/>
|
||||||
Step: {steps_done}<br/>
|
Step: {steps_done}<br/>
|
||||||
Last prompt: {html.escape(batch.cond_text[0])}<br/>
|
Last prompt: {html.escape(superbatch[0].cond_text[0])}<br/>
|
||||||
Last saved embedding: {html.escape(last_saved_file)}<br/>
|
Last saved embedding: {html.escape(last_saved_file)}<br/>
|
||||||
Last saved image: {html.escape(last_saved_image)}<br/>
|
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||||
</p>
|
</p>
|
||||||
|
|
Loading…
Reference in New Issue
Block a user