Fixed copying mistake

This commit is contained in:
random_thoughtss 2022-10-19 13:56:26 -07:00
parent 8e7097d06a
commit 0719c10bf1

View File

@ -19,63 +19,35 @@ from ldm.models.diffusion.ddim import DDIMSampler, noise_like
# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddim.py # https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddim.py
# ================================================================================================= # =================================================================================================
@torch.no_grad() @torch.no_grad()
def sample( def sample(self,
self, S,
S, batch_size,
batch_size, shape,
shape, conditioning=None,
conditioning=None, callback=None,
callback=None, normals_sequence=None,
normals_sequence=None, img_callback=None,
img_callback=None, quantize_x0=False,
quantize_x0=False, eta=0.,
eta=0., mask=None,
mask=None, x0=None,
x0=None, temperature=1.,
temperature=1., noise_dropout=0.,
noise_dropout=0., score_corrector=None,
score_corrector=None, corrector_kwargs=None,
corrector_kwargs=None, verbose=True,
verbose=True, x_T=None,
x_T=None, log_every_t=100,
log_every_t=100, unconditional_guidance_scale=1.,
unconditional_guidance_scale=1., unconditional_conditioning=None,
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... **kwargs
**kwargs ):
):
if conditioning is not None: if conditioning is not None:
if isinstance(conditioning, dict): if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]] ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list): while isinstance(ctmp, list):
ctmp = elf.inpainting_fill == 2: ctmp = ctmp[0]
self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask
if self.image_mask is not None:
conditioning_mask = np.array(self.image_mask.convert("L"))
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
conditioning_mask = torch.round(conditioning_mask)
else:
conditioning_mask = torch.ones(1, 1, *image.shape[-2:])
# Create another latent image, this time with a masked version of the original input.
conditioning_mask = conditioning_mask.to(image.device)
conditioning_image = image * (1.0 - conditioning_mask)
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
# Create the concatenated conditioning tensor to be fed to `c_concat`
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:])
conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype)
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
x = create_random_tensors([opctmp[0]
cbs = ctmp.shape[0] cbs = ctmp.shape[0]
if cbs != batch_size: if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
@ -106,7 +78,6 @@ def sample(
) )
return samples, intermediates return samples, intermediates
@torch.no_grad() @torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,