fix dropout, implement train/eval mode
This commit is contained in:
parent
89d8ecff09
commit
d2c97fc3fe
|
@ -154,16 +154,28 @@ class Hypernetwork:
|
||||||
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
|
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
|
||||||
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
|
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
|
||||||
)
|
)
|
||||||
|
self.eval_mode()
|
||||||
|
|
||||||
def weights(self):
|
def weights(self):
|
||||||
res = []
|
res = []
|
||||||
|
for k, layers in self.layers.items():
|
||||||
|
for layer in layers:
|
||||||
|
res += layer.parameters()
|
||||||
|
return res
|
||||||
|
|
||||||
|
def train_mode(self):
|
||||||
for k, layers in self.layers.items():
|
for k, layers in self.layers.items():
|
||||||
for layer in layers:
|
for layer in layers:
|
||||||
layer.train()
|
layer.train()
|
||||||
res += layer.trainables()
|
for param in layer.parameters():
|
||||||
|
param.requires_grad = True
|
||||||
|
|
||||||
return res
|
def eval_mode(self):
|
||||||
|
for k, layers in self.layers.items():
|
||||||
|
for layer in layers:
|
||||||
|
layer.eval()
|
||||||
|
for param in layer.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
def save(self, filename):
|
def save(self, filename):
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
|
@ -426,8 +438,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
||||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||||
|
|
||||||
weights = hypernetwork.weights()
|
weights = hypernetwork.weights()
|
||||||
for weight in weights:
|
hypernetwork.train_mode()
|
||||||
weight.requires_grad = True
|
|
||||||
|
|
||||||
# Here we use optimizer from saved HN, or we can specify as UI option.
|
# Here we use optimizer from saved HN, or we can specify as UI option.
|
||||||
if hypernetwork.optimizer_name in optimizer_dict:
|
if hypernetwork.optimizer_name in optimizer_dict:
|
||||||
|
@ -538,7 +549,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
||||||
if images_dir is not None and steps_done % create_image_every == 0:
|
if images_dir is not None and steps_done % create_image_every == 0:
|
||||||
forced_filename = f'{hypernetwork_name}-{steps_done}'
|
forced_filename = f'{hypernetwork_name}-{steps_done}'
|
||||||
last_saved_image = os.path.join(images_dir, forced_filename)
|
last_saved_image = os.path.join(images_dir, forced_filename)
|
||||||
|
hypernetwork.eval_mode()
|
||||||
shared.sd_model.cond_stage_model.to(devices.device)
|
shared.sd_model.cond_stage_model.to(devices.device)
|
||||||
shared.sd_model.first_stage_model.to(devices.device)
|
shared.sd_model.first_stage_model.to(devices.device)
|
||||||
|
|
||||||
|
@ -571,7 +582,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
||||||
if unload:
|
if unload:
|
||||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||||
|
hypernetwork.train_mode()
|
||||||
if image is not None:
|
if image is not None:
|
||||||
shared.state.current_image = image
|
shared.state.current_image = image
|
||||||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
||||||
|
@ -593,6 +604,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
||||||
finally:
|
finally:
|
||||||
pbar.leave = False
|
pbar.leave = False
|
||||||
pbar.close()
|
pbar.close()
|
||||||
|
hypernetwork.eval_mode()
|
||||||
#report_statistics(loss_dict)
|
#report_statistics(loss_dict)
|
||||||
|
|
||||||
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
||||||
|
|
Loading…
Reference in New Issue
Block a user