Fix loss calculations

This commit is contained in:
dan 2023-01-21 22:58:30 +08:00
parent 90bc57da51
commit f4f070a548
2 changed files with 15 additions and 16 deletions

View File

@ -595,8 +595,8 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
if clip_grad: if clip_grad:
clip_grad_sched.step(hypernetwork.step) clip_grad_sched.step(hypernetwork.step)
with devices.autocast(): def get_loss(batch):
for batch in superbatch: with devices.autocast():
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
if tag_drop_out != 0 or shuffle_tags: if tag_drop_out != 0 or shuffle_tags:
shared.sd_model.cond_stage_model.to(devices.device) shared.sd_model.cond_stage_model.to(devices.device)
@ -604,11 +604,10 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.cond_stage_model.to(devices.cpu)
else: else:
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory) c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
loss = shared.sd_model(x, c)[0] / gradient_step * len(batch) / batch_size return shared.sd_model(x, c)[0] / gradient_step * len(batch) / batch_size
del x
del c
_loss_step += loss.item() loss = sum(get_loss(batch) for batch in superbatch)
_loss_step += loss.item()
scaler.scale(loss).backward() scaler.scale(loss).backward()
# go back until we reach gradient accumulation steps # go back until we reach gradient accumulation steps
@ -683,7 +682,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
p.width = preview_width p.width = preview_width
p.height = preview_height p.height = preview_height
else: else:
p.prompt = batch.cond_text[0] p.prompt = superbatch[0].cond_text[0]
p.steps = 20 p.steps = 20
p.width = training_width p.width = training_width
p.height = training_height p.height = training_height
@ -715,7 +714,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
<p> <p>
Loss: {loss_step:.7f}<br/> Loss: {loss_step:.7f}<br/>
Step: {steps_done}<br/> Step: {steps_done}<br/>
Last prompt: {html.escape(batch.cond_text[0])}<br/> Last prompt: {html.escape(superbatch[0].cond_text[0])}<br/>
Last saved hypernetwork: {html.escape(last_saved_file)}<br/> Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/> Last saved image: {html.escape(last_saved_image)}<br/>
</p> </p>

View File

@ -472,8 +472,8 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
if clip_grad: if clip_grad:
clip_grad_sched.step(embedding.step) clip_grad_sched.step(embedding.step)
with devices.autocast(): def get_loss(batch):
for batch in superbatch: with devices.autocast():
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
c = shared.sd_model.cond_stage_model(batch.cond_text) c = shared.sd_model.cond_stage_model(batch.cond_text)
@ -485,10 +485,10 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
else: else:
cond = c cond = c
loss = shared.sd_model(x, cond)[0] / gradient_step * len(batch) / batch_size return shared.sd_model(x, cond)[0] / gradient_step * len(batch) / batch_size
del x
_loss_step += loss.item() loss = sum(get_loss(batch) for batch in superbatch)
_loss_step += loss.item()
scaler.scale(loss).backward() scaler.scale(loss).backward()
# go back until we reach gradient accumulation steps # go back until we reach gradient accumulation steps
@ -548,7 +548,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
p.width = preview_width p.width = preview_width
p.height = preview_height p.height = preview_height
else: else:
p.prompt = batch.cond_text[0] p.prompt = superbatch[0].cond_text[0]
p.steps = 20 p.steps = 20
p.width = training_width p.width = training_width
p.height = training_height p.height = training_height
@ -605,7 +605,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
<p> <p>
Loss: {loss_step:.7f}<br/> Loss: {loss_step:.7f}<br/>
Step: {steps_done}<br/> Step: {steps_done}<br/>
Last prompt: {html.escape(batch.cond_text[0])}<br/> Last prompt: {html.escape(superbatch[0].cond_text[0])}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/> Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/> Last saved image: {html.escape(last_saved_image)}<br/>
</p> </p>