remove dependence on TQDM for sampler progress/interrupt functionality
This commit is contained in:
parent
ec1924ee57
commit
cbf15edbf9
|
@ -402,12 +402,6 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||
with devices.autocast():
|
||||
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
|
||||
|
||||
if state.interrupted or state.skipped:
|
||||
|
||||
# if we are interrupted, sample returns just noise
|
||||
# use the image collected previously in sampler loop
|
||||
samples_ddim = shared.state.current_latent
|
||||
|
||||
samples_ddim = samples_ddim.to(devices.dtype_vae)
|
||||
x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
|
||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
|
|
@ -98,25 +98,8 @@ def store_latent(decoded):
|
|||
shared.state.current_image = sample_to_image(decoded)
|
||||
|
||||
|
||||
|
||||
def extended_tdqm(sequence, *args, desc=None, **kwargs):
|
||||
state.sampling_steps = len(sequence)
|
||||
state.sampling_step = 0
|
||||
|
||||
seq = sequence if cmd_opts.disable_console_progressbars else tqdm.tqdm(sequence, *args, desc=state.job, file=shared.progress_print_out, **kwargs)
|
||||
|
||||
for x in seq:
|
||||
if state.interrupted or state.skipped:
|
||||
break
|
||||
|
||||
yield x
|
||||
|
||||
state.sampling_step += 1
|
||||
shared.total_tqdm.update()
|
||||
|
||||
|
||||
ldm.models.diffusion.ddim.tqdm = lambda *args, desc=None, **kwargs: extended_tdqm(*args, desc=desc, **kwargs)
|
||||
ldm.models.diffusion.plms.tqdm = lambda *args, desc=None, **kwargs: extended_tdqm(*args, desc=desc, **kwargs)
|
||||
class InterruptedException(BaseException):
|
||||
pass
|
||||
|
||||
|
||||
class VanillaStableDiffusionSampler:
|
||||
|
@ -128,14 +111,32 @@ class VanillaStableDiffusionSampler:
|
|||
self.init_latent = None
|
||||
self.sampler_noises = None
|
||||
self.step = 0
|
||||
self.stop_at = None
|
||||
self.eta = None
|
||||
self.default_eta = 0.0
|
||||
self.config = None
|
||||
self.last_latent = None
|
||||
|
||||
def number_of_needed_noises(self, p):
|
||||
return 0
|
||||
|
||||
def launch_sampling(self, steps, func):
|
||||
state.sampling_steps = steps
|
||||
state.sampling_step = 0
|
||||
|
||||
try:
|
||||
return func()
|
||||
except InterruptedException:
|
||||
return self.last_latent
|
||||
|
||||
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
|
||||
if state.interrupted or state.skipped:
|
||||
raise InterruptedException
|
||||
|
||||
if self.stop_at is not None and self.step > self.stop_at:
|
||||
raise InterruptedException
|
||||
|
||||
|
||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
|
||||
|
||||
|
@ -159,11 +160,16 @@ class VanillaStableDiffusionSampler:
|
|||
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
|
||||
|
||||
if self.mask is not None:
|
||||
store_latent(self.init_latent * self.mask + self.nmask * res[1])
|
||||
self.last_latent = self.init_latent * self.mask + self.nmask * res[1]
|
||||
else:
|
||||
store_latent(res[1])
|
||||
self.last_latent = res[1]
|
||||
|
||||
store_latent(self.last_latent)
|
||||
|
||||
self.step += 1
|
||||
state.sampling_step = self.step
|
||||
shared.total_tqdm.update()
|
||||
|
||||
return res
|
||||
|
||||
def initialize(self, p):
|
||||
|
@ -192,7 +198,7 @@ class VanillaStableDiffusionSampler:
|
|||
self.init_latent = x
|
||||
self.step = 0
|
||||
|
||||
samples = self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning)
|
||||
samples = self.launch_sampling(steps, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
|
||||
|
||||
return samples
|
||||
|
||||
|
@ -206,9 +212,9 @@ class VanillaStableDiffusionSampler:
|
|||
|
||||
# existing code fails with certain step counts, like 9
|
||||
try:
|
||||
samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)
|
||||
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
|
||||
except Exception:
|
||||
samples_ddim, _ = self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)
|
||||
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
|
||||
|
||||
return samples_ddim
|
||||
|
||||
|
@ -223,6 +229,9 @@ class CFGDenoiser(torch.nn.Module):
|
|||
self.step = 0
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||
if state.interrupted or state.skipped:
|
||||
raise InterruptedException
|
||||
|
||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
||||
|
||||
|
@ -268,25 +277,6 @@ class CFGDenoiser(torch.nn.Module):
|
|||
return denoised
|
||||
|
||||
|
||||
def extended_trange(sampler, count, *args, **kwargs):
|
||||
state.sampling_steps = count
|
||||
state.sampling_step = 0
|
||||
|
||||
seq = range(count) if cmd_opts.disable_console_progressbars else tqdm.trange(count, *args, desc=state.job, file=shared.progress_print_out, **kwargs)
|
||||
|
||||
for x in seq:
|
||||
if state.interrupted or state.skipped:
|
||||
break
|
||||
|
||||
if sampler.stop_at is not None and x > sampler.stop_at:
|
||||
break
|
||||
|
||||
yield x
|
||||
|
||||
state.sampling_step += 1
|
||||
shared.total_tqdm.update()
|
||||
|
||||
|
||||
class TorchHijack:
|
||||
def __init__(self, kdiff_sampler):
|
||||
self.kdiff_sampler = kdiff_sampler
|
||||
|
@ -314,9 +304,28 @@ class KDiffusionSampler:
|
|||
self.eta = None
|
||||
self.default_eta = 1.0
|
||||
self.config = None
|
||||
self.last_latent = None
|
||||
|
||||
def callback_state(self, d):
|
||||
store_latent(d["denoised"])
|
||||
step = d['i']
|
||||
latent = d["denoised"]
|
||||
store_latent(latent)
|
||||
self.last_latent = latent
|
||||
|
||||
if self.stop_at is not None and step > self.stop_at:
|
||||
raise InterruptedException
|
||||
|
||||
state.sampling_step = step
|
||||
shared.total_tqdm.update()
|
||||
|
||||
def launch_sampling(self, steps, func):
|
||||
state.sampling_steps = steps
|
||||
state.sampling_step = 0
|
||||
|
||||
try:
|
||||
return func()
|
||||
except InterruptedException:
|
||||
return self.last_latent
|
||||
|
||||
def number_of_needed_noises(self, p):
|
||||
return p.steps
|
||||
|
@ -339,9 +348,6 @@ class KDiffusionSampler:
|
|||
self.sampler_noise_index = 0
|
||||
self.eta = p.eta or opts.eta_ancestral
|
||||
|
||||
if hasattr(k_diffusion.sampling, 'trange'):
|
||||
k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(self, *args, **kwargs)
|
||||
|
||||
if self.sampler_noises is not None:
|
||||
k_diffusion.sampling.torch = TorchHijack(self)
|
||||
|
||||
|
@ -383,8 +389,9 @@ class KDiffusionSampler:
|
|||
|
||||
self.model_wrap_cfg.init_latent = x
|
||||
|
||||
return self.func(self.model_wrap_cfg, xi, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
|
||||
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
|
||||
return samples
|
||||
|
||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
|
||||
steps = steps or p.steps
|
||||
|
@ -406,6 +413,8 @@ class KDiffusionSampler:
|
|||
extra_params_kwargs['n'] = steps
|
||||
else:
|
||||
extra_params_kwargs['sigmas'] = sigmas
|
||||
samples = self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
|
||||
|
||||
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
|
||||
return samples
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user