This commit is contained in:
guaneec 2023-02-06 10:11:48 +08:00 committed by GitHub
commit 820e26db7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 91 additions and 59 deletions

View File

@ -625,7 +625,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
break break
if shared.state.interrupted: if shared.state.interrupted:
break break
for j, batch in enumerate(dl): for j, superbatch in enumerate(modules.textual_inversion.dataset.group_batches(dl, batch_size)):
# works as a drop_last=True for gradient accumulation # works as a drop_last=True for gradient accumulation
if j == max_steps_per_epoch: if j == max_steps_per_epoch:
break break
@ -638,19 +638,19 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
if clip_grad: if clip_grad:
clip_grad_sched.step(hypernetwork.step) clip_grad_sched.step(hypernetwork.step)
with devices.autocast(): def get_loss(batch):
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) with devices.autocast():
if tag_drop_out != 0 or shuffle_tags: x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
shared.sd_model.cond_stage_model.to(devices.device) if tag_drop_out != 0 or shuffle_tags:
c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory) shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.cond_stage_model.to(devices.cpu) c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
else: shared.sd_model.cond_stage_model.to(devices.cpu)
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory) else:
loss = shared.sd_model(x, c)[0] / gradient_step c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
del x return shared.sd_model(x, c)[0] / gradient_step * len(batch) / batch_size
del c
_loss_step += loss.item() loss = sum(get_loss(batch) for batch in superbatch)
_loss_step += loss.item()
scaler.scale(loss).backward() scaler.scale(loss).backward()
# go back until we reach gradient accumulation steps # go back until we reach gradient accumulation steps
@ -727,7 +727,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
p.width = preview_width p.width = preview_width
p.height = preview_height p.height = preview_height
else: else:
p.prompt = batch.cond_text[0] p.prompt = superbatch[0].cond_text[0]
p.steps = 20 p.steps = 20
p.width = training_width p.width = training_width
p.height = training_height p.height = training_height
@ -759,7 +759,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
<p> <p>
Loss: {loss_step:.7f}<br/> Loss: {loss_step:.7f}<br/>
Step: {steps_done}<br/> Step: {steps_done}<br/>
Last prompt: {html.escape(batch.cond_text[0])}<br/> Last prompt: {html.escape(superbatch[0].cond_text[0])}<br/>
Last saved hypernetwork: {html.escape(last_saved_file)}<br/> Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/> Last saved image: {html.escape(last_saved_image)}<br/>
</p> </p>

View File

@ -6,7 +6,6 @@ from PIL import Image
from torch.utils.data import Dataset, DataLoader, Sampler from torch.utils.data import Dataset, DataLoader, Sampler
from torchvision import transforms from torchvision import transforms
from collections import defaultdict from collections import defaultdict
from random import shuffle, choices
import random import random
import tqdm import tqdm
@ -147,43 +146,72 @@ class PersonalizedBase(Dataset):
return entry return entry
class GroupedBatchSampler(Sampler): def greedy_pack(tails, b):
'''inefficient suboptimal packing of remainders from each bucket'''
by_len = defaultdict(list)
for t in tails:
by_len[len(t)].append(t)
n = sum(len(t) for t in tails) // b
superbatches = []
for _ in range(n):
to_pick = b
superbatch = []
while to_pick:
for k, v in sorted(by_len.items(), reverse=True):
# try pick longest
if k <= to_pick:
to_pick -= k
superbatch.append(v.pop())
if not v:
del(by_len[k])
break
else:
# can't find any so split a group
maxlen = max(by_len)
tail = by_len[maxlen].pop()
if not by_len[maxlen]:
del by_len[maxlen]
superbatch.append(tail[:to_pick])
by_len[len(tail[to_pick:])].append(tail[to_pick:])
to_pick = 0
superbatches.append(superbatch)
return superbatches
class VariableBatchSampler(Sampler):
def __init__(self, data_source: PersonalizedBase, batch_size: int): def __init__(self, data_source: PersonalizedBase, batch_size: int):
super().__init__(data_source) self.n = len(data_source)
n = len(data_source)
self.groups = data_source.groups self.groups = data_source.groups
self.len = n_batch = n // batch_size
expected = [len(g) / n * n_batch * batch_size for g in data_source.groups]
self.base = [int(e) // batch_size for e in expected]
self.n_rand_batches = nrb = n_batch - sum(self.base)
self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]
self.batch_size = batch_size self.batch_size = batch_size
def __len__(self):
return self.len
def __iter__(self): def __iter__(self):
b = self.batch_size b = self.batch_size
dropped = set(random.sample(range(self.n), self.n % b))
groups = [[x for x in g if x not in dropped] for g in self.groups]
for g in groups:
random.shuffle(g)
superbatches = []
for g in groups:
superbatches.extend([g[i*b:(i+1)*b]] for i in range(len(g) // b))
tails = [g[-(len(g) % b):] for g in groups if len(g) % b != 0]
random.shuffle(tails)
superbatches.extend(greedy_pack(tails, b))
random.shuffle(superbatches)
yield from [batch for superbatch in superbatches for batch in superbatch]
for g in self.groups: def group_batches(batches, batch_size):
shuffle(g) m, superbatch = 0, []
for batch in batches:
batches = [] superbatch.append(batch)
for g in self.groups: m += len(batch)
batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b)) assert m <= batch_size
for _ in range(self.n_rand_batches): if m == batch_size:
rand_group = choices(self.groups, self.probs)[0] yield superbatch
batches.append(choices(rand_group, k=b)) m, superbatch = 0, []
shuffle(batches)
yield from batches
class PersonalizedDataLoader(DataLoader): class PersonalizedDataLoader(DataLoader):
def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False): def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory) super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=VariableBatchSampler(dataset, batch_size), pin_memory=pin_memory)
if latent_sampling_method == "random": if latent_sampling_method == "random":
self.collate_fn = collate_wrapper_random self.collate_fn = collate_wrapper_random
else: else:
@ -197,6 +225,9 @@ class BatchLoader:
self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1) self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
#self.emb_index = [entry.emb_index for entry in data] #self.emb_index = [entry.emb_index for entry in data]
#print(self.latent_sample.device) #print(self.latent_sample.device)
def __len__(self):
return len(self.cond)
def pin_memory(self): def pin_memory(self):
self.latent_sample = self.latent_sample.pin_memory() self.latent_sample = self.latent_sample.pin_memory()

View File

@ -465,7 +465,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
break break
if shared.state.interrupted: if shared.state.interrupted:
break break
for j, batch in enumerate(dl): for j, superbatch in enumerate(modules.textual_inversion.dataset.group_batches(dl, batch_size)):
# works as a drop_last=True for gradient accumulation # works as a drop_last=True for gradient accumulation
if j == max_steps_per_epoch: if j == max_steps_per_epoch:
break break
@ -477,23 +477,24 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
if clip_grad: if clip_grad:
clip_grad_sched.step(embedding.step) clip_grad_sched.step(embedding.step)
with devices.autocast():
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
c = shared.sd_model.cond_stage_model(batch.cond_text)
if is_training_inpainting_model: def get_loss(batch):
if img_c is None: with devices.autocast():
img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height) x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
c = shared.sd_model.cond_stage_model(batch.cond_text)
cond = {"c_concat": [img_c], "c_crossattn": [c]} if is_training_inpainting_model:
else: if img_c is None:
cond = c img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height)
loss = shared.sd_model(x, cond)[0] / gradient_step cond = {"c_concat": [img_c], "c_crossattn": [c]}
del x else:
cond = c
_loss_step += loss.item() return shared.sd_model(x, cond)[0] / gradient_step * len(batch) / batch_size
loss = sum(get_loss(batch) for batch in superbatch)
_loss_step += loss.item()
scaler.scale(loss).backward() scaler.scale(loss).backward()
# go back until we reach gradient accumulation steps # go back until we reach gradient accumulation steps
@ -553,7 +554,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
p.width = preview_width p.width = preview_width
p.height = preview_height p.height = preview_height
else: else:
p.prompt = batch.cond_text[0] p.prompt = superbatch[0].cond_text[0]
p.steps = 20 p.steps = 20
p.width = training_width p.width = training_width
p.height = training_height p.height = training_height
@ -610,7 +611,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
<p> <p>
Loss: {loss_step:.7f}<br/> Loss: {loss_step:.7f}<br/>
Step: {steps_done}<br/> Step: {steps_done}<br/>
Last prompt: {html.escape(batch.cond_text[0])}<br/> Last prompt: {html.escape(superbatch[0].cond_text[0])}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/> Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/> Last saved image: {html.escape(last_saved_image)}<br/>
</p> </p>