Merge pull request #7234 from brkirch/fix-full-previews
Fix full previews and--no-half-vae to work correctly with --upcast-sampling
This commit is contained in:
commit
645f4e7ef8
|
@ -172,7 +172,7 @@ class StableDiffusionProcessing:
|
||||||
midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
|
midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
|
||||||
midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
|
midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
|
||||||
|
|
||||||
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image.to(devices.dtype_unet) if devices.unet_needs_upcast else source_image))
|
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image.to(devices.dtype_vae) if devices.unet_needs_upcast else source_image))
|
||||||
conditioning_image = conditioning_image.float() if devices.unet_needs_upcast else conditioning_image
|
conditioning_image = conditioning_image.float() if devices.unet_needs_upcast else conditioning_image
|
||||||
conditioning = torch.nn.functional.interpolate(
|
conditioning = torch.nn.functional.interpolate(
|
||||||
self.sd_model.depth_model(midas_in),
|
self.sd_model.depth_model(midas_in),
|
||||||
|
@ -217,7 +217,7 @@ class StableDiffusionProcessing:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Encode the new masked image using first stage of network.
|
# Encode the new masked image using first stage of network.
|
||||||
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image.to(devices.dtype_unet) if devices.unet_needs_upcast else conditioning_image))
|
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image.to(devices.dtype_vae) if devices.unet_needs_upcast else conditioning_image))
|
||||||
|
|
||||||
# Create the concatenated conditioning tensor to be fed to `c_concat`
|
# Create the concatenated conditioning tensor to be fed to `c_concat`
|
||||||
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
|
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
|
||||||
|
@ -417,7 +417,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
|
||||||
|
|
||||||
def decode_first_stage(model, x):
|
def decode_first_stage(model, x):
|
||||||
with devices.autocast(disable=x.dtype == devices.dtype_vae):
|
with devices.autocast(disable=x.dtype == devices.dtype_vae):
|
||||||
x = model.decode_first_stage(x)
|
x = model.decode_first_stage(x.to(devices.dtype_vae) if devices.unet_needs_upcast else x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@ -1001,7 +1001,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||||
|
|
||||||
image = torch.from_numpy(batch_images)
|
image = torch.from_numpy(batch_images)
|
||||||
image = 2. * image - 1.
|
image = 2. * image - 1.
|
||||||
image = image.to(device=shared.device, dtype=devices.dtype_unet if devices.unet_needs_upcast else None)
|
image = image.to(device=shared.device, dtype=devices.dtype_vae if devices.unet_needs_upcast else None)
|
||||||
|
|
||||||
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
|
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ class CondFunc:
|
||||||
self = super(CondFunc, cls).__new__(cls)
|
self = super(CondFunc, cls).__new__(cls)
|
||||||
if isinstance(orig_func, str):
|
if isinstance(orig_func, str):
|
||||||
func_path = orig_func.split('.')
|
func_path = orig_func.split('.')
|
||||||
for i in range(len(func_path)-2, -1, -1):
|
for i in range(len(func_path)-1, -1, -1):
|
||||||
try:
|
try:
|
||||||
resolved_obj = importlib.import_module('.'.join(func_path[:i]))
|
resolved_obj = importlib.import_module('.'.join(func_path[:i]))
|
||||||
break
|
break
|
||||||
|
|
Loading…
Reference in New Issue
Block a user