Fix loss calculations
This commit is contained in:
parent
90bc57da51
commit
f4f070a548
|
@ -595,8 +595,8 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||
if clip_grad:
|
||||
clip_grad_sched.step(hypernetwork.step)
|
||||
|
||||
with devices.autocast():
|
||||
for batch in superbatch:
|
||||
def get_loss(batch):
|
||||
with devices.autocast():
|
||||
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||
if tag_drop_out != 0 or shuffle_tags:
|
||||
shared.sd_model.cond_stage_model.to(devices.device)
|
||||
|
@ -604,11 +604,10 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||
else:
|
||||
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
|
||||
loss = shared.sd_model(x, c)[0] / gradient_step * len(batch) / batch_size
|
||||
del x
|
||||
del c
|
||||
return shared.sd_model(x, c)[0] / gradient_step * len(batch) / batch_size
|
||||
|
||||
_loss_step += loss.item()
|
||||
loss = sum(get_loss(batch) for batch in superbatch)
|
||||
_loss_step += loss.item()
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
# go back until we reach gradient accumulation steps
|
||||
|
@ -683,7 +682,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||
p.width = preview_width
|
||||
p.height = preview_height
|
||||
else:
|
||||
p.prompt = batch.cond_text[0]
|
||||
p.prompt = superbatch[0].cond_text[0]
|
||||
p.steps = 20
|
||||
p.width = training_width
|
||||
p.height = training_height
|
||||
|
@ -715,7 +714,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||
<p>
|
||||
Loss: {loss_step:.7f}<br/>
|
||||
Step: {steps_done}<br/>
|
||||
Last prompt: {html.escape(batch.cond_text[0])}<br/>
|
||||
Last prompt: {html.escape(superbatch[0].cond_text[0])}<br/>
|
||||
Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
|
||||
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||
</p>
|
||||
|
|
|
@ -472,8 +472,8 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
|||
if clip_grad:
|
||||
clip_grad_sched.step(embedding.step)
|
||||
|
||||
with devices.autocast():
|
||||
for batch in superbatch:
|
||||
def get_loss(batch):
|
||||
with devices.autocast():
|
||||
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||
c = shared.sd_model.cond_stage_model(batch.cond_text)
|
||||
|
||||
|
@ -485,10 +485,10 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
|||
else:
|
||||
cond = c
|
||||
|
||||
loss = shared.sd_model(x, cond)[0] / gradient_step * len(batch) / batch_size
|
||||
del x
|
||||
return shared.sd_model(x, cond)[0] / gradient_step * len(batch) / batch_size
|
||||
|
||||
_loss_step += loss.item()
|
||||
loss = sum(get_loss(batch) for batch in superbatch)
|
||||
_loss_step += loss.item()
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
# go back until we reach gradient accumulation steps
|
||||
|
@ -548,7 +548,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
|||
p.width = preview_width
|
||||
p.height = preview_height
|
||||
else:
|
||||
p.prompt = batch.cond_text[0]
|
||||
p.prompt = superbatch[0].cond_text[0]
|
||||
p.steps = 20
|
||||
p.width = training_width
|
||||
p.height = training_height
|
||||
|
@ -605,7 +605,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
|||
<p>
|
||||
Loss: {loss_step:.7f}<br/>
|
||||
Step: {steps_done}<br/>
|
||||
Last prompt: {html.escape(batch.cond_text[0])}<br/>
|
||||
Last prompt: {html.escape(superbatch[0].cond_text[0])}<br/>
|
||||
Last saved embedding: {html.escape(last_saved_file)}<br/>
|
||||
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||
</p>
|
||||
|
|
Loading…
Reference in New Issue
Block a user