diff --git a/modules/processing.py b/modules/processing.py index f773a30e..03c76070 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -372,8 +372,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" - with devices.autocast(): - samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength) + #with devices.autocast(): + + samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength) if state.interrupted: diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index f00256f2..c5fbbbfe 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -10,6 +10,7 @@ import lark # [60, 'fantasy landscape with a lake and an oak in foreground in background masterful'] # [75, 'fantasy landscape with a lake and an oak in background masterful'] # [100, 'fantasy landscape with a lake and a christmas tree in background masterful'] +from modules import devices schedule_parser = lark.Lark(r""" !start: (prompt | /[][():]/+)* @@ -130,7 +131,7 @@ def get_learned_conditioning(model, prompts, steps): continue texts = [x[1] for x in prompt_schedule] - conds = model.get_learned_conditioning(texts) + conds = model.get_learned_conditioning(texts).to(devices.dtype) cond_schedule = [] for i, (end_at_step, text) in enumerate(prompt_schedule): diff --git a/modules/sd_models.py b/modules/sd_models.py index 5f992064..5110e447 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -5,7 +5,9 @@ from collections import namedtuple import torch from omegaconf import OmegaConf - +import ldm.modules.diffusionmodules.model +import ldm.modules.diffusionmodules.openaimodel +import ldm.modules.diffusionmodules.util from ldm.util import instantiate_from_config from modules import shared, modelloader, devices @@ -27,6 +29,23 @@ except Exception: pass +def timestep_embedding(*args, **kwargs): + return ldm_modules_diffusionmodules_util_timestep_embedding(*args, **kwargs).to(devices.dtype) + + +ldm_modules_diffusionmodules_util_timestep_embedding = ldm.modules.diffusionmodules.openaimodel.timestep_embedding +ldm.modules.diffusionmodules.openaimodel.timestep_embedding = timestep_embedding + + +class GroupNorm32(torch.nn.GroupNorm): + def forward(self, x): + return super().forward(x).type(x.dtype) + + +ldm.modules.diffusionmodules.util.GroupNorm32 = GroupNorm32 + + + def setup_model(): if not os.path.exists(model_path): os.makedirs(model_path) @@ -133,6 +152,8 @@ def load_model_weights(model, checkpoint_file, sd_model_hash): model.half() devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 + model.model.diffusion_model.dtype = devices.dtype + torch.set_default_tensor_type(torch.FloatTensor if shared.cmd_opts.no_half else torch.HalfTensor) model.sd_model_hash = sd_model_hash model.sd_model_checkpint = checkpoint_file diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 497df943..4e12d354 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -7,7 +7,7 @@ import inspect import k_diffusion.sampling import ldm.models.diffusion.ddim import ldm.models.diffusion.plms -from modules import prompt_parser +from modules import prompt_parser, devices from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -339,9 +339,13 @@ class KDiffusionSampler: if p.sampler_noise_scheduler_override: sigmas = p.sampler_noise_scheduler_override(steps) + elif self.config is not None and self.config.options.get('scheduler', None) == 'karras': + sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device) else: sigmas = self.model_wrap.get_sigmas(steps) + sigmas = sigmas.to(devices.dtype) + noise = noise * sigmas[steps - t_enc - 1] xi = x + noise @@ -363,6 +367,8 @@ class KDiffusionSampler: else: sigmas = self.model_wrap.get_sigmas(steps) + sigmas = sigmas.to(devices.dtype) + x = x * sigmas[0] extra_params_kwargs = self.initialize(p)