From 87e8b9a2ab3f033e7fdadbb2fe258857915980ac Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 16 Sep 2022 09:47:03 +0300 Subject: [PATCH] prevent replacing torch_randn globally (instead replacing k_diffusion.sampling.torch) and add a setting to disable this all --- modules/processing.py | 2 +- modules/sd_samplers.py | 25 ++++++++++++++++++++----- modules/shared.py | 3 ++- 3 files changed, 23 insertions(+), 7 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index aab72903..5abdfd7c 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -122,7 +122,7 @@ def slerp(val, low, high): def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None): xs = [] - if p is not None and p.sampler is not None and len(seeds) > 1: + if p is not None and p.sampler is not None and len(seeds) > 1 and opts.enable_batch_seeds: sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))] else: sampler_noises = None diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index f77fe43f..d478c5bc 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -175,7 +175,19 @@ def extended_trange(count, *args, **kwargs): shared.total_tqdm.update() -original_randn_like = torch.randn_like +class TorchHijack: + def __init__(self, kdiff_sampler): + self.kdiff_sampler = kdiff_sampler + + def __getattr__(self, item): + if item == 'randn_like': + return self.kdiff_sampler.randn_like + + if hasattr(torch, item): + return getattr(torch, item) + + raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) + class KDiffusionSampler: def __init__(self, funcname, sd_model): @@ -186,8 +198,6 @@ class KDiffusionSampler: self.sampler_noises = None self.sampler_noise_index = 0 - k_diffusion.sampling.torch.randn_like = self.randn_like - def callback_state(self, d): store_latent(d["denoised"]) @@ -200,8 +210,7 @@ class KDiffusionSampler: if noise is not None and x.shape == noise.shape: res = noise else: - print('generating') - res = original_randn_like(x) + res = torch.randn_like(x) self.sampler_noise_index += 1 return res @@ -223,6 +232,9 @@ class KDiffusionSampler: if hasattr(k_diffusion.sampling, 'trange'): k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(*args, **kwargs) + if self.sampler_noises is not None: + k_diffusion.sampling.torch = TorchHijack(self) + return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state) def sample(self, p, x, conditioning, unconditional_conditioning): @@ -232,6 +244,9 @@ class KDiffusionSampler: if hasattr(k_diffusion.sampling, 'trange'): k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(*args, **kwargs) + if self.sampler_noises is not None: + k_diffusion.sampling.torch = TorchHijack(self) + samples_ddim = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state) return samples_ddim diff --git a/modules/shared.py b/modules/shared.py index bc39ad1c..ac870ec4 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -124,7 +124,8 @@ class Options: "add_model_hash_to_info": OptionInfo(False, "Add model hash to generation information"), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "font": OptionInfo("", "Font for image grids that have text"), - "enable_emphasis": OptionInfo(True, "Use (text) to make model pay more attention to text text and [text] to make it pay less attention"), + "enable_emphasis": OptionInfo(True, "Use (text) to make model pay more attention to text and [text] to make it pay less attention"), + "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."), "ESRGAN_tile": OptionInfo(192, "Tile size for upscaling. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for upscaling. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),