From 668d7e9b9aba1770beae48a8664e0351fcd59f31 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 5 Feb 2023 11:20:47 +0300 Subject: [PATCH] make it possible to load SD1 checkpoints without CLIP --- modules/sd_disable_initialization.py | 17 ++++++++++------- modules/sd_models.py | 6 +++++- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index e90aa9fe..c4a09d15 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -20,8 +20,9 @@ class DisableInitialization: ``` """ - def __init__(self): + def __init__(self, disable_clip=True): self.replaced = [] + self.disable_clip = disable_clip def replace(self, obj, field, func): original = getattr(obj, field, None) @@ -75,12 +76,14 @@ class DisableInitialization: self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing) self.replace(torch.nn.init, '_no_grad_normal_', do_nothing) self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing) - self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained) - self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained) - self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model) - self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file) - self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file) - self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache) + + if self.disable_clip: + self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained) + self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained) + self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model) + self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file) + self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file) + self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache) def __exit__(self, exc_type, exc_val, exc_tb): for obj, field, original in self.replaced: diff --git a/modules/sd_models.py b/modules/sd_models.py index af1731e5..d847d358 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -354,6 +354,9 @@ def repair_config(sd_config): sd_config.model.params.unet_config.params.use_fp16 = True +sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight' +sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight' + def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() @@ -374,6 +377,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_ state_dict = get_checkpoint_state_dict(checkpoint_info, timer) checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) + clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict timer.record("find config") @@ -386,7 +390,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_ sd_model = None try: - with sd_disable_initialization.DisableInitialization(): + with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd): sd_model = instantiate_from_config(sd_config.model) except Exception as e: pass