DDIM support returned for img2img

This commit is contained in:
AUTOMATIC 2022-08-31 15:09:40 +03:00
parent 9427e4e290
commit 53e7616b51

View File

@ -94,7 +94,7 @@ samplers = [
SamplerData('DDIM', lambda: VanillaStableDiffusionSampler(DDIMSampler)), SamplerData('DDIM', lambda: VanillaStableDiffusionSampler(DDIMSampler)),
SamplerData('PLMS', lambda: VanillaStableDiffusionSampler(PLMSSampler)), SamplerData('PLMS', lambda: VanillaStableDiffusionSampler(PLMSSampler)),
] ]
samplers_for_img2img = [x for x in samplers if x.name != 'DDIM' and x.name != 'PLMS'] samplers_for_img2img = [x for x in samplers if x.name != 'PLMS']
RealesrganModelInfo = namedtuple("RealesrganModelInfo", ["name", "location", "model", "netscale"]) RealesrganModelInfo = namedtuple("RealesrganModelInfo", ["name", "location", "model", "netscale"])
@ -835,9 +835,37 @@ class StableDiffusionProcessing:
raise NotImplementedError() raise NotImplementedError()
def p_sample_ddim_hook(sampler_wrapper, x_dec, cond, ts, *args, **kwargs):
if sampler_wrapper.mask is not None:
img_orig = sampler_wrapper.sampler.model.q_sample(sampler_wrapper.init_latent, ts)
x_dec = img_orig * sampler_wrapper.mask + sampler_wrapper.nmask * x_dec
return sampler_wrapper.orig_p_sample_ddim(x_dec, cond, ts, *args, **kwargs)
class VanillaStableDiffusionSampler: class VanillaStableDiffusionSampler:
def __init__(self, constructor): def __init__(self, constructor):
self.sampler = constructor(sd_model) self.sampler = constructor(sd_model)
self.orig_p_sample_ddim = self.sampler.p_sample_ddim
self.sampler.p_sample_ddim = lambda x_dec, cond, ts, *args, **kwargs: p_sample_ddim_hook(self, x_dec, cond, ts, *args, **kwargs)
self.mask = None
self.nmask = None
self.init_latent = None
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning):
t_enc = int(min(p.denoising_strength, 0.999) * p.steps)
self.sampler.make_schedule(ddim_num_steps=p.steps, ddim_eta=0.0, verbose=False)
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(device), noise=noise)
self.mask = p.mask
self.nmask = p.nmask
self.init_latent = p.init_latent
samples = self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning)
return samples
def sample(self, p: StableDiffusionProcessing, x, conditioning, unconditional_conditioning): def sample(self, p: StableDiffusionProcessing, x, conditioning, unconditional_conditioning):
samples_ddim, _ = self.sampler.sample(S=p.steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x) samples_ddim, _ = self.sampler.sample(S=p.steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x)
@ -864,6 +892,27 @@ class KDiffusionSampler:
self.func = getattr(k_diffusion.sampling, self.funcname) self.func = getattr(k_diffusion.sampling, self.funcname)
self.model_wrap_cfg = CFGDenoiser(self.model_wrap) self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning):
t_enc = int(min(p.denoising_strength, 0.999) * p.steps)
sigmas = self.model_wrap.get_sigmas(p.steps)
noise = noise * sigmas[p.steps - t_enc - 1]
xi = x + noise
if p.mask is not None:
if p.inpainting_fill == 2:
xi = xi * p.mask + noise * p.nmask
elif p.inpainting_fill == 3:
xi = xi * p.mask
sigma_sched = sigmas[p.steps - t_enc - 1:]
def mask_cb(v):
v["denoised"][:] = v["denoised"][:] * p.nmask + p.init_latent * p.mask
return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=mask_cb if p.mask is not None else None)
def sample(self, p: StableDiffusionProcessing, x, conditioning, unconditional_conditioning): def sample(self, p: StableDiffusionProcessing, x, conditioning, unconditional_conditioning):
sigmas = self.model_wrap.get_sigmas(p.steps) sigmas = self.model_wrap.get_sigmas(p.steps)
x = x * sigmas[0] x = x * sigmas[0]
@ -1246,39 +1295,20 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.original_mask = self.original_mask.filter(ImageFilter.GaussianBlur(self.mask_blur)).convert('L') self.original_mask = self.original_mask.filter(ImageFilter.GaussianBlur(self.mask_blur)).convert('L')
latmask = self.original_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2])) latmask = self.original_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
latmask = np.moveaxis(np.array(latmask, dtype=np.float), 2, 0) / 255 latmask = np.moveaxis(np.array(latmask, dtype=np.float64), 2, 0) / 255
latmask = latmask[0] latmask = latmask[0]
latmask = np.tile(latmask[None], (4, 1, 1)) latmask = np.tile(latmask[None], (4, 1, 1))
self.mask = torch.asarray(1.0 - latmask).to(device).type(sd_model.dtype) self.mask = torch.asarray(1.0 - latmask).to(device).type(sd_model.dtype)
self.nmask = torch.asarray(latmask).to(device).type(sd_model.dtype) self.nmask = torch.asarray(latmask).to(device).type(sd_model.dtype)
def sample(self, x, conditioning, unconditional_conditioning): def sample(self, x, conditioning, unconditional_conditioning):
t_enc = int(min(self.denoising_strength, 0.999) * self.steps) samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)
sigmas = self.sampler.model_wrap.get_sigmas(self.steps)
noise = x * sigmas[self.steps - t_enc - 1]
xi = self.init_latent + noise
if self.mask is not None: if self.mask is not None:
if self.inpainting_fill == 2: samples = samples * self.nmask + self.init_latent * self.mask
xi = xi * self.mask + noise * self.nmask
elif self.inpainting_fill == 3:
xi = xi * self.mask
sigma_sched = sigmas[self.steps - t_enc - 1:] return samples
def mask_cb(v):
v["denoised"][:] = v["denoised"][:] * self.nmask + self.init_latent * self.mask
samples_ddim = self.sampler.func(self.sampler.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': self.cfg_scale}, disable=False, callback=mask_cb if self.mask is not None else None)
if self.mask is not None:
samples_ddim = samples_ddim * self.nmask + self.init_latent * self.mask
return samples_ddim
def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, use_GFPGAN: bool, prompt_matrix, mode: int, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int): def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, use_GFPGAN: bool, prompt_matrix, mode: int, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
@ -1544,6 +1574,7 @@ def run_extras(image, GFPGAN_strength, RealESRGAN_upscaling, RealESRGAN_model_in
if have_realesrgan and RealESRGAN_upscaling != 1.0: if have_realesrgan and RealESRGAN_upscaling != 1.0:
image = upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index) image = upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index)
os.makedirs(outpath, exist_ok=True)
base_count = len(os.listdir(outpath)) base_count = len(os.listdir(outpath))
save_image(image, outpath, f"{base_count:05}", None, '', opts.samples_format, short_filename=True) save_image(image, outpath, f"{base_count:05}", None, '', opts.samples_format, short_filename=True)