Merge pull request #7309 from brkirch/fix-embeddings
Fix embeddings, upscalers, and refactor `--upcast-sampling`
This commit is contained in:
commit
fecb990deb
|
@ -87,6 +87,14 @@ dtype_unet = torch.float16
|
||||||
unet_needs_upcast = False
|
unet_needs_upcast = False
|
||||||
|
|
||||||
|
|
||||||
|
def cond_cast_unet(input):
|
||||||
|
return input.to(dtype_unet) if unet_needs_upcast else input
|
||||||
|
|
||||||
|
|
||||||
|
def cond_cast_float(input):
|
||||||
|
return input.float() if unet_needs_upcast else input
|
||||||
|
|
||||||
|
|
||||||
def randn(seed, shape):
|
def randn(seed, shape):
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
if device.type == 'mps':
|
if device.type == 'mps':
|
||||||
|
@ -199,6 +207,3 @@ if has_mps():
|
||||||
cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
|
cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
|
||||||
torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) )
|
torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) )
|
||||||
torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) )
|
torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) )
|
||||||
orig_narrow = torch.narrow
|
|
||||||
torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() )
|
|
||||||
|
|
||||||
|
|
|
@ -173,8 +173,7 @@ class StableDiffusionProcessing:
|
||||||
midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
|
midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
|
||||||
midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
|
midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
|
||||||
|
|
||||||
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image.to(devices.dtype_vae) if devices.unet_needs_upcast else source_image))
|
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
|
||||||
conditioning_image = conditioning_image.float() if devices.unet_needs_upcast else conditioning_image
|
|
||||||
conditioning = torch.nn.functional.interpolate(
|
conditioning = torch.nn.functional.interpolate(
|
||||||
self.sd_model.depth_model(midas_in),
|
self.sd_model.depth_model(midas_in),
|
||||||
size=conditioning_image.shape[2:],
|
size=conditioning_image.shape[2:],
|
||||||
|
@ -218,7 +217,7 @@ class StableDiffusionProcessing:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Encode the new masked image using first stage of network.
|
# Encode the new masked image using first stage of network.
|
||||||
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image.to(devices.dtype_vae) if devices.unet_needs_upcast else conditioning_image))
|
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
|
||||||
|
|
||||||
# Create the concatenated conditioning tensor to be fed to `c_concat`
|
# Create the concatenated conditioning tensor to be fed to `c_concat`
|
||||||
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
|
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
|
||||||
|
@ -229,16 +228,18 @@ class StableDiffusionProcessing:
|
||||||
return image_conditioning
|
return image_conditioning
|
||||||
|
|
||||||
def img2img_image_conditioning(self, source_image, latent_image, image_mask=None):
|
def img2img_image_conditioning(self, source_image, latent_image, image_mask=None):
|
||||||
|
source_image = devices.cond_cast_float(source_image)
|
||||||
|
|
||||||
# HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
|
# HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
|
||||||
# identify itself with a field common to all models. The conditioning_key is also hybrid.
|
# identify itself with a field common to all models. The conditioning_key is also hybrid.
|
||||||
if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
|
if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
|
||||||
return self.depth2img_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image)
|
return self.depth2img_image_conditioning(source_image)
|
||||||
|
|
||||||
if self.sd_model.cond_stage_key == "edit":
|
if self.sd_model.cond_stage_key == "edit":
|
||||||
return self.edit_image_conditioning(source_image)
|
return self.edit_image_conditioning(source_image)
|
||||||
|
|
||||||
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
|
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
|
||||||
return self.inpainting_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image, latent_image, image_mask=image_mask)
|
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
|
||||||
|
|
||||||
# Dummy zero conditioning if we're not using inpainting or depth model.
|
# Dummy zero conditioning if we're not using inpainting or depth model.
|
||||||
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
|
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
|
||||||
|
@ -418,7 +419,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
|
||||||
|
|
||||||
def decode_first_stage(model, x):
|
def decode_first_stage(model, x):
|
||||||
with devices.autocast(disable=x.dtype == devices.dtype_vae):
|
with devices.autocast(disable=x.dtype == devices.dtype_vae):
|
||||||
x = model.decode_first_stage(x.to(devices.dtype_vae) if devices.unet_needs_upcast else x)
|
x = model.decode_first_stage(x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@ -1007,7 +1008,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||||
|
|
||||||
image = torch.from_numpy(batch_images)
|
image = torch.from_numpy(batch_images)
|
||||||
image = 2. * image - 1.
|
image = 2. * image - 1.
|
||||||
image = image.to(device=shared.device, dtype=devices.dtype_vae if devices.unet_needs_upcast else None)
|
image = image.to(shared.device)
|
||||||
|
|
||||||
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
|
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
|
||||||
|
|
||||||
|
|
|
@ -46,7 +46,7 @@ class UpscalerRealESRGAN(Upscaler):
|
||||||
scale=info.scale,
|
scale=info.scale,
|
||||||
model_path=info.local_data_path,
|
model_path=info.local_data_path,
|
||||||
model=info.model(),
|
model=info.model(),
|
||||||
half=not cmd_opts.no_half,
|
half=not cmd_opts.no_half and not cmd_opts.upcast_sampling,
|
||||||
tile=opts.ESRGAN_tile,
|
tile=opts.ESRGAN_tile,
|
||||||
tile_pad=opts.ESRGAN_tile_overlap,
|
tile_pad=opts.ESRGAN_tile_overlap,
|
||||||
)
|
)
|
||||||
|
|
|
@ -173,7 +173,7 @@ class EmbeddingsWithFixes(torch.nn.Module):
|
||||||
vecs = []
|
vecs = []
|
||||||
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
||||||
for offset, embedding in fixes:
|
for offset, embedding in fixes:
|
||||||
emb = embedding.vec
|
emb = devices.cond_cast_unet(embedding.vec)
|
||||||
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
|
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
|
||||||
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
|
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
|
||||||
|
|
||||||
|
|
|
@ -55,8 +55,14 @@ class GELUHijack(torch.nn.GELU, torch.nn.Module):
|
||||||
|
|
||||||
unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
|
unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
|
||||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
|
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
|
||||||
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).to(devices.dtype_unet), unet_needs_upcast)
|
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
|
||||||
if version.parse(torch.__version__) <= version.parse("1.13.1"):
|
if version.parse(torch.__version__) <= version.parse("1.13.1"):
|
||||||
CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
|
CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
|
||||||
CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
|
CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
|
||||||
CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
|
CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
|
||||||
|
|
||||||
|
first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16
|
||||||
|
first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs)
|
||||||
|
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
|
||||||
|
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
|
||||||
|
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)
|
||||||
|
|
|
@ -10,7 +10,7 @@ then
|
||||||
fi
|
fi
|
||||||
|
|
||||||
export install_dir="$HOME"
|
export install_dir="$HOME"
|
||||||
export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --use-cpu interrogate"
|
export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate"
|
||||||
export TORCH_COMMAND="pip install torch==1.12.1 torchvision==0.13.1"
|
export TORCH_COMMAND="pip install torch==1.12.1 torchvision==0.13.1"
|
||||||
export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git"
|
export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git"
|
||||||
export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71"
|
export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71"
|
||||||
|
|
Loading…
Reference in New Issue
Block a user