diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index ee4c5d02..a7a6aa31 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -1,6 +1,6 @@ import re from collections import namedtuple - +from typing import List import lark # a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]" @@ -175,15 +175,14 @@ def get_multicond_prompt_list(prompts): class ComposableScheduledPromptConditioning: def __init__(self, schedules, weight=1.0): - self.schedules = schedules # : list[ScheduledPromptConditioning] + self.schedules: List[ScheduledPromptConditioning] = schedules self.weight: float = weight class MulticondLearnedConditioning: def __init__(self, shape, batch): self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS - self.batch = batch # : list[list[ComposableScheduledPromptConditioning]] - + self.batch: List[List[ComposableScheduledPromptConditioning]] = batch def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning: """same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt. @@ -203,7 +202,7 @@ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearne return MulticondLearnedConditioning(shape=(len(prompts),), batch=res) -def reconstruct_cond_batch(c, current_step): # c: list[list[ScheduledPromptConditioning]] +def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step): param = c[0][0].cond res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) for i, cond_schedule in enumerate(c):