Merge pull request #7309 from brkirch/fix-embeddings

Fix embeddings, upscalers, and refactor `--upcast-sampling`
This commit is contained in:
AUTOMATIC1111 2023-01-28 18:44:36 +03:00 committed by GitHub
commit fecb990deb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 26 additions and 14 deletions

View File

@ -87,6 +87,14 @@ dtype_unet = torch.float16
unet_needs_upcast = False unet_needs_upcast = False
def cond_cast_unet(input):
return input.to(dtype_unet) if unet_needs_upcast else input
def cond_cast_float(input):
return input.float() if unet_needs_upcast else input
def randn(seed, shape): def randn(seed, shape):
torch.manual_seed(seed) torch.manual_seed(seed)
if device.type == 'mps': if device.type == 'mps':
@ -199,6 +207,3 @@ if has_mps():
cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0)) cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) ) torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) )
torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) ) torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) )
orig_narrow = torch.narrow
torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() )

View File

@ -173,8 +173,7 @@ class StableDiffusionProcessing:
midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device) midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size) midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image.to(devices.dtype_vae) if devices.unet_needs_upcast else source_image)) conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
conditioning_image = conditioning_image.float() if devices.unet_needs_upcast else conditioning_image
conditioning = torch.nn.functional.interpolate( conditioning = torch.nn.functional.interpolate(
self.sd_model.depth_model(midas_in), self.sd_model.depth_model(midas_in),
size=conditioning_image.shape[2:], size=conditioning_image.shape[2:],
@ -218,7 +217,7 @@ class StableDiffusionProcessing:
) )
# Encode the new masked image using first stage of network. # Encode the new masked image using first stage of network.
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image.to(devices.dtype_vae) if devices.unet_needs_upcast else conditioning_image)) conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
# Create the concatenated conditioning tensor to be fed to `c_concat` # Create the concatenated conditioning tensor to be fed to `c_concat`
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:]) conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
@ -229,16 +228,18 @@ class StableDiffusionProcessing:
return image_conditioning return image_conditioning
def img2img_image_conditioning(self, source_image, latent_image, image_mask=None): def img2img_image_conditioning(self, source_image, latent_image, image_mask=None):
source_image = devices.cond_cast_float(source_image)
# HACK: Using introspection as the Depth2Image model doesn't appear to uniquely # HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
# identify itself with a field common to all models. The conditioning_key is also hybrid. # identify itself with a field common to all models. The conditioning_key is also hybrid.
if isinstance(self.sd_model, LatentDepth2ImageDiffusion): if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
return self.depth2img_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image) return self.depth2img_image_conditioning(source_image)
if self.sd_model.cond_stage_key == "edit": if self.sd_model.cond_stage_key == "edit":
return self.edit_image_conditioning(source_image) return self.edit_image_conditioning(source_image)
if self.sampler.conditioning_key in {'hybrid', 'concat'}: if self.sampler.conditioning_key in {'hybrid', 'concat'}:
return self.inpainting_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image, latent_image, image_mask=image_mask) return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
# Dummy zero conditioning if we're not using inpainting or depth model. # Dummy zero conditioning if we're not using inpainting or depth model.
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1) return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
@ -418,7 +419,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
def decode_first_stage(model, x): def decode_first_stage(model, x):
with devices.autocast(disable=x.dtype == devices.dtype_vae): with devices.autocast(disable=x.dtype == devices.dtype_vae):
x = model.decode_first_stage(x.to(devices.dtype_vae) if devices.unet_needs_upcast else x) x = model.decode_first_stage(x)
return x return x
@ -1007,7 +1008,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
image = torch.from_numpy(batch_images) image = torch.from_numpy(batch_images)
image = 2. * image - 1. image = 2. * image - 1.
image = image.to(device=shared.device, dtype=devices.dtype_vae if devices.unet_needs_upcast else None) image = image.to(shared.device)
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image)) self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))

View File

@ -46,7 +46,7 @@ class UpscalerRealESRGAN(Upscaler):
scale=info.scale, scale=info.scale,
model_path=info.local_data_path, model_path=info.local_data_path,
model=info.model(), model=info.model(),
half=not cmd_opts.no_half, half=not cmd_opts.no_half and not cmd_opts.upcast_sampling,
tile=opts.ESRGAN_tile, tile=opts.ESRGAN_tile,
tile_pad=opts.ESRGAN_tile_overlap, tile_pad=opts.ESRGAN_tile_overlap,
) )

View File

@ -173,7 +173,7 @@ class EmbeddingsWithFixes(torch.nn.Module):
vecs = [] vecs = []
for fixes, tensor in zip(batch_fixes, inputs_embeds): for fixes, tensor in zip(batch_fixes, inputs_embeds):
for offset, embedding in fixes: for offset, embedding in fixes:
emb = embedding.vec emb = devices.cond_cast_unet(embedding.vec)
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0]) emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]) tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])

View File

@ -55,8 +55,14 @@ class GELUHijack(torch.nn.GELU, torch.nn.Module):
unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast) CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).to(devices.dtype_unet), unet_needs_upcast) CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
if version.parse(torch.__version__) <= version.parse("1.13.1"): if version.parse(torch.__version__) <= version.parse("1.13.1"):
CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast) CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast) CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU) CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16
first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs)
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)

View File

@ -10,7 +10,7 @@ then
fi fi
export install_dir="$HOME" export install_dir="$HOME"
export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --use-cpu interrogate" export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate"
export TORCH_COMMAND="pip install torch==1.12.1 torchvision==0.13.1" export TORCH_COMMAND="pip install torch==1.12.1 torchvision==0.13.1"
export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git" export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git"
export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71" export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71"