bugfix: save image for hires fix BEFORE upscaling latent space
This commit is contained in:
parent
321e13ca17
commit
f674c488d9
|
@ -665,17 +665,17 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix")
|
images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix")
|
||||||
|
|
||||||
if opts.use_scale_latent_for_hires_fix:
|
if opts.use_scale_latent_for_hires_fix:
|
||||||
|
for i in range(samples.shape[0]):
|
||||||
|
save_intermediate(samples, i)
|
||||||
|
|
||||||
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
|
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
|
||||||
|
|
||||||
# Avoid making the inpainting conditioning unless necessary as
|
# Avoid making the inpainting conditioning unless necessary as
|
||||||
# this does need some extra compute to decode / encode the image again.
|
# this does need some extra compute to decode / encode the image again.
|
||||||
if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
|
if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
|
||||||
image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
|
image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
|
||||||
else:
|
else:
|
||||||
image_conditioning = self.txt2img_image_conditioning(samples)
|
image_conditioning = self.txt2img_image_conditioning(samples)
|
||||||
|
|
||||||
for i in range(samples.shape[0]):
|
|
||||||
save_intermediate(samples, i)
|
|
||||||
else:
|
else:
|
||||||
decoded_samples = decode_first_stage(self.sd_model, samples)
|
decoded_samples = decode_first_stage(self.sd_model, samples)
|
||||||
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user