diff --git a/modules/processing.py b/modules/processing.py index bb94033b..d8c6b8d5 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -360,7 +360,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: #c = p.sd_model.get_learned_conditioning(prompts) with devices.autocast(): uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps) - c = prompt_parser.get_learned_conditioning(shared.sd_model, prompts, p.steps) + c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps) if len(model_hijack.comments) > 0: for comment in model_hijack.comments: diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index a3b12421..f7420daf 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -97,10 +97,26 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"]) -ScheduledPromptBatch = namedtuple("ScheduledPromptBatch", ["shape", "schedules"]) def get_learned_conditioning(model, prompts, steps): + """converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond), + and the sampling step at which this condition is to be replaced by the next one. + + Input: + (model, ['a red crown', 'a [blue:green:5] jeweled crown'], 20) + + Output: + [ + [ + ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0523, ..., -0.4901, -0.3066, 0.0674], ..., [ 0.3317, -0.5102, -0.4066, ..., 0.4119, -0.7647, -1.0160]], device='cuda:0')) + ], + [ + ScheduledPromptConditioning(end_at_step=5, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.0192, 0.3867, -0.4644, ..., 0.1135, -0.3696, -0.4625]], device='cuda:0')), + ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.7352, -0.4356, -0.7888, ..., 0.6994, -0.4312, -1.2593]], device='cuda:0')) + ] + ] + """ res = [] prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps) @@ -123,13 +139,75 @@ def get_learned_conditioning(model, prompts, steps): cache[prompt] = cond_schedule res.append(cond_schedule) - return ScheduledPromptBatch((len(prompts),) + res[0][0].cond.shape, res) + return res -def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step): - param = c.schedules[0][0].cond - res = torch.zeros(c.shape, device=param.device, dtype=param.dtype) - for i, cond_schedule in enumerate(c.schedules): +re_AND = re.compile(r"\bAND\b") +re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?\s*(?:\d+|\d*\.\d+)?))?\s*$") + + +def get_multicond_prompt_list(prompts): + res_indexes = [] + + prompt_flat_list = [] + prompt_indexes = {} + + for prompt in prompts: + subprompts = re_AND.split(prompt) + + indexes = [] + for subprompt in subprompts: + text, weight = re_weight.search(subprompt).groups() + + weight = float(weight) if weight is not None else 1.0 + + index = prompt_indexes.get(text, None) + if index is None: + index = len(prompt_flat_list) + prompt_flat_list.append(text) + prompt_indexes[text] = index + + indexes.append((index, weight)) + + res_indexes.append(indexes) + + return res_indexes, prompt_flat_list, prompt_indexes + + +class ComposableScheduledPromptConditioning: + def __init__(self, schedules, weight=1.0): + self.schedules: list[ScheduledPromptConditioning] = schedules + self.weight: float = weight + + +class MulticondLearnedConditioning: + def __init__(self, shape, batch): + self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS + self.batch: list[list[ComposableScheduledPromptConditioning]] = batch + + +def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning: + """same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt. + For each prompt, the list is obtained by splitting the prompt using the AND separator. + + https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/ + """ + + res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts) + + learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps) + + res = [] + for indexes in res_indexes: + res.append([ComposableScheduledPromptConditioning(learned_conditioning[i], weight) for i, weight in indexes]) + + return MulticondLearnedConditioning(shape=(len(prompts),), batch=res) + + +def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_step): + param = c[0][0].cond + res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) + for i, cond_schedule in enumerate(c): target_index = 0 for current, (end_at, cond) in enumerate(cond_schedule): if current_step <= end_at: @@ -140,6 +218,30 @@ def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step): return res +def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step): + param = c.batch[0][0].schedules[0].cond + + tensors = [] + conds_list = [] + + for batch_no, composable_prompts in enumerate(c.batch): + conds_for_batch = [] + + for cond_index, composable_prompt in enumerate(composable_prompts): + target_index = 0 + for current, (end_at, cond) in enumerate(composable_prompt.schedules): + if current_step <= end_at: + target_index = current + break + + conds_for_batch.append((len(tensors), composable_prompt.weight)) + tensors.append(composable_prompt.schedules[target_index].cond) + + conds_list.append(conds_for_batch) + + return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype) + + re_attention = re.compile(r""" \\\(| \\\)| diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index dbf570d2..d27c547b 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -109,9 +109,12 @@ class VanillaStableDiffusionSampler: return 0 def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs): - cond = prompt_parser.reconstruct_cond_batch(cond, self.step) + conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step) + assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers' + cond = tensor + if self.mask is not None: img_orig = self.sampler.model.q_sample(self.init_latent, ts) x_dec = img_orig * self.mask + self.nmask * x_dec @@ -183,19 +186,31 @@ class CFGDenoiser(torch.nn.Module): self.step = 0 def forward(self, x, sigma, uncond, cond, cond_scale): - cond = prompt_parser.reconstruct_cond_batch(cond, self.step) + conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) + batch_size = len(conds_list) + repeats = [len(conds_list[i]) for i in range(batch_size)] + + x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) + sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) + cond_in = torch.cat([tensor, uncond]) + if shared.batch_cond_uncond: - x_in = torch.cat([x] * 2) - sigma_in = torch.cat([sigma] * 2) - cond_in = torch.cat([uncond, cond]) - uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) - denoised = uncond + (cond - uncond) * cond_scale + x_out = self.inner_model(x_in, sigma_in, cond=cond_in) else: - uncond = self.inner_model(x, sigma, cond=uncond) - cond = self.inner_model(x, sigma, cond=cond) - denoised = uncond + (cond - uncond) * cond_scale + x_out = torch.zeros_like(x_in) + for batch_offset in range(0, x_out.shape[0], batch_size): + a = batch_offset + b = a + batch_size + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=cond_in[a:b]) + + denoised_uncond = x_out[-batch_size:] + denoised = torch.clone(denoised_uncond) + + for i, conds in enumerate(conds_list): + for cond_index, weight in conds: + denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale) if self.mask is not None: denoised = self.init_latent * self.mask + self.nmask * denoised diff --git a/modules/ui.py b/modules/ui.py index 523ab25b..9620350f 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -34,7 +34,7 @@ import modules.gfpgan_model import modules.codeformer_model import modules.styles import modules.generation_parameters_copypaste -from modules.prompt_parser import get_learned_conditioning_prompt_schedules +from modules import prompt_parser from modules.images import apply_filename_pattern, get_next_sequence_number import modules.textual_inversion.ui @@ -394,7 +394,9 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: def update_token_counter(text, steps): try: - prompt_schedules = get_learned_conditioning_prompt_schedules([text], steps) + _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) + prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) + except Exception: # a parsing error can happen here during typing, and we don't want to bother the user with # messages related to it in console