diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index ccbaa9ad..7b2030d4 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -201,7 +201,7 @@ class StableDiffusionModelHijack: def process_file(path, filename): name = os.path.splitext(filename)[0] - data = torch.load(path) + data = torch.load(path, map_location="cpu") # textual inversion embeddings if 'string_to_param' in data: @@ -217,7 +217,7 @@ class StableDiffusionModelHijack: if len(emb.shape) == 1: emb = emb.unsqueeze(0) - self.word_embeddings[name] = emb.detach() + self.word_embeddings[name] = emb.detach().to(device) self.word_embeddings_checksums[name] = f'{const_hash(emb.reshape(-1)*100)&0xffff:04x}' ids = tokenizer([name], add_special_tokens=False)['input_ids'][0]