Fix loss calculations
This commit is contained in:
parent
90bc57da51
commit
f4f070a548
|
@ -595,8 +595,8 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
||||||
if clip_grad:
|
if clip_grad:
|
||||||
clip_grad_sched.step(hypernetwork.step)
|
clip_grad_sched.step(hypernetwork.step)
|
||||||
|
|
||||||
|
def get_loss(batch):
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
for batch in superbatch:
|
|
||||||
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||||
if tag_drop_out != 0 or shuffle_tags:
|
if tag_drop_out != 0 or shuffle_tags:
|
||||||
shared.sd_model.cond_stage_model.to(devices.device)
|
shared.sd_model.cond_stage_model.to(devices.device)
|
||||||
|
@ -604,10 +604,9 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
||||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||||
else:
|
else:
|
||||||
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
|
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
|
||||||
loss = shared.sd_model(x, c)[0] / gradient_step * len(batch) / batch_size
|
return shared.sd_model(x, c)[0] / gradient_step * len(batch) / batch_size
|
||||||
del x
|
|
||||||
del c
|
|
||||||
|
|
||||||
|
loss = sum(get_loss(batch) for batch in superbatch)
|
||||||
_loss_step += loss.item()
|
_loss_step += loss.item()
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
|
|
||||||
|
@ -683,7 +682,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
||||||
p.width = preview_width
|
p.width = preview_width
|
||||||
p.height = preview_height
|
p.height = preview_height
|
||||||
else:
|
else:
|
||||||
p.prompt = batch.cond_text[0]
|
p.prompt = superbatch[0].cond_text[0]
|
||||||
p.steps = 20
|
p.steps = 20
|
||||||
p.width = training_width
|
p.width = training_width
|
||||||
p.height = training_height
|
p.height = training_height
|
||||||
|
@ -715,7 +714,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
||||||
<p>
|
<p>
|
||||||
Loss: {loss_step:.7f}<br/>
|
Loss: {loss_step:.7f}<br/>
|
||||||
Step: {steps_done}<br/>
|
Step: {steps_done}<br/>
|
||||||
Last prompt: {html.escape(batch.cond_text[0])}<br/>
|
Last prompt: {html.escape(superbatch[0].cond_text[0])}<br/>
|
||||||
Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
|
Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
|
||||||
Last saved image: {html.escape(last_saved_image)}<br/>
|
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||||
</p>
|
</p>
|
||||||
|
|
|
@ -472,8 +472,8 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
||||||
if clip_grad:
|
if clip_grad:
|
||||||
clip_grad_sched.step(embedding.step)
|
clip_grad_sched.step(embedding.step)
|
||||||
|
|
||||||
|
def get_loss(batch):
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
for batch in superbatch:
|
|
||||||
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||||
c = shared.sd_model.cond_stage_model(batch.cond_text)
|
c = shared.sd_model.cond_stage_model(batch.cond_text)
|
||||||
|
|
||||||
|
@ -485,9 +485,9 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
||||||
else:
|
else:
|
||||||
cond = c
|
cond = c
|
||||||
|
|
||||||
loss = shared.sd_model(x, cond)[0] / gradient_step * len(batch) / batch_size
|
return shared.sd_model(x, cond)[0] / gradient_step * len(batch) / batch_size
|
||||||
del x
|
|
||||||
|
|
||||||
|
loss = sum(get_loss(batch) for batch in superbatch)
|
||||||
_loss_step += loss.item()
|
_loss_step += loss.item()
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
|
|
||||||
|
@ -548,7 +548,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
||||||
p.width = preview_width
|
p.width = preview_width
|
||||||
p.height = preview_height
|
p.height = preview_height
|
||||||
else:
|
else:
|
||||||
p.prompt = batch.cond_text[0]
|
p.prompt = superbatch[0].cond_text[0]
|
||||||
p.steps = 20
|
p.steps = 20
|
||||||
p.width = training_width
|
p.width = training_width
|
||||||
p.height = training_height
|
p.height = training_height
|
||||||
|
@ -605,7 +605,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
||||||
<p>
|
<p>
|
||||||
Loss: {loss_step:.7f}<br/>
|
Loss: {loss_step:.7f}<br/>
|
||||||
Step: {steps_done}<br/>
|
Step: {steps_done}<br/>
|
||||||
Last prompt: {html.escape(batch.cond_text[0])}<br/>
|
Last prompt: {html.escape(superbatch[0].cond_text[0])}<br/>
|
||||||
Last saved embedding: {html.escape(last_saved_file)}<br/>
|
Last saved embedding: {html.escape(last_saved_file)}<br/>
|
||||||
Last saved image: {html.escape(last_saved_image)}<br/>
|
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||||
</p>
|
</p>
|
||||||
|
|
Loading…
Reference in New Issue
Block a user