fix the issue with training on SD2.0
This commit is contained in:
parent
e672cfb074
commit
311354c0bb
|
@ -228,6 +228,8 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
|
|||
model.sd_model_checkpoint = checkpoint_file
|
||||
model.sd_checkpoint_info = checkpoint_info
|
||||
|
||||
model.logvar = model.logvar.to(devices.device) # fix for training
|
||||
|
||||
sd_vae.delete_base_vae()
|
||||
sd_vae.clear_loaded_vae()
|
||||
vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
|
||||
|
|
|
@ -282,7 +282,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
|||
return embedding, filename
|
||||
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
||||
|
||||
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||
old_parallel_processing_allowed = shared.parallel_processing_allowed
|
||||
|
||||
|
@ -310,7 +310,6 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
|||
loss_step = 0
|
||||
_loss_step = 0 #internal
|
||||
|
||||
|
||||
last_saved_file = "<none>"
|
||||
last_saved_image = "<none>"
|
||||
forced_filename = "<none>"
|
||||
|
|
Loading…
Reference in New Issue
Block a user