Make VAE step sequential to prevent VRAM spikes
This commit is contained in:
parent
0b5dcb3d7c
commit
67efee33a6
|
@ -530,8 +530,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
|
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
|
||||||
|
|
||||||
samples_ddim = samples_ddim.to(devices.dtype_vae)
|
x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
|
||||||
x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
|
x_samples_ddim = torch.stack(x_samples_ddim).float()
|
||||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
del samples_ddim
|
del samples_ddim
|
||||||
|
|
Loading…
Reference in New Issue
Block a user