diff --git a/.gitignore b/.gitignore index 0b1d17ca..e360cc9e 100644 --- a/.gitignore +++ b/.gitignore @@ -33,3 +33,7 @@ notification.mp3 /test/stdout.txt /test/stderr.txt /cache.json +special_start_medvram.bat +special_start_medvram.bat +special_start_medvram ohne xformers.bat +webui-user.bat diff --git a/.vs/ProjectSettings.json b/.vs/ProjectSettings.json new file mode 100644 index 00000000..f8b48885 --- /dev/null +++ b/.vs/ProjectSettings.json @@ -0,0 +1,3 @@ +{ + "CurrentProjectSetting": null +} \ No newline at end of file diff --git a/.vs/VSWorkspaceState.json b/.vs/VSWorkspaceState.json new file mode 100644 index 00000000..27f8c4b4 --- /dev/null +++ b/.vs/VSWorkspaceState.json @@ -0,0 +1,14 @@ +{ + "ExpandedNodes": [ + "", + "\\.github", + "\\configs", + "\\extensions-builtin", + "\\html", + "\\models", + "\\scripts", + "\\test", + "\\textual_inversion_templates" + ], + "PreviewInSolutionExplorer": false +} \ No newline at end of file diff --git a/.vs/slnx.sqlite b/.vs/slnx.sqlite new file mode 100644 index 00000000..7e160e1a Binary files /dev/null and b/.vs/slnx.sqlite differ diff --git a/.vs/stable-diffusion-webui2/FileContentIndex/643815af-76c4-4837-b969-d697d2a8e66c.vsidx b/.vs/stable-diffusion-webui2/FileContentIndex/643815af-76c4-4837-b969-d697d2a8e66c.vsidx new file mode 100644 index 00000000..61f7b390 Binary files /dev/null and b/.vs/stable-diffusion-webui2/FileContentIndex/643815af-76c4-4837-b969-d697d2a8e66c.vsidx differ diff --git a/.vs/stable-diffusion-webui2/FileContentIndex/7b3bb1f1-db2b-4b44-a74a-8d6d38f8c8dd.vsidx b/.vs/stable-diffusion-webui2/FileContentIndex/7b3bb1f1-db2b-4b44-a74a-8d6d38f8c8dd.vsidx new file mode 100644 index 00000000..36054bbe Binary files /dev/null and b/.vs/stable-diffusion-webui2/FileContentIndex/7b3bb1f1-db2b-4b44-a74a-8d6d38f8c8dd.vsidx differ diff --git a/.vs/stable-diffusion-webui2/FileContentIndex/a1c9ce19-b76b-493c-abae-90bbb908dc34.vsidx b/.vs/stable-diffusion-webui2/FileContentIndex/a1c9ce19-b76b-493c-abae-90bbb908dc34.vsidx new file mode 100644 index 00000000..7ca26afe Binary files /dev/null and b/.vs/stable-diffusion-webui2/FileContentIndex/a1c9ce19-b76b-493c-abae-90bbb908dc34.vsidx differ diff --git a/.vs/stable-diffusion-webui2/FileContentIndex/dc1d22eb-9771-47ac-8f24-f0b617c2ed70.vsidx b/.vs/stable-diffusion-webui2/FileContentIndex/dc1d22eb-9771-47ac-8f24-f0b617c2ed70.vsidx new file mode 100644 index 00000000..998d0eae Binary files /dev/null and b/.vs/stable-diffusion-webui2/FileContentIndex/dc1d22eb-9771-47ac-8f24-f0b617c2ed70.vsidx differ diff --git a/.vs/stable-diffusion-webui2/FileContentIndex/read.lock b/.vs/stable-diffusion-webui2/FileContentIndex/read.lock new file mode 100644 index 00000000..e69de29b diff --git a/.vs/stable-diffusion-webui2/v17/.wsuo b/.vs/stable-diffusion-webui2/v17/.wsuo new file mode 100644 index 00000000..a8435ed6 Binary files /dev/null and b/.vs/stable-diffusion-webui2/v17/.wsuo differ diff --git a/optimizedSD/ddpm.py b/optimizedSD/ddpm.py new file mode 100644 index 00000000..79058bca --- /dev/null +++ b/optimizedSD/ddpm.py @@ -0,0 +1,1057 @@ +""" +wild mixture of +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py +https://github.com/CompVis/taming-transformers +-- merci +""" + +import time, math +from tqdm.auto import trange, tqdm +import torch +from einops import rearrange +from tqdm import tqdm +from ldm.modules.distributions.distributions import DiagonalGaussianDistribution +from ldm.models.autoencoder import VQModelInterface +import torch.nn as nn +import numpy as np +import pytorch_lightning as pl +from functools import partial +from pytorch_lightning.utilities.distributed import rank_zero_only +from ldm.util import exists, default, instantiate_from_config +from ldm.modules.diffusionmodules.util import make_beta_schedule +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like +from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like +from .samplers import CompVisDenoiser, get_ancestral_step, to_d, append_dims,linear_multistep_coeff + +def disabled_train(self): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class DDPM(pl.LightningModule): + # classic DDPM with Gaussian diffusion, in image space + def __init__(self, + timesteps=1000, + beta_schedule="linear", + ckpt_path=None, + ignore_keys=[], + load_only_unet=False, + monitor="val/loss", + use_ema=True, + first_stage_key="image", + image_size=256, + channels=3, + log_every_t=100, + clip_denoised=True, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0., + v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1., + conditioning_key=None, + parameterization="eps", # all assuming fixed variance schedules + scheduler_config=None, + use_positional_encodings=False, + ): + super().__init__() + assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"' + self.parameterization = parameterization + print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") + self.cond_stage_model = None + self.clip_denoised = clip_denoised + self.log_every_t = log_every_t + self.first_stage_key = first_stage_key + self.image_size = image_size # try conv? + self.channels = channels + self.use_positional_encodings = use_positional_encodings + self.use_scheduler = scheduler_config is not None + if self.use_scheduler: + self.scheduler_config = scheduler_config + + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) + self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, + linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) + + + def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + + +class FirstStage(DDPM): + """main class""" + def __init__(self, + first_stage_config, + num_timesteps_cond=None, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + scale_by_std=False, + *args, **kwargs): + self.num_timesteps_cond = default(num_timesteps_cond, 1) + self.scale_by_std = scale_by_std + assert self.num_timesteps_cond <= kwargs['timesteps'] + # for backwards compatibility after implementation of DiffusionWrapper + if conditioning_key is None: + conditioning_key = 'concat' if concat_mode else 'crossattn' + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", []) + super().__init__() + self.concat_mode = concat_mode + self.cond_stage_trainable = cond_stage_trainable + self.cond_stage_key = cond_stage_key + try: + self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 + except: + self.num_downs = 0 + if not scale_by_std: + self.scale_factor = scale_factor + self.instantiate_first_stage(first_stage_config) + self.cond_stage_forward = cond_stage_forward + self.clip_denoised = False + self.bbox_tokenizer = None + + self.restarted_from_ckpt = False + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys) + self.restarted_from_ckpt = True + + + def instantiate_first_stage(self, config): + model = instantiate_from_config(config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + def get_first_stage_encoding(self, encoder_posterior): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") + return self.scale_factor * z + + + @torch.no_grad() + def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b h w c -> b c h w').contiguous() + + z = 1. / self.scale_factor * z + + if hasattr(self, "split_input_params"): + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + + @torch.no_grad() + def encode_first_stage(self, x): + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + df = self.split_input_params["vqf"] + self.split_input_params['original_image_size'] = x.shape[-2:] + bs, nc, h, w = x.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df) + z = unfold(x) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + output_list = [self.first_stage_model.encode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) + o = o * weighting + + # Reverse reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization + return decoded + + else: + return self.first_stage_model.encode(x) + else: + return self.first_stage_model.encode(x) + + +class CondStage(DDPM): + """main class""" + def __init__(self, + cond_stage_config, + num_timesteps_cond=None, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + scale_by_std=False, + *args, **kwargs): + self.num_timesteps_cond = default(num_timesteps_cond, 1) + self.scale_by_std = scale_by_std + assert self.num_timesteps_cond <= kwargs['timesteps'] + # for backwards compatibility after implementation of DiffusionWrapper + if conditioning_key is None: + conditioning_key = 'concat' if concat_mode else 'crossattn' + if cond_stage_config == '__is_unconditional__': + conditioning_key = None + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", []) + super().__init__() + self.concat_mode = concat_mode + self.cond_stage_trainable = cond_stage_trainable + self.cond_stage_key = cond_stage_key + self.num_downs = 0 + if not scale_by_std: + self.scale_factor = scale_factor + self.instantiate_cond_stage(cond_stage_config) + self.cond_stage_forward = cond_stage_forward + self.clip_denoised = False + self.bbox_tokenizer = None + + self.restarted_from_ckpt = False + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys) + self.restarted_from_ckpt = True + + def instantiate_cond_stage(self, config): + if not self.cond_stage_trainable: + if config == "__is_first_stage__": + print("Using first stage also as cond stage.") + self.cond_stage_model = self.first_stage_model + elif config == "__is_unconditional__": + print(f"Training {self.__class__.__name__} as an unconditional model.") + self.cond_stage_model = None + # self.be_unconditional = True + else: + model = instantiate_from_config(config) + self.cond_stage_model = model.eval() + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + else: + assert config != '__is_first_stage__' + assert config != '__is_unconditional__' + model = instantiate_from_config(config) + self.cond_stage_model = model + + def get_learned_conditioning(self, c): + if self.cond_stage_forward is None: + if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): + c = self.cond_stage_model.encode(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + else: + c = self.cond_stage_model(c) + else: + assert hasattr(self.cond_stage_model, self.cond_stage_forward) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + return c + +class DiffusionWrapper(pl.LightningModule): + def __init__(self, diff_model_config): + super().__init__() + self.diffusion_model = instantiate_from_config(diff_model_config) + + def forward(self, x, t, cc): + out = self.diffusion_model(x, t, context=cc) + return out + +class DiffusionWrapperOut(pl.LightningModule): + def __init__(self, diff_model_config): + super().__init__() + self.diffusion_model = instantiate_from_config(diff_model_config) + + def forward(self, h,emb,tp,hs, cc): + return self.diffusion_model(h,emb,tp,hs, context=cc) + + +class UNet(DDPM): + """main class""" + def __init__(self, + unetConfigEncode, + unetConfigDecode, + num_timesteps_cond=None, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + unet_bs = 1, + scale_by_std=False, + *args, **kwargs): + self.num_timesteps_cond = default(num_timesteps_cond, 1) + self.scale_by_std = scale_by_std + assert self.num_timesteps_cond <= kwargs['timesteps'] + # for backwards compatibility after implementation of DiffusionWrapper + if conditioning_key is None: + conditioning_key = 'concat' if concat_mode else 'crossattn' + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", []) + super().__init__(conditioning_key=conditioning_key, *args, **kwargs) + self.concat_mode = concat_mode + self.cond_stage_trainable = cond_stage_trainable + self.cond_stage_key = cond_stage_key + self.num_downs = 0 + self.cdevice = "cuda" + self.unetConfigEncode = unetConfigEncode + self.unetConfigDecode = unetConfigDecode + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.register_buffer('scale_factor', torch.tensor(scale_factor)) + self.cond_stage_forward = cond_stage_forward + self.clip_denoised = False + self.bbox_tokenizer = None + self.model1 = DiffusionWrapper(self.unetConfigEncode) + self.model2 = DiffusionWrapperOut(self.unetConfigDecode) + self.model1.eval() + self.model2.eval() + self.turbo = False + self.unet_bs = unet_bs + self.restarted_from_ckpt = False + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys) + self.restarted_from_ckpt = True + + def make_cond_schedule(self, ): + self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) + ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() + self.cond_ids[:self.num_timesteps_cond] = ids + + @rank_zero_only + @torch.no_grad() + def on_train_batch_start(self, batch, batch_idx): + # only for very first batch + if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: + assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' + # set rescale weight to 1./std of encodings + print("### USING STD-RESCALING ###") + x = super().get_input(batch, self.first_stage_key) + x = x.to(self.cdevice) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + del self.scale_factor + self.register_buffer('scale_factor', 1. / z.flatten().std()) + print(f"setting self.scale_factor to {self.scale_factor}") + print("### USING STD-RESCALING ###") + + + def apply_model(self, x_noisy, t, cond, return_ids=False): + + if(not self.turbo): + self.model1.to(self.cdevice) + + step = self.unet_bs + h,emb,hs = self.model1(x_noisy[0:step], t[:step], cond[:step]) + bs = cond.shape[0] + + # assert bs%2 == 0 + lenhs = len(hs) + + for i in range(step,bs,step): + h_temp,emb_temp,hs_temp = self.model1(x_noisy[i:i+step], t[i:i+step], cond[i:i+step]) + h = torch.cat((h,h_temp)) + emb = torch.cat((emb,emb_temp)) + for j in range(lenhs): + hs[j] = torch.cat((hs[j], hs_temp[j])) + + + if(not self.turbo): + self.model1.to("cpu") + self.model2.to(self.cdevice) + + hs_temp = [hs[j][:step] for j in range(lenhs)] + x_recon = self.model2(h[:step],emb[:step],x_noisy.dtype,hs_temp,cond[:step]) + + for i in range(step,bs,step): + + hs_temp = [hs[j][i:i+step] for j in range(lenhs)] + x_recon1 = self.model2(h[i:i+step],emb[i:i+step],x_noisy.dtype,hs_temp,cond[i:i+step]) + x_recon = torch.cat((x_recon, x_recon1)) + + if(not self.turbo): + self.model2.to("cpu") + + if isinstance(x_recon, tuple) and not return_ids: + return x_recon[0] + else: + return x_recon + + def register_buffer1(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device(self.cdevice): + attr = attr.to(torch.device(self.cdevice)) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + + + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.num_timesteps,verbose=verbose) + + + assert self.alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + + to_torch = lambda x: x.to(self.cdevice) + self.register_buffer1('betas', to_torch(self.betas)) + self.register_buffer1('alphas_cumprod', to_torch(self.alphas_cumprod)) + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=self.alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer1('ddim_sigmas', ddim_sigmas) + self.register_buffer1('ddim_alphas', ddim_alphas) + self.register_buffer1('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer1('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + + + @torch.no_grad() + def sample(self, + S, + conditioning, + x0=None, + shape = None, + seed=1234, + callback=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + sampler = "plms", + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + ): + + + if(self.turbo): + self.model1.to(self.cdevice) + self.model2.to(self.cdevice) + + if x0 is None: + batch_size, b1, b2, b3 = shape + img_shape = (1, b1, b2, b3) + tens = [] + print("seeds used = ", [seed+s for s in range(batch_size)]) + for _ in range(batch_size): + torch.manual_seed(seed) + tens.append(torch.randn(img_shape, device=self.cdevice)) + seed+=1 + noise = torch.cat(tens) + del tens + + x_latent = noise if x0 is None else x0 + # sampling + if sampler in ('ddim', 'dpm2', 'heun', 'dpm2_a', 'lms') and not hasattr(self, 'ddim_timesteps'): + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False) + + if sampler == "plms": + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False) + print(f'Data shape for PLMS sampling is {shape}') + samples = self.plms_sampling(conditioning, batch_size, x_latent, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + + elif sampler == "ddim": + samples = self.ddim_sampling(x_latent, conditioning, S, unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + mask = mask,init_latent=x_T,use_original_steps=False, + callback=callback, img_callback=img_callback) + + elif sampler == "euler": + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False) + samples = self.euler_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning, + unconditional_guidance_scale=unconditional_guidance_scale, + callback=callback, img_callback=img_callback) + elif sampler == "euler_a": + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False) + samples = self.euler_ancestral_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning, + unconditional_guidance_scale=unconditional_guidance_scale, + callback=callback, img_callback=img_callback) + + elif sampler == "dpm2": + samples = self.dpm_2_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning, + unconditional_guidance_scale=unconditional_guidance_scale, + callback=callback, img_callback=img_callback) + elif sampler == "heun": + samples = self.heun_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning, + unconditional_guidance_scale=unconditional_guidance_scale, + callback=callback, img_callback=img_callback) + + elif sampler == "dpm2_a": + samples = self.dpm_2_ancestral_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning, + unconditional_guidance_scale=unconditional_guidance_scale, + callback=callback, img_callback=img_callback) + + + elif sampler == "lms": + samples = self.lms_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning, + unconditional_guidance_scale=unconditional_guidance_scale, + callback=callback, img_callback=img_callback) + + if(self.turbo): + self.model1.to("cpu") + self.model2.to("cpu") + + return samples + + @torch.no_grad() + def plms_sampling(self, cond,b, img, + ddim_use_original_steps=False, + callback=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None,): + + device = self.betas.device + timesteps = self.ddim_timesteps + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + print(f"Running PLMS Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) + old_eps = [] + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + old_eps=old_eps, t_next=ts_next) + img, pred_x0, e_t = outs + old_eps.append(e_t) + if len(old_eps) >= 4: + old_eps.pop(0) + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + return img + + @torch.no_grad() + def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None): + b, *_, device = *x.shape, x.device + + def get_model_output(x, t): + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + return e_t + + alphas = self.ddim_alphas + alphas_prev = self.ddim_alphas_prev + sqrt_one_minus_alphas = self.ddim_sqrt_one_minus_alphas + sigmas = self.ddim_sigmas + + def get_x_prev_and_pred_x0(e_t, index): + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + e_t = get_model_output(x, t) + if len(old_eps) == 0: + # Pseudo Improved Euler (2nd order) + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) + e_t_next = get_model_output(x_prev, t_next) + e_t_prime = (e_t + e_t_next) / 2 + elif len(old_eps) == 1: + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (3 * e_t - old_eps[-1]) / 2 + elif len(old_eps) == 2: + # 3nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + elif len(old_eps) >= 3: + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 + + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) + + return x_prev, pred_x0, e_t + + + @torch.no_grad() + def stochastic_encode(self, x0, t, seed, ddim_eta,ddim_steps,use_original_steps=False, noise=None): + # fast, but does not allow for exact reconstruction + # t serves as an index to gather the correct alphas + self.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta, verbose=False) + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + + if noise is None: + b0, b1, b2, b3 = x0.shape + img_shape = (1, b1, b2, b3) + tens = [] + print("seeds used = ", [seed+s for s in range(b0)]) + for _ in range(b0): + torch.manual_seed(seed) + tens.append(torch.randn(img_shape, device=x0.device)) + seed+=1 + noise = torch.cat(tens) + del tens + return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(self.ddim_sqrt_one_minus_alphas, t, x0.shape) * noise) + + @torch.no_grad() + def add_noise(self, x0, t): + + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + noise = torch.randn(x0.shape, device=x0.device) + + # print(extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape), + # extract_into_tensor(self.ddim_sqrt_one_minus_alphas, t, x0.shape)) + return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(self.ddim_sqrt_one_minus_alphas, t, x0.shape) * noise) + + + @torch.no_grad() + def ddim_sampling(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, + mask = None,init_latent=None,use_original_steps=False, + callback=None, img_callback=None): + + timesteps = self.ddim_timesteps + timesteps = timesteps[:t_start] + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='Decoding image', total=total_steps) + x_dec = x_latent + x0 = init_latent + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) + + if mask is not None: + # x0_noisy = self.add_noise(mask, torch.tensor([index] * x0.shape[0]).to(self.cdevice)) + x0_noisy = x0 + x_dec = x0_noisy* mask + (1. - mask) * x_dec + + x_dec = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + + if callback: callback(i) + if img_callback: img_callback(x_dec, i) + + if mask is not None: + return x0 * mask + (1. - mask) * x_dec + + return x_dec + + + @torch.no_grad() + def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + alphas = self.ddim_alphas + alphas_prev = self.ddim_alphas_prev + sqrt_one_minus_alphas = self.ddim_sqrt_one_minus_alphas + sigmas = self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev + + + @torch.no_grad() + def euler_sampling(self, ac, x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None,callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., img_callback=None): + """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" + extra_args = {} if extra_args is None else extra_args + cvd = CompVisDenoiser(ac) + sigmas = cvd.get_sigmas(S) + x = x*sigmas[0] + + print(f"Running Euler Sampling with {len(sigmas) - 1} timesteps") + + s_in = x.new_ones([x.shape[0]]).half() + for i in trange(len(sigmas) - 1, disable=disable): + gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. + eps = torch.randn_like(x) * s_noise + sigma_hat = (sigmas[i] * (gamma + 1)).half() + if gamma > 0: + x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 + + s_i = sigma_hat * s_in + x_in = torch.cat([x] * 2) + t_in = torch.cat([s_i] * 2) + cond_in = torch.cat([unconditional_conditioning, cond]) + c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)] + eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in) + e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) + denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + + d = to_d(x, sigma_hat, denoised) + if callback: callback(i) + if img_callback: img_callback(x, i) + dt = sigmas[i + 1] - sigma_hat + # Euler method + x = x + d * dt + return x + + @torch.no_grad() + def euler_ancestral_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None, callback=None, disable=None, img_callback=None): + """Ancestral sampling with Euler method steps.""" + extra_args = {} if extra_args is None else extra_args + + + cvd = CompVisDenoiser(ac) + sigmas = cvd.get_sigmas(S) + x = x*sigmas[0] + + print(f"Running Euler Ancestral Sampling with {len(sigmas) - 1} timesteps") + + s_in = x.new_ones([x.shape[0]]).half() + for i in trange(len(sigmas) - 1, disable=disable): + + s_i = sigmas[i] * s_in + x_in = torch.cat([x] * 2) + t_in = torch.cat([s_i] * 2) + cond_in = torch.cat([unconditional_conditioning, cond]) + c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)] + eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in) + e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) + denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1]) + if callback: callback(i) + if img_callback: img_callback(x, i) + d = to_d(x, sigmas[i], denoised) + # Euler method + dt = sigma_down - sigmas[i] + x = x + d * dt + x = x + torch.randn_like(x) * sigma_up + return x + + + + @torch.no_grad() + def heun_sampling(self, ac, x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., img_callback=None): + """Implements Algorithm 2 (Heun steps) from Karras et al. (2022).""" + extra_args = {} if extra_args is None else extra_args + + cvd = CompVisDenoiser(alphas_cumprod=ac) + sigmas = cvd.get_sigmas(S) + x = x*sigmas[0] + + print(f"Running Heun Sampling with {len(sigmas) - 1} timesteps") + + s_in = x.new_ones([x.shape[0]]).half() + for i in trange(len(sigmas) - 1, disable=disable): + gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. + eps = torch.randn_like(x) * s_noise + sigma_hat = (sigmas[i] * (gamma + 1)).half() + if gamma > 0: + x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 + + s_i = sigma_hat * s_in + x_in = torch.cat([x] * 2) + t_in = torch.cat([s_i] * 2) + cond_in = torch.cat([unconditional_conditioning, cond]) + c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)] + eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in) + e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) + denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + d = to_d(x, sigma_hat, denoised) + if callback: callback(i) + if img_callback: img_callback(x, i) + dt = sigmas[i + 1] - sigma_hat + if sigmas[i + 1] == 0: + # Euler method + x = x + d * dt + else: + # Heun's method + x_2 = x + d * dt + s_i = sigmas[i + 1] * s_in + x_in = torch.cat([x_2] * 2) + t_in = torch.cat([s_i] * 2) + cond_in = torch.cat([unconditional_conditioning, cond]) + c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)] + eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in) + e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) + denoised_2 = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + d_2 = to_d(x_2, sigmas[i + 1], denoised_2) + d_prime = (d + d_2) / 2 + x = x + d_prime * dt + return x + + + @torch.no_grad() + def dpm_2_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., img_callback=None): + """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022).""" + extra_args = {} if extra_args is None else extra_args + + cvd = CompVisDenoiser(ac) + sigmas = cvd.get_sigmas(S) + x = x*sigmas[0] + + print(f"Running DPM2 Sampling with {len(sigmas) - 1} timesteps") + + s_in = x.new_ones([x.shape[0]]).half() + for i in trange(len(sigmas) - 1, disable=disable): + gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. + eps = torch.randn_like(x) * s_noise + sigma_hat = sigmas[i] * (gamma + 1) + if gamma > 0: + x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 + + s_i = sigma_hat * s_in + x_in = torch.cat([x] * 2) + t_in = torch.cat([s_i] * 2) + cond_in = torch.cat([unconditional_conditioning, cond]) + c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)] + eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in) + e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) + denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if callback: callback(i) + if img_callback: img_callback(x, i) + + d = to_d(x, sigma_hat, denoised) + # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule + sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3 + dt_1 = sigma_mid - sigma_hat + dt_2 = sigmas[i + 1] - sigma_hat + x_2 = x + d * dt_1 + + s_i = sigma_mid * s_in + x_in = torch.cat([x_2] * 2) + t_in = torch.cat([s_i] * 2) + cond_in = torch.cat([unconditional_conditioning, cond]) + c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)] + eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in) + e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) + denoised_2 = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + + d_2 = to_d(x_2, sigma_mid, denoised_2) + x = x + d_2 * dt_2 + return x + + + @torch.no_grad() + def dpm_2_ancestral_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None, img_callback=None): + """Ancestral sampling with DPM-Solver inspired second-order steps.""" + extra_args = {} if extra_args is None else extra_args + + cvd = CompVisDenoiser(ac) + sigmas = cvd.get_sigmas(S) + x = x*sigmas[0] + + print(f"Running DPM2 Ancestral Sampling with {len(sigmas) - 1} timesteps") + + s_in = x.new_ones([x.shape[0]]).half() + for i in trange(len(sigmas) - 1, disable=disable): + + s_i = sigmas[i] * s_in + x_in = torch.cat([x] * 2) + t_in = torch.cat([s_i] * 2) + cond_in = torch.cat([unconditional_conditioning, cond]) + c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)] + eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in) + e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) + denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + + sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1]) + if callback: callback(i) + if img_callback: img_callback(x, i) + d = to_d(x, sigmas[i], denoised) + # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule + sigma_mid = ((sigmas[i] ** (1 / 3) + sigma_down ** (1 / 3)) / 2) ** 3 + dt_1 = sigma_mid - sigmas[i] + dt_2 = sigma_down - sigmas[i] + x_2 = x + d * dt_1 + + s_i = sigma_mid * s_in + x_in = torch.cat([x_2] * 2) + t_in = torch.cat([s_i] * 2) + cond_in = torch.cat([unconditional_conditioning, cond]) + c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)] + eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in) + e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) + denoised_2 = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + + d_2 = to_d(x_2, sigma_mid, denoised_2) + x = x + d_2 * dt_2 + x = x + torch.randn_like(x) * sigma_up + return x + + + @torch.no_grad() + def lms_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None, order=4, img_callback=None): + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + + cvd = CompVisDenoiser(ac) + sigmas = cvd.get_sigmas(S) + x = x*sigmas[0] + + print(f"Running LMS Sampling with {len(sigmas) - 1} timesteps") + + ds = [] + for i in trange(len(sigmas) - 1, disable=disable): + + s_i = sigmas[i] * s_in + x_in = torch.cat([x] * 2) + t_in = torch.cat([s_i] * 2) + cond_in = torch.cat([unconditional_conditioning, cond]) + c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)] + eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in) + e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) + denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if callback: callback(i) + if img_callback: img_callback(x, i) + + d = to_d(x, sigmas[i], denoised) + ds.append(d) + if len(ds) > order: + ds.pop(0) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + cur_order = min(i + 1, order) + coeffs = [linear_multistep_coeff(cur_order, sigmas.cpu(), i, j) for j in range(cur_order)] + x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) + return x diff --git a/optimizedSD/openaimodelSplit.py b/optimizedSD/openaimodelSplit.py new file mode 100644 index 00000000..7a32ffe9 --- /dev/null +++ b/optimizedSD/openaimodelSplit.py @@ -0,0 +1,807 @@ +from abc import abstractmethod +import math +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F +from ldm.modules.diffusionmodules.util import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) +from .splitAttention import SpatialTransformer + + +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + +class TransposedUpsample(nn.Module): + 'Learned 2x upsampling without padding' + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) + + def forward(self,x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + #return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModelEncode(nn.Module): + + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + ): + super().__init__() + if use_spatial_transformer: + assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + def forward(self, x, timesteps=None, context=None, y=None): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + + return h, emb, hs + + +class UNetModelDecode(nn.Module): + + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + ): + super().__init__() + if use_spatial_transformer: + assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + def forward(self, h,emb,tp,hs, context=None, y=None): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(tp) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) \ No newline at end of file diff --git a/optimizedSD/optimUtils.py b/optimizedSD/optimUtils.py new file mode 100644 index 00000000..18b99679 --- /dev/null +++ b/optimizedSD/optimUtils.py @@ -0,0 +1,73 @@ +import os +import pandas as pd + + +def split_weighted_subprompts(text): + """ + grabs all text up to the first occurrence of ':' + uses the grabbed text as a sub-prompt, and takes the value following ':' as weight + if ':' has no value defined, defaults to 1.0 + repeats until no text remaining + """ + remaining = len(text) + prompts = [] + weights = [] + while remaining > 0: + if ":" in text: + idx = text.index(":") # first occurrence from start + # grab up to index as sub-prompt + prompt = text[:idx] + remaining -= idx + # remove from main text + text = text[idx+1:] + # find value for weight + if " " in text: + idx = text.index(" ") # first occurence + else: # no space, read to end + idx = len(text) + if idx != 0: + try: + weight = float(text[:idx]) + except: # couldn't treat as float + print(f"Warning: '{text[:idx]}' is not a value, are you missing a space?") + weight = 1.0 + else: # no value found + weight = 1.0 + # remove from main text + remaining -= idx + text = text[idx+1:] + # append the sub-prompt and its weight + prompts.append(prompt) + weights.append(weight) + else: # no : found + if len(text) > 0: # there is still text though + # take remainder as weight 1 + prompts.append(text) + weights.append(1.0) + remaining = 0 + return prompts, weights + +def logger(params, log_csv): + os.makedirs('logs', exist_ok=True) + cols = [arg for arg, _ in params.items()] + if not os.path.exists(log_csv): + df = pd.DataFrame(columns=cols) + df.to_csv(log_csv, index=False) + + df = pd.read_csv(log_csv) + for arg in cols: + if arg not in df.columns: + df[arg] = "" + df.to_csv(log_csv, index = False) + + li = {} + cols = [col for col in df.columns] + data = {arg:value for arg, value in params.items()} + for col in cols: + if col in data: + li[col] = data[col] + else: + li[col] = '' + + df = pd.DataFrame(li,index = [0]) + df.to_csv(log_csv,index=False, mode='a', header=False) \ No newline at end of file diff --git a/optimizedSD/samplers.py b/optimizedSD/samplers.py new file mode 100644 index 00000000..6a68e8e1 --- /dev/null +++ b/optimizedSD/samplers.py @@ -0,0 +1,252 @@ +from scipy import integrate +import torch +from tqdm.auto import trange, tqdm +import torch.nn as nn + + +def append_zero(x): + return torch.cat([x, x.new_zeros([1])]) + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') + return x[(...,) + (None,) * dims_to_append] + +def get_ancestral_step(sigma_from, sigma_to): + """Calculates the noise level (sigma_down) to step down to and the amount + of noise to add (sigma_up) when doing an ancestral sampling step.""" + sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5 + sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 + return sigma_down, sigma_up + + +class DiscreteSchedule(nn.Module): + """A mapping between continuous noise levels (sigmas) and a list of discrete noise + levels.""" + + def __init__(self, sigmas, quantize): + super().__init__() + self.register_buffer('sigmas', sigmas) + self.quantize = quantize + + def get_sigmas(self, n=None): + if n is None: + return append_zero(self.sigmas.flip(0)) + t_max = len(self.sigmas) - 1 + t = torch.linspace(t_max, 0, n, device=self.sigmas.device) + return append_zero(self.t_to_sigma(t)) + + def sigma_to_t(self, sigma, quantize=None): + quantize = self.quantize if quantize is None else quantize + dists = torch.abs(sigma - self.sigmas[:, None]) + if quantize: + return torch.argmin(dists, dim=0).view(sigma.shape) + low_idx, high_idx = torch.sort(torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0)[0] + low, high = self.sigmas[low_idx], self.sigmas[high_idx] + w = (low - sigma) / (low - high) + w = w.clamp(0, 1) + t = (1 - w) * low_idx + w * high_idx + return t.view(sigma.shape) + + def t_to_sigma(self, t): + t = t.float() + low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() + # print(low_idx, high_idx, w ) + return (1 - w) * self.sigmas[low_idx] + w * self.sigmas[high_idx] + + +class DiscreteEpsDDPMDenoiser(DiscreteSchedule): + """A wrapper for discrete schedule DDPM models that output eps (the predicted + noise).""" + + def __init__(self, alphas_cumprod, quantize): + super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize) + self.sigma_data = 1. + + def get_scalings(self, sigma): + c_out = -sigma + c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 + return c_out, c_in + + def get_eps(self, *args, **kwargs): + return self.inner_model(*args, **kwargs) + + def forward(self, input, sigma, **kwargs): + c_out, c_in = [append_dims(x, input.ndim) for x in self.get_scalings(sigma)] + eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs) + return input + eps * c_out + +class CompVisDenoiser(DiscreteEpsDDPMDenoiser): + """A wrapper for CompVis diffusion models.""" + + def __init__(self, alphas_cumprod, quantize=False, device='cpu'): + super().__init__(alphas_cumprod, quantize=quantize) + + def get_eps(self, *args, **kwargs): + return self.inner_model.apply_model(*args, **kwargs) + + +def to_d(x, sigma, denoised): + """Converts a denoiser output to a Karras ODE derivative.""" + return (x - denoised) / append_dims(sigma, x.ndim) + + +def get_ancestral_step(sigma_from, sigma_to): + """Calculates the noise level (sigma_down) to step down to and the amount + of noise to add (sigma_up) when doing an ancestral sampling step.""" + sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5 + sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 + return sigma_down, sigma_up + + +@torch.no_grad() +def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): + """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. + eps = torch.randn_like(x) * s_noise + sigma_hat = sigmas[i] * (gamma + 1) + if gamma > 0: + x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 + denoised = model(x, sigma_hat * s_in, **extra_args) + d = to_d(x, sigma_hat, denoised) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) + dt = sigmas[i + 1] - sigma_hat + # Euler method + x = x + d * dt + return x + + + +@torch.no_grad() +def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None): + """Ancestral sampling with Euler method steps.""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1]) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + d = to_d(x, sigmas[i], denoised) + # Euler method + dt = sigma_down - sigmas[i] + x = x + d * dt + x = x + torch.randn_like(x) * sigma_up + return x + + +@torch.no_grad() +def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): + """Implements Algorithm 2 (Heun steps) from Karras et al. (2022).""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. + eps = torch.randn_like(x) * s_noise + sigma_hat = sigmas[i] * (gamma + 1) + if gamma > 0: + x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 + denoised = model(x, sigma_hat * s_in, **extra_args) + d = to_d(x, sigma_hat, denoised) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) + dt = sigmas[i + 1] - sigma_hat + if sigmas[i + 1] == 0: + # Euler method + x = x + d * dt + else: + # Heun's method + x_2 = x + d * dt + denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args) + d_2 = to_d(x_2, sigmas[i + 1], denoised_2) + d_prime = (d + d_2) / 2 + x = x + d_prime * dt + return x + + +@torch.no_grad() +def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): + """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022).""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. + eps = torch.randn_like(x) * s_noise + sigma_hat = sigmas[i] * (gamma + 1) + if gamma > 0: + x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 + denoised = model(x, sigma_hat * s_in, **extra_args) + d = to_d(x, sigma_hat, denoised) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) + # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule + sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3 + dt_1 = sigma_mid - sigma_hat + dt_2 = sigmas[i + 1] - sigma_hat + x_2 = x + d * dt_1 + denoised_2 = model(x_2, sigma_mid * s_in, **extra_args) + d_2 = to_d(x_2, sigma_mid, denoised_2) + x = x + d_2 * dt_2 + return x + + +@torch.no_grad() +def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None): + """Ancestral sampling with DPM-Solver inspired second-order steps.""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1]) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + d = to_d(x, sigmas[i], denoised) + # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule + sigma_mid = ((sigmas[i] ** (1 / 3) + sigma_down ** (1 / 3)) / 2) ** 3 + dt_1 = sigma_mid - sigmas[i] + dt_2 = sigma_down - sigmas[i] + x_2 = x + d * dt_1 + denoised_2 = model(x_2, sigma_mid * s_in, **extra_args) + d_2 = to_d(x_2, sigma_mid, denoised_2) + x = x + d_2 * dt_2 + x = x + torch.randn_like(x) * sigma_up + return x + + +def linear_multistep_coeff(order, t, i, j): + if order - 1 > i: + raise ValueError(f'Order {order} too high for step {i}') + def fn(tau): + prod = 1. + for k in range(order): + if j == k: + continue + prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) + return prod + return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0] + + +@torch.no_grad() +def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4): + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + ds = [] + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + d = to_d(x, sigmas[i], denoised) + ds.append(d) + if len(ds) > order: + ds.pop(0) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + cur_order = min(i + 1, order) + coeffs = [linear_multistep_coeff(cur_order, sigmas.cpu(), i, j) for j in range(cur_order)] + x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) + return x diff --git a/optimizedSD/splitAttention.py b/optimizedSD/splitAttention.py new file mode 100644 index 00000000..dbfd459e --- /dev/null +++ b/optimizedSD/splitAttention.py @@ -0,0 +1,280 @@ +from inspect import isfunction +import math +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat + +from ldm.modules.diffusionmodules.util import checkpoint + + +def exists(val): + return val is not None + + +def uniq(arr): + return{el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = rearrange(q, 'b c h w -> b (h w) c') + k = rearrange(k, 'b c h w -> b c (h w)') + w_ = torch.einsum('bij,bjk->bik', q, k) + + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, 'b c h w -> b c (h w)') + w_ = rearrange(w_, 'b i j -> b j i') + h_ = torch.einsum('bij,bjk->bik', v, w_) + h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) + h_ = self.proj_out(h_) + + return x+h_ + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., att_step=1): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + self.att_step = att_step + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + del context, x + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + + limit = k.shape[0] + att_step = self.att_step + q_chunks = list(torch.tensor_split(q, limit//att_step, dim=0)) + k_chunks = list(torch.tensor_split(k, limit//att_step, dim=0)) + v_chunks = list(torch.tensor_split(v, limit//att_step, dim=0)) + + q_chunks.reverse() + k_chunks.reverse() + v_chunks.reverse() + sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) + del k, q, v + for i in range (0, limit, att_step): + + q_buffer = q_chunks.pop() + k_buffer = k_chunks.pop() + v_buffer = v_chunks.pop() + sim_buffer = einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale + + del k_buffer, q_buffer + # attention, what we cannot get enough of, by chunks + + sim_buffer = sim_buffer.softmax(dim=-1) + + sim_buffer = einsum('b i j, b j d -> b i d', sim_buffer, v_buffer) + del v_buffer + sim[i:i+att_step,:,:] = sim_buffer + + del sim_buffer + sim = rearrange(sim, '(b h) n d -> b n (h d)', h=h) + return self.to_out(sim) + + +class BasicTransformerBlock(nn.Module): + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): + super().__init__() + self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + def _forward(self, x, context=None): + x = self.attn1(self.norm1(x)) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None): + super().__init__() + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) + for d in range(depth)] + ) + + self.proj_out = zero_module(nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c') + for block in self.transformer_blocks: + x = block(x, context=context) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) + x = self.proj_out(x) + return x + x_in diff --git a/optimizedSD/v1-inference.yaml b/optimizedSD/v1-inference.yaml new file mode 100644 index 00000000..2e535fcb --- /dev/null +++ b/optimizedSD/v1-inference.yaml @@ -0,0 +1,114 @@ +modelUNet: + base_learning_rate: 1.0e-04 + target: optimizedSD.ddpm.UNet + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + + unetConfigEncode: + target: optimizedSD.openaimodelSplit.UNetModelEncode + params: + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [4, 2, 1] + num_res_blocks: 2 + channel_mult: [1, 2, 4, 4] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + unetConfigDecode: + target: optimizedSD.openaimodelSplit.UNetModelDecode + params: + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [4, 2, 1] + num_res_blocks: 2 + channel_mult: [1, 2, 4, 4] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + +modelFirstStage: + target: optimizedSD.ddpm.FirstStage + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + +modelCondStage: + target: optimizedSD.ddpm.CondStage + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder + params: + device: cpu diff --git a/patch_p2p.patch b/patch_p2p.patch new file mode 100644 index 00000000..fd6ef4cb --- /dev/null +++ b/patch_p2p.patch @@ -0,0 +1,49 @@ +From 269833067de1e7d0b6a6bd65724743d6b88a133f Mon Sep 17 00:00:00 2001 +From: Kyle +Date: Thu, 2 Feb 2023 09:37:01 -0500 +Subject: [PATCH] instruct-pix2pix support + +--- + modules/processing.py | 2 +- + modules/sd_samplers_kdiffusion.py | 8 ++++---- + 2 files changed, 5 insertions(+), 5 deletions(-) + +diff --git a/modules/processing.py b/modules/processing.py +index e544c2e16..f299e04da 100644 +--- a/modules/processing.py ++++ b/modules/processing.py +@@ -186,7 +186,7 @@ def depth2img_image_conditioning(self, source_image): + return conditioning + + def edit_image_conditioning(self, source_image): +- conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image)) ++ conditioning_image = self.sd_model.encode_first_stage(source_image).mode() + + return conditioning_image + +diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py +index aa7f106b3..31ee22d3f 100644 +--- a/modules/sd_samplers_kdiffusion.py ++++ b/modules/sd_samplers_kdiffusion.py +@@ -77,9 +77,9 @@ def forward(self, x, sigma, uncond, cond, cond_scale, image_cond): + batch_size = len(conds_list) + repeats = [len(conds_list[i]) for i in range(batch_size)] + +- x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) +- image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond]) +- sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) ++ x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x]) ++ sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma]) ++ image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [image_cond]) + + denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps) + cfg_denoiser_callback(denoiser_params) +@@ -88,7 +88,7 @@ def forward(self, x, sigma, uncond, cond, cond_scale, image_cond): + sigma_in = denoiser_params.sigma + + if tensor.shape[1] == uncond.shape[1]: +- cond_in = torch.cat([tensor, uncond]) ++ cond_in = torch.cat([tensor, uncond, uncond]) + + if shared.batch_cond_uncond: + x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]}) \ No newline at end of file diff --git a/scripts/external_masking.py b/scripts/external_masking.py new file mode 100644 index 00000000..3193d67f --- /dev/null +++ b/scripts/external_masking.py @@ -0,0 +1,271 @@ +import math +import os +import sys +import traceback + + +import cv2 +from PIL import Image +import numpy as np + +lastx,lasty=None,None +zoomOrigin = 0,0 +zoomFactor = 1 + +midDragStart = None + +def display_mask_ui(image,mask,max_size,initPolys): + global lastx,lasty,zoomOrigin,zoomFactor + + lastx,lasty=None,None + zoomOrigin = 0,0 + zoomFactor = 1 + + polys = initPolys + + def on_mouse(event, x, y, buttons, param): + global lastx,lasty,zoomFactor,midDragStart,zoomOrigin + + lastx,lasty = (x+zoomOrigin[0])/zoomFactor,(y+zoomOrigin[1])/zoomFactor + + if event == cv2.EVENT_LBUTTONDOWN: + polys[-1].append((lastx,lasty)) + elif event == cv2.EVENT_RBUTTONDOWN: + polys.append([]) + elif event == cv2.EVENT_MBUTTONDOWN: + midDragStart = zoomOrigin[0]+x,zoomOrigin[1]+y + elif event == cv2.EVENT_MBUTTONUP: + if midDragStart is not None: + zoomOrigin = max(0,midDragStart[0]-x),max(0,midDragStart[1]-y) + midDragStart = None + elif event == cv2.EVENT_MOUSEMOVE: + if midDragStart is not None: + zoomOrigin = max(0,midDragStart[0]-x),max(0,midDragStart[1]-y) + elif event == cv2.EVENT_MOUSEWHEEL: + origZoom = zoomFactor + if buttons > 0: + zoomFactor *= 1.1 + else: + zoomFactor *= 0.9 + zoomFactor = max(1,zoomFactor) + + zoomOrigin = max(0,int(zoomOrigin[0]+ (max_size*0.25*(zoomFactor-origZoom)))) , max(0,int(zoomOrigin[1] + (max_size*0.25*(zoomFactor-origZoom)))) + + + + opencvImage = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) + + if mask is None: + opencvMask = cv2.cvtColor( np.array(opencvImage) , cv2.COLOR_BGR2GRAY) + else: + opencvMask = np.array(mask) + + + maxdim = max(opencvImage.shape[1],opencvImage.shape[0]) + + factor = max_size/maxdim + + + cv2.namedWindow('MaskingWindow', cv2.WINDOW_AUTOSIZE) + cv2.setWindowProperty('MaskingWindow', cv2.WND_PROP_TOPMOST, 1) + cv2.setMouseCallback('MaskingWindow', on_mouse) + + font = cv2.FONT_HERSHEY_SIMPLEX + + srcImage = opencvImage.copy() + combinedImage = opencvImage.copy() + + interp = cv2.INTER_CUBIC + if zoomFactor*factor < 0: + interp = cv2.INTER_AREA + + zoomedSrc = cv2.resize(srcImage,(None,None),fx=zoomFactor*factor,fy=zoomFactor*factor,interpolation=interp) + zoomedSrc = zoomedSrc[zoomOrigin[1]:zoomOrigin[1]+max_size,zoomOrigin[0]:zoomOrigin[0]+max_size,:] + + lastZoomFactor = zoomFactor + lastZoomOrigin = zoomOrigin + while 1: + + if lastZoomFactor != zoomFactor or lastZoomOrigin != zoomOrigin: + interp = cv2.INTER_CUBIC + if zoomFactor*factor < 0: + interp = cv2.INTER_AREA + zoomedSrc = cv2.resize(srcImage,(None,None),fx=zoomFactor*factor,fy=zoomFactor*factor,interpolation=interp) + zoomedSrc = zoomedSrc[zoomOrigin[1]:zoomOrigin[1]+max_size,zoomOrigin[0]:zoomOrigin[0]+max_size,:] + zoomedSrc = cv2.copyMakeBorder(zoomedSrc, 0, max_size-zoomedSrc.shape[0], 0, max_size-zoomedSrc.shape[1], cv2.BORDER_CONSTANT) + + lastZoomFactor = zoomFactor + lastZoomOrigin = zoomOrigin + + foreground = np.zeros_like(zoomedSrc) + + for i,polyline in enumerate(polys): + if len(polyline)>0: + + segs = polyline[::] + + active=False + if len(polys[-1])>0 and i==len(polys)-1 and lastx is not None: + segs = polyline+[(lastx,lasty)] + active=True + + segs = np.array(segs) - np.array([(zoomOrigin[0]/zoomFactor,zoomOrigin[1]/zoomFactor)]) + segs = (np.array([segs])*zoomFactor).astype(int) + + if active: + cv2.fillPoly(foreground, (np.array(segs)) , ( 190, 107, 253), 0) + else: + cv2.fillPoly(foreground, (np.array(segs)) , (255, 255, 255), 0) + + if active: + for x,y in segs[0]: + cv2.circle(foreground, (int(x),int(y)), 5, (25,25,25), 3) + cv2.circle(foreground, (int(x),int(y)), 5, (255,255,255), 2) + + + foreground[foreground<1] = zoomedSrc[foreground<1] + combinedImage = cv2.addWeighted(zoomedSrc, 0.5, foreground, 0.5, 0) + + helpText='Q=Save, C=Reset, LeftClick=Add new point to polygon, Rightclick=Close polygon, MouseWheel=Zoom, MidDrag=Pan' + combinedImage = cv2.putText(combinedImage, helpText, (0,11), font, 0.4, (0,0,0), 2, cv2.LINE_AA) + combinedImage = cv2.putText(combinedImage, helpText, (0,11), font, 0.4, (255,255,255), 1, cv2.LINE_AA) + + cv2.imshow('MaskingWindow',combinedImage) + + try: + key = cv2.waitKey(1) + if key == ord('q'): + if len(polys[0])>0: + newmask = np.zeros_like(cv2.cvtColor( opencvMask.astype('uint8') ,cv2.COLOR_GRAY2BGR) ) + for i,polyline in enumerate(polys): + if len(polyline)>0: + segs = [(int(a/factor),int(b/factor)) for a,b in polyline] + cv2.fillPoly(newmask, np.array([segs]), (255,255,255), 0) + cv2.destroyWindow('MaskingWindow') + return Image.fromarray( cv2.cvtColor( newmask, cv2.COLOR_BGR2GRAY) ),polys + break + if key == ord('c'): + polys = [[]] + + except Exception as e: + print(e) + break + + cv2.destroyWindow('MaskingWindow') + return mask,polys + +if __name__ == '__main__': + img = Image.open('K:\\test2.png') + oldmask = Image.new('L',img.size,(0,)) + newmask,newPolys = display_mask_ui(img,oldmask,1024,[[]]) + + opencvImg = cv2.cvtColor( np.array(img) , cv2.COLOR_RGB2BGR) + opencvMask = cv2.cvtColor( np.array(newmask) , cv2.COLOR_GRAY2BGR) + + combinedImage = cv2.addWeighted(opencvImg, 0.5, opencvMask, 0.5, 0) + combinedImage = Image.fromarray( cv2.cvtColor( combinedImage , cv2.COLOR_BGR2RGB)) + + display_mask_ui(combinedImage,oldmask,1024,[[]]) + + + exit() + +import modules.scripts as scripts +import gradio as gr + +from modules.processing import Processed, process_images +from modules.shared import opts, cmd_opts, state + +class Script(scripts.Script): + + def title(self): + return "External Image Masking" + + def show(self, is_img2img): + return is_img2img + + def ui(self, is_img2img): + if not is_img2img: + return None + + initialSize = 1024 + + try: + import tkinter as tk + root = tk.Tk() + screen_width = int(root.winfo_screenwidth()) + screen_height = int(root.winfo_screenheight()) + print(screen_width,screen_height) + initialSize = min(screen_width,screen_height)-50 + print(initialSize) + except Exception as e: + print(e) + + max_size = gr.Slider(label="Masking preview size", minimum=512, maximum=initialSize*2, step=8, value=initialSize) + with gr.Row(): + ask_on_each_run = gr.Checkbox(label='Draw new mask on every run', value=False) + non_contigious_split = gr.Checkbox(label='Process non-contigious masks separately', value=False) + + return [max_size,ask_on_each_run,non_contigious_split] + + def run(self, p, max_size, ask_on_each_run, non_contigious_split): + + if not hasattr(self,'lastImg'): + self.lastImg = None + + if not hasattr(self,'lastMask'): + self.lastMask = None + + if not hasattr(self,'lastPolys'): + self.lastPolys = [[]] + + if ask_on_each_run or self.lastImg is None or self.lastImg != p.init_images[0]: + + if self.lastImg is None or self.lastImg != p.init_images[0]: + self.lastPolys = [[]] + + p.image_mask,self.lastPolys = display_mask_ui(p.init_images[0],p.image_mask,max_size,self.lastPolys) + self.lastImg = p.init_images[0] + if p.image_mask is not None: + self.lastMask = p.image_mask.copy() + elif hasattr(self,'lastMask') and self.lastMask is not None: + p.image_mask = self.lastMask.copy() + + if non_contigious_split: + maskImgArr = np.array(p.image_mask) + ret, markers = cv2.connectedComponents(maskImgArr) + markerCount = markers.max() + + if markerCount > 1: + tempimages = [] + tempMasks = [] + for maski in range(1,markerCount+1): + print('maski',maski) + maskSection = np.zeros_like(maskImgArr) + maskSection[markers==maski] = 255 + p.image_mask = Image.fromarray( maskSection.copy() ) + proc = process_images(p) + images = proc.images + tempimages.append(np.array(images[0])) + tempMasks.append(np.array(maskSection.copy())) + + finalImage = tempimages[0].copy() + + for outimg,outmask in zip(tempimages,tempMasks): + + resizeimg = cv2.resize(outimg, (finalImage.shape[0],finalImage.shape[1]) ) + resizedMask = cv2.resize(outmask, (finalImage.shape[0],finalImage.shape[1]) ) + + finalImage[resizedMask==255] = resizeimg[resizedMask==255] + images = [finalImage] + + + else: + proc = process_images(p) + images = proc.images + else: + proc = process_images(p) + images = proc.images + + proc.images = images + return proc diff --git a/scripts/prompt_matrix.py b/scripts/prompt_matrix.py index de921ea8..dd95e588 100644 --- a/scripts/prompt_matrix.py +++ b/scripts/prompt_matrix.py @@ -44,40 +44,16 @@ class Script(scripts.Script): def title(self): return "Prompt matrix" - def ui(self, is_img2img): - gr.HTML('
') - with gr.Row(): - with gr.Column(): - put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', - value=False, elem_id=self.elem_id("put_at_start")) - with gr.Column(): - # Radio buttons for selecting the prompt between positive and negative - prompt_type = gr.Radio(["positive", "negative"], label="Select prompt", - elem_id=self.elem_id("prompt_type"), value="positive") - with gr.Row(): - with gr.Column(): - different_seeds = gr.Checkbox( - label='Use different seed for each picture', value=False, elem_id=self.elem_id("different_seeds")) - with gr.Column(): - # Radio buttons for selecting the delimiter to use in the resulting prompt - variations_delimiter = gr.Radio(["comma", "space"], label="Select delimiter", elem_id=self.elem_id( - "variations_delimiter"), value="comma") - return [put_at_start, different_seeds, prompt_type, variations_delimiter] + def ui(self, is_img2img): + put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', value=False, elem_id=self.elem_id("put_at_start")) + different_seeds = gr.Checkbox(label='Use different seed for each picture', value=False, elem_id=self.elem_id("different_seeds")) - def run(self, p, put_at_start, different_seeds, prompt_type, variations_delimiter): + return [put_at_start, different_seeds] + + def run(self, p, put_at_start, different_seeds): modules.processing.fix_seed(p) - # Raise error if promp type is not positive or negative - if prompt_type not in ["positive", "negative"]: - raise ValueError(f"Unknown prompt type {prompt_type}") - # Raise error if variations delimiter is not comma or space - if variations_delimiter not in ["comma", "space"]: - raise ValueError(f"Unknown variations delimiter {variations_delimiter}") - prompt = p.prompt if prompt_type == "positive" else p.negative_prompt - original_prompt = prompt[0] if type(prompt) == list else prompt - positive_prompt = p.prompt[0] if type(p.prompt) == list else p.prompt - - delimiter = ", " if variations_delimiter == "comma" else " " + original_prompt = p.prompt[0] if type(p.prompt) == list else p.prompt all_prompts = [] prompt_matrix_parts = original_prompt.split("|") @@ -90,19 +66,16 @@ class Script(scripts.Script): else: selected_prompts = [prompt_matrix_parts[0]] + selected_prompts - all_prompts.append(delimiter.join(selected_prompts)) + all_prompts.append(", ".join(selected_prompts)) p.n_iter = math.ceil(len(all_prompts) / p.batch_size) p.do_not_save_grid = True print(f"Prompt matrix will create {len(all_prompts)} images using a total of {p.n_iter} batches.") - if prompt_type == "positive": - p.prompt = all_prompts - else: - p.negative_prompt = all_prompts + p.prompt = all_prompts p.seed = [p.seed + (i if different_seeds else 0) for i in range(len(all_prompts))] - p.prompt_for_display = positive_prompt + p.prompt_for_display = original_prompt processed = process_images(p) grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2)) diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 3122f6f6..3df40483 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -286,24 +286,23 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend print("Unexpected error: draw_xyz_grid failed to return even a single processed image") return Processed(p, []) - sub_grids = [None] * len(zs) + grids = [None] * len(zs) for i in range(len(zs)): start_index = i * len(xs) * len(ys) end_index = start_index + len(xs) * len(ys) grid = images.image_grid(image_cache[start_index:end_index], rows=len(ys)) if draw_legend: grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts) - sub_grids[i] = grid + + grids[i] = grid if include_sub_grids and len(zs) > 1: processed_result.images.insert(i+1, grid) - sub_grid_size = sub_grids[0].size - z_grid = images.image_grid(sub_grids, rows=1) - if draw_legend: - z_grid = images.draw_grid_annotations(z_grid, sub_grid_size[0], sub_grid_size[1], title_texts, [[images.GridAnnotation()]]) - processed_result.images[0] = z_grid + original_grid_size = grids[0].size + grids = images.image_grid(grids, rows=1) + processed_result.images[0] = images.draw_grid_annotations(grids, original_grid_size[0], original_grid_size[1], title_texts, [[images.GridAnnotation()]]) - return processed_result, sub_grids + return processed_result class SharedSettingsStackHelper(object): @@ -577,7 +576,7 @@ class Script(scripts.Script): return res with SharedSettingsStackHelper(): - processed, sub_grids = draw_xyz_grid( + processed = draw_xyz_grid( p, xs=xs, ys=ys, @@ -593,10 +592,6 @@ class Script(scripts.Script): second_axes_processed=second_axes_processed ) - if opts.grid_save and len(sub_grids) > 1: - for sub_grid in sub_grids: - images.save_image(sub_grid, p.outpath_grids, "xyz_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p) - if opts.grid_save: images.save_image(processed.images[0], p.outpath_grids, "xyz_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p) diff --git a/special_start_medvram ohne xformers.bat b/special_start_medvram ohne xformers.bat new file mode 100644 index 00000000..9377ad45 --- /dev/null +++ b/special_start_medvram ohne xformers.bat @@ -0,0 +1,8 @@ +@echo off + +set PYTHON= +set GIT= +set VENV_DIR= +set COMMANDLINE_ARGS=--opt-split-attention-v1 --medvram --api + +call webui.bat diff --git a/special_start_medvram.bat b/special_start_medvram.bat new file mode 100644 index 00000000..b0803c41 --- /dev/null +++ b/special_start_medvram.bat @@ -0,0 +1,8 @@ +@echo off + +set PYTHON= +set GIT= +set VENV_DIR= +set COMMANDLINE_ARGS=--opt-split-attention --medvram --api --xformers + +call webui.bat diff --git a/webui-user.bat b/webui-user.bat index e5a257be..dacdf1b9 100644 --- a/webui-user.bat +++ b/webui-user.bat @@ -3,6 +3,7 @@ set PYTHON= set GIT= set VENV_DIR= -set COMMANDLINE_ARGS= +set COMMANDLINE_ARGS=--medvram --api +git pull call webui.bat