diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 6c57fdec..f076fc55 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -41,90 +41,6 @@ sampler_extra_params = { 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'], } -class CFGDenoiserEdit(torch.nn.Module): - """ - Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet) - that can take a noisy picture and produce a noise-free picture using two guidances (prompts) - instead of one. Originally, the second prompt is just an empty string, but we use non-empty - negative prompt. - """ - - def __init__(self, model): - super().__init__() - self.inner_model = model - self.mask = None - self.nmask = None - self.init_latent = None - self.step = 0 - - def combine_denoised(self, x_out, conds_list, uncond, cond_scale, image_cfg_scale): - denoised_uncond = x_out[-uncond.shape[0]:] - denoised = torch.clone(denoised_uncond) - - for i, conds in enumerate(conds_list): - for cond_index, weight in conds: - out_cond, out_img_cond, out_uncond = x_out.chunk(3) - denoised[i] = out_uncond[cond_index] + cond_scale * (out_cond[cond_index] - out_img_cond[cond_index]) + image_cfg_scale * (out_img_cond[cond_index] - out_uncond[cond_index]) - - return denoised - - def forward(self, x, sigma, uncond, cond, cond_scale, image_cond, image_cfg_scale): - if state.interrupted or state.skipped: - raise sd_samplers_common.InterruptedException - - conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) - uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) - - batch_size = len(conds_list) - repeats = [len(conds_list[i]) for i in range(batch_size)] - - x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x]) - sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma]) - image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [torch.zeros_like(self.init_latent)]) - - denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps) - cfg_denoiser_callback(denoiser_params) - x_in = denoiser_params.x - image_cond_in = denoiser_params.image_cond - sigma_in = denoiser_params.sigma - - if tensor.shape[1] == uncond.shape[1]: - cond_in = torch.cat([tensor, uncond, uncond]) - - if shared.batch_cond_uncond: - x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]}) - else: - x_out = torch.zeros_like(x_in) - for batch_offset in range(0, x_out.shape[0], batch_size): - a = batch_offset - b = a + batch_size - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]}) - else: - x_out = torch.zeros_like(x_in) - batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size - for batch_offset in range(0, tensor.shape[0], batch_size): - a = batch_offset - b = min(a + batch_size, tensor.shape[0]) - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": torch.cat([tensor[a:b]], uncond) , "c_concat": [image_cond_in[a:b]]}) - - x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]}) - - devices.test_for_nans(x_out, "unet") - - if opts.live_preview_content == "Prompt": - sd_samplers_common.store_latent(x_out[0:uncond.shape[0]]) - elif opts.live_preview_content == "Negative prompt": - sd_samplers_common.store_latent(x_out[-uncond.shape[0]:]) - - denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale, image_cfg_scale) - - if self.mask is not None: - denoised = self.init_latent * self.mask + self.nmask * denoised - - self.step += 1 - - return denoised - class CFGDenoiser(torch.nn.Module): """ @@ -141,6 +57,7 @@ class CFGDenoiser(torch.nn.Module): self.nmask = None self.init_latent = None self.step = 0 + self.image_cfg_scale = None def combine_denoised(self, x_out, conds_list, uncond, cond_scale): denoised_uncond = x_out[-uncond.shape[0]:] @@ -152,19 +69,36 @@ class CFGDenoiser(torch.nn.Module): return denoised + def combine_denoised_for_edit_model(self, x_out, cond_scale): + out_cond, out_img_cond, out_uncond = x_out.chunk(3) + denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond) + + return denoised + def forward(self, x, sigma, uncond, cond, cond_scale, image_cond): if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException + # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling, + # so is_edit_model is set to False to support AND composition. + is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0 + conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) + assert not is_edit_model or all([len(conds) == 1 for conds in conds_list]), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)" + batch_size = len(conds_list) repeats = [len(conds_list[i]) for i in range(batch_size)] - x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) - sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) - image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond]) + if not is_edit_model: + x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) + sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) + image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond]) + else: + x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x]) + sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma]) + image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [torch.zeros_like(self.init_latent)]) denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps) cfg_denoiser_callback(denoiser_params) @@ -173,7 +107,10 @@ class CFGDenoiser(torch.nn.Module): sigma_in = denoiser_params.sigma if tensor.shape[1] == uncond.shape[1]: - cond_in = torch.cat([tensor, uncond]) + if not is_edit_model: + cond_in = torch.cat([tensor, uncond]) + else: + cond_in = torch.cat([tensor, uncond, uncond]) if shared.batch_cond_uncond: x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]}) @@ -189,7 +126,13 @@ class CFGDenoiser(torch.nn.Module): for batch_offset in range(0, tensor.shape[0], batch_size): a = batch_offset b = min(a + batch_size, tensor.shape[0]) - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [tensor[a:b]], "c_concat": [image_cond_in[a:b]]}) + + if not is_edit_model: + c_crossattn = [tensor[a:b]] + else: + c_crossattn = torch.cat([tensor[a:b]], uncond) + + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": c_crossattn, "c_concat": [image_cond_in[a:b]]}) x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]}) @@ -200,7 +143,10 @@ class CFGDenoiser(torch.nn.Module): elif opts.live_preview_content == "Negative prompt": sd_samplers_common.store_latent(x_out[-uncond.shape[0]:]) - denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) + if not is_edit_model: + denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) + else: + denoised = self.combine_denoised_for_edit_model(x_out, cond_scale) if self.mask is not None: denoised = self.init_latent * self.mask + self.nmask * denoised @@ -280,12 +226,10 @@ class KDiffusionSampler: return p.steps def initialize(self, p): - if shared.sd_model.cond_stage_key == "edit" and getattr(p, 'image_cfg_scale', None) != 1: - self.model_wrap_cfg = CFGDenoiserEdit(self.model_wrap) - self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None self.model_wrap_cfg.step = 0 + self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None) self.eta = p.eta if p.eta is not None else opts.eta_ancestral k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else []) @@ -355,9 +299,6 @@ class KDiffusionSampler: 'cond_scale': p.cfg_scale, } - if hasattr(p, 'image_cfg_scale') and p.image_cfg_scale != 1 and p.image_cfg_scale != None: - extra_args['image_cfg_scale'] = p.image_cfg_scale - samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) return samples