From 1764ac3c8bc482bd575987850e96630d9115e51a Mon Sep 17 00:00:00 2001 From: aria1th <35677394+aria1th@users.noreply.github.com> Date: Thu, 3 Nov 2022 14:49:26 +0900 Subject: [PATCH] use hash to check valid optim --- modules/hypernetworks/hypernetwork.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 63c25de8..4230b8cf 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -177,11 +177,12 @@ class Hypernetwork: state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name if self.optimizer_name is not None: optimizer_saved_dict['optimizer_name'] = self.optimizer_name - if self.optimizer_state_dict: - optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict - torch.save(optimizer_saved_dict, filename + '.optim') torch.save(state_dict, filename) + if self.optimizer_state_dict: + optimizer_saved_dict['hash'] = sd_models.model_hash(filename) + optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict + torch.save(optimizer_saved_dict, filename + '.optim') def load(self, filename): self.filename = filename @@ -204,7 +205,10 @@ class Hypernetwork: optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {} self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW') print(f"Optimizer name is {self.optimizer_name}") - self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None) + if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None): + self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None) + else: + self.optimizer_state_dict = None if self.optimizer_state_dict: print("Loaded existing optimizer from checkpoint") else: @@ -229,7 +233,7 @@ def list_hypernetworks(path): name = os.path.splitext(os.path.basename(filename))[0] # Prevent a hypothetical "None.pt" from being listed. if name != "None": - res[name] = filename + res[name + f"({sd_models.model_hash(filename)})"] = filename return res @@ -375,6 +379,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log else: hypernetwork_dir = None + hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0] if create_image_every > 0: images_dir = os.path.join(log_directory, "images") os.makedirs(images_dir, exist_ok=True)