From 53e7616b5133a0bffc799cae8b1a66395f975f3a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 31 Aug 2022 15:09:40 +0300 Subject: [PATCH] DDIM support returned for img2img --- webui.py | 79 +++++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 55 insertions(+), 24 deletions(-) diff --git a/webui.py b/webui.py index b8088795..80952b79 100644 --- a/webui.py +++ b/webui.py @@ -94,7 +94,7 @@ samplers = [ SamplerData('DDIM', lambda: VanillaStableDiffusionSampler(DDIMSampler)), SamplerData('PLMS', lambda: VanillaStableDiffusionSampler(PLMSSampler)), ] -samplers_for_img2img = [x for x in samplers if x.name != 'DDIM' and x.name != 'PLMS'] +samplers_for_img2img = [x for x in samplers if x.name != 'PLMS'] RealesrganModelInfo = namedtuple("RealesrganModelInfo", ["name", "location", "model", "netscale"]) @@ -835,9 +835,37 @@ class StableDiffusionProcessing: raise NotImplementedError() +def p_sample_ddim_hook(sampler_wrapper, x_dec, cond, ts, *args, **kwargs): + if sampler_wrapper.mask is not None: + img_orig = sampler_wrapper.sampler.model.q_sample(sampler_wrapper.init_latent, ts) + x_dec = img_orig * sampler_wrapper.mask + sampler_wrapper.nmask * x_dec + + return sampler_wrapper.orig_p_sample_ddim(x_dec, cond, ts, *args, **kwargs) + + class VanillaStableDiffusionSampler: def __init__(self, constructor): self.sampler = constructor(sd_model) + self.orig_p_sample_ddim = self.sampler.p_sample_ddim + self.sampler.p_sample_ddim = lambda x_dec, cond, ts, *args, **kwargs: p_sample_ddim_hook(self, x_dec, cond, ts, *args, **kwargs) + self.mask = None + self.nmask = None + self.init_latent = None + + def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning): + t_enc = int(min(p.denoising_strength, 0.999) * p.steps) + + self.sampler.make_schedule(ddim_num_steps=p.steps, ddim_eta=0.0, verbose=False) + x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(device), noise=noise) + + self.mask = p.mask + self.nmask = p.nmask + self.init_latent = p.init_latent + + samples = self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning) + + return samples + def sample(self, p: StableDiffusionProcessing, x, conditioning, unconditional_conditioning): samples_ddim, _ = self.sampler.sample(S=p.steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x) @@ -864,6 +892,27 @@ class KDiffusionSampler: self.func = getattr(k_diffusion.sampling, self.funcname) self.model_wrap_cfg = CFGDenoiser(self.model_wrap) + def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning): + t_enc = int(min(p.denoising_strength, 0.999) * p.steps) + sigmas = self.model_wrap.get_sigmas(p.steps) + noise = noise * sigmas[p.steps - t_enc - 1] + + xi = x + noise + + if p.mask is not None: + if p.inpainting_fill == 2: + xi = xi * p.mask + noise * p.nmask + elif p.inpainting_fill == 3: + xi = xi * p.mask + + sigma_sched = sigmas[p.steps - t_enc - 1:] + + def mask_cb(v): + v["denoised"][:] = v["denoised"][:] * p.nmask + p.init_latent * p.mask + + return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=mask_cb if p.mask is not None else None) + + def sample(self, p: StableDiffusionProcessing, x, conditioning, unconditional_conditioning): sigmas = self.model_wrap.get_sigmas(p.steps) x = x * sigmas[0] @@ -1246,39 +1295,20 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.original_mask = self.original_mask.filter(ImageFilter.GaussianBlur(self.mask_blur)).convert('L') latmask = self.original_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2])) - latmask = np.moveaxis(np.array(latmask, dtype=np.float), 2, 0) / 255 + latmask = np.moveaxis(np.array(latmask, dtype=np.float64), 2, 0) / 255 latmask = latmask[0] latmask = np.tile(latmask[None], (4, 1, 1)) self.mask = torch.asarray(1.0 - latmask).to(device).type(sd_model.dtype) self.nmask = torch.asarray(latmask).to(device).type(sd_model.dtype) - - def sample(self, x, conditioning, unconditional_conditioning): - t_enc = int(min(self.denoising_strength, 0.999) * self.steps) - - sigmas = self.sampler.model_wrap.get_sigmas(self.steps) - noise = x * sigmas[self.steps - t_enc - 1] - xi = self.init_latent + noise + samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning) if self.mask is not None: - if self.inpainting_fill == 2: - xi = xi * self.mask + noise * self.nmask - elif self.inpainting_fill == 3: - xi = xi * self.mask + samples = samples * self.nmask + self.init_latent * self.mask - sigma_sched = sigmas[self.steps - t_enc - 1:] - - def mask_cb(v): - v["denoised"][:] = v["denoised"][:] * self.nmask + self.init_latent * self.mask - - samples_ddim = self.sampler.func(self.sampler.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': self.cfg_scale}, disable=False, callback=mask_cb if self.mask is not None else None) - - if self.mask is not None: - samples_ddim = samples_ddim * self.nmask + self.init_latent * self.mask - - return samples_ddim + return samples def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, use_GFPGAN: bool, prompt_matrix, mode: int, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int): @@ -1544,6 +1574,7 @@ def run_extras(image, GFPGAN_strength, RealESRGAN_upscaling, RealESRGAN_model_in if have_realesrgan and RealESRGAN_upscaling != 1.0: image = upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index) + os.makedirs(outpath, exist_ok=True) base_count = len(os.listdir(outpath)) save_image(image, outpath, f"{base_count:05}", None, '', opts.samples_format, short_filename=True)