""" This code started out as a PyTorch port of Ho et al's diffusion models: https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules. """ import enum import math import random import numpy as np import torch import torch as th from torch.distributions import Normal from tqdm import tqdm from models.diffusion.nn import mean_flat from models.diffusion.losses import normal_kl, discretized_gaussian_log_likelihood def causal_timestep_adjustment(t, S, num_timesteps, causal_slope=1, add_jitter=True): """ Remaps [t] from a batch of integers into a causal sequence [S] long where each sequence element is [causal_slope] timesteps advanced from the previous sequence element. At t=0, the sequence is all 0s and at t=[num_timesteps], the sequence is all [num_timesteps]. As a result of the last property, longer sequences will have larger "gaps" between them in continuous space. This must be considered at inference time. Specifically, you should allot ((num_timesteps+causal_slope*(seq_len-1))/num_timesteps) times more timesteps in inference for the same quality. :param t: Batched timestep integers :param S: Sequence length. :param num_timesteps: Number of total timesteps. :param causal_slope: The causal slope. Ex: "2" means each sequence element will be 2 timesteps ahead of its predecessor. :param add_jitter: Whether or not to add random jitter into the extra gaps between timesteps added by this function. Should be true for training and false for inference. :return: [b,S] sequence of timestep integers. """ S_sloped = causal_slope * (S-1) # This algorithm for adding causality does so by simply adding S_sloped additional timesteps. To make this # actually work, we map the existing t from the timescale specified to the model to the causal timescale: adj_t = torch.div(t * (num_timesteps + S_sloped), num_timesteps, rounding_mode='floor') adj_t = adj_t - S_sloped if add_jitter: t_gap = (num_timesteps + S_sloped) / num_timesteps jitter = (2*random.random()-1) * t_gap adj_t = (adj_t+jitter).clamp(-S_sloped, num_timesteps) # Now use the re-mapped adj_t to create a timestep vector that propagates across the sequence with the specified slope. t = adj_t.unsqueeze(1).repeat(1, S) t = (t + torch.arange(0, S, device=t.device) * causal_slope).clamp(-1, num_timesteps).long() return t def causal_mask_and_fix(t, num_timesteps): mask1 = t == num_timesteps t[mask1] = num_timesteps-1 mask2 = t == -1 t[mask2] = 0 return t, mask1.logical_or(mask2) def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): """ Get a pre-defined beta schedule for the given name. The beta schedule library consists of beta schedules which remain similar in the limit of num_diffusion_timesteps. Beta schedules may be added, but should not be removed or changed once they are committed to maintain backwards compatibility. """ if schedule_name == "linear": # Linear schedule from Ho et al, extended to work for any number of # diffusion steps. scale = 1000 / num_diffusion_timesteps beta_start = scale * 0.0001 beta_end = scale * 0.02 return np.linspace( beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 ) elif schedule_name == "cosine": return betas_for_alpha_bar( num_diffusion_timesteps, lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, ) else: raise NotImplementedError(f"unknown beta schedule: {schedule_name}") def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. :param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t from 0 to 1 and produces the cumulative product of (1-beta) up to that part of the diffusion process. :param max_beta: the maximum beta to use; use values lower than 1 to prevent singularities. """ betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) return np.array(betas) class ModelMeanType(enum.Enum): """ Which type of output the model predicts. """ PREVIOUS_X = 'previous_x' # the model predicts x_{t-1} START_X = 'start_x' # the model predicts x_0 EPSILON = 'epsilon' # the model predicts epsilon class ModelVarType(enum.Enum): """ What is used as the model's output variance. The LEARNED_RANGE option has been added to allow the model to predict values between FIXED_SMALL and FIXED_LARGE, making its job easier. """ LEARNED = 'learned' FIXED_SMALL = 'fixed_small' FIXED_LARGE = 'fixed_large' LEARNED_RANGE = 'learned_range' class LossType(enum.Enum): MSE = 'mse' # use raw MSE loss (and KL when learning variances) RESCALED_MSE = 'rescaled_mse' # use raw MSE loss (with RESCALED_KL when learning variances) KL = 'kl' # use the variational lower-bound RESCALED_KL = 'rescaled_kl' # like KL, but rescale to estimate the full VLB def is_vb(self): return self == LossType.KL or self == LossType.RESCALED_KL class GaussianDiffusion: """ Utilities for training and sampling diffusion models. Ported directly from here, and then adapted over time to further experimentation. https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 :param betas: a 1-D numpy array of betas for each diffusion timestep, starting at T and going to 1. :param model_mean_type: a ModelMeanType determining what the model outputs. :param model_var_type: a ModelVarType determining how variance is output. :param loss_type: a LossType determining the loss function to use. :param rescale_timesteps: if True, pass floating point timesteps into the model so that they are always scaled like in the original paper (0 to 1000). """ def __init__( self, *, betas, model_mean_type, model_var_type, loss_type, rescale_timesteps=False, conditioning_free=False, conditioning_free_k=1, ramp_conditioning_free=True, ): self.model_mean_type = ModelMeanType(model_mean_type) self.model_var_type = ModelVarType(model_var_type) self.loss_type = LossType(loss_type) self.rescale_timesteps = rescale_timesteps self.conditioning_free = conditioning_free self.conditioning_free_k = conditioning_free_k self.ramp_conditioning_free = ramp_conditioning_free # Use float64 for accuracy. betas = np.array(betas, dtype=np.float64) self.betas = betas assert len(betas.shape) == 1, "betas must be 1-D" assert (betas > 0).all() and (betas <= 1).all() self.num_timesteps = int(betas.shape[0]) alphas = 1.0 - betas self.alphas_cumprod = np.cumprod(alphas, axis=0) self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) # calculations for diffusion q(x_t | x_{t-1}) and others self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) # calculations for posterior q(x_{t-1} | x_t, x_0) self.posterior_variance = ( betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) ) # log calculation clipped because the posterior variance is 0 at the # beginning of the diffusion chain. self.posterior_log_variance_clipped = np.log( np.append(self.posterior_variance[1], self.posterior_variance[1:]) ) self.posterior_mean_coef1 = ( betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) ) self.posterior_mean_coef2 = ( (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) ) def q_mean_variance(self, x_start, t): """ Get the distribution q(x_t | x_0). :param x_start: the [N x C x ...] tensor of noiseless inputs. :param t: the number of diffusion steps (minus 1). Here, 0 means one step. :return: A tuple (mean, variance, log_variance), all of x_start's shape. """ mean = ( _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start ) variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) log_variance = _extract_into_tensor( self.log_one_minus_alphas_cumprod, t, x_start.shape ) return mean, variance, log_variance def q_sample(self, x_start, t, noise=None, allow_negatives=False): """ Diffuse the data for a given number of diffusion steps. In other words, sample from q(x_t | x_0). :param x_start: the initial data batch. :param t: the number of diffusion steps (minus 1). Here, 0 means one step. :param noise: if specified, the split-out normal noise. :return: A noisy version of x_start. """ if noise is None: noise = th.randn_like(x_start) assert noise.shape == x_start.shape if allow_negatives: mask = (t < 0) t[mask] = 0 result = ( _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ) if allow_negatives: result[mask] = x_start[mask] return result def q_posterior_mean_variance(self, x_start, x_t, t): """ Compute the mean and variance of the diffusion posterior: q(x_{t-1} | x_t, x_0) """ assert x_start.shape == x_t.shape posterior_mean = ( _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t ) posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) posterior_log_variance_clipped = _extract_into_tensor( self.posterior_log_variance_clipped, t, x_t.shape ) assert ( posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == x_start.shape[0] ) return posterior_mean, posterior_variance, posterior_log_variance_clipped def p_mean_variance( self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None ): """ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of the initial x, x_0. :param model: the model, which takes a signal and a batch of timesteps as input. :param x: the [N x C x ...] tensor at time t. :param t: a 1-D Tensor of timesteps. :param clip_denoised: if True, clip the denoised signal into [-1, 1]. :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. Applies before clip_denoised. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict with the following keys: - 'mean': the model mean output. - 'variance': the model variance output. - 'log_variance': the log of 'variance'. - 'pred_xstart': the prediction for x_0. """ if model_kwargs is None: model_kwargs = {} B, C = x.shape[:2] assert t.shape == (B,) or t.shape == (B,1,x.shape[-1]) model_output = model(x, self._scale_timesteps(t), **model_kwargs) if self.conditioning_free: model_output_no_conditioning = model(x, self._scale_timesteps(t), conditioning_free=True, **model_kwargs) if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: assert model_output.shape == (B, C * 2, *x.shape[2:]) model_output, model_var_values = th.split(model_output, C, dim=1) if self.conditioning_free: model_output_no_conditioning, _ = th.split(model_output_no_conditioning, C, dim=1) if self.model_var_type == ModelVarType.LEARNED: model_log_variance = model_var_values model_variance = th.exp(model_log_variance) else: min_log = _extract_into_tensor( self.posterior_log_variance_clipped, t, x.shape ) max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) # The model_var_values is [-1, 1] for [min_var, max_var]. frac = (model_var_values + 1) / 2 model_log_variance = frac * max_log + (1 - frac) * min_log model_variance = th.exp(model_log_variance) else: model_variance, model_log_variance = { # for fixedlarge, we set the initial (log-)variance like so # to get a better decoder log likelihood. ModelVarType.FIXED_LARGE: ( np.append(self.posterior_variance[1], self.betas[1:]), np.log(np.append(self.posterior_variance[1], self.betas[1:])), ), ModelVarType.FIXED_SMALL: ( self.posterior_variance, self.posterior_log_variance_clipped, ), }[self.model_var_type] model_variance = _extract_into_tensor(model_variance, t, x.shape) model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) if self.conditioning_free: if self.ramp_conditioning_free: assert t.shape[0] == 1 # This should only be used in inference. cfk = self.conditioning_free_k * (1 - self._scale_timesteps(t).float().mean().item() / self.num_timesteps) else: cfk = self.conditioning_free_k model_output = (1 + cfk) * model_output - cfk * model_output_no_conditioning # TODO: combine variance predictions here similarly. def process_xstart(x): if denoised_fn is not None: x = denoised_fn(x) if clip_denoised: return x.clamp(-1, 1) return x if self.model_mean_type == ModelMeanType.PREVIOUS_X: assert 'why are you doing this?' pred_xstart = process_xstart( self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output) ) model_mean = model_output elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]: if self.model_mean_type == ModelMeanType.START_X: assert 'bad boy.' pred_xstart = process_xstart(model_output) else: eps = model_output pred_xstart = process_xstart( self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) ) model_mean, _, _ = self.q_posterior_mean_variance( x_start=pred_xstart, x_t=x, t=t ) else: raise NotImplementedError(self.model_mean_type) assert ( model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape ) return { "mean": model_mean, "variance": model_variance, "log_variance": model_log_variance, "pred_xstart": pred_xstart, "pred_eps": eps, } def _predict_xstart_from_eps(self, x_t, t, eps): assert x_t.shape == eps.shape return ( _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps ) def _predict_xstart_from_xprev(self, x_t, t, xprev): assert x_t.shape == xprev.shape return ( # (xprev - coef2*x_t) / coef1 _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev - _extract_into_tensor( self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape ) * x_t ) def _predict_eps_from_xstart(self, x_t, t, pred_xstart): return ( _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) def _scale_timesteps(self, t): if self.rescale_timesteps: return t.float() * (1000.0 / self.num_timesteps) return t def _get_scale_ratio(self): return 1 def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): """ Compute the mean for the previous step, given a function cond_fn that computes the gradient of a conditional log probability with respect to x. In particular, cond_fn computes grad(log(p(y|x))), and we want to condition on y. This uses the conditioning strategy from Sohl-Dickstein et al. (2015). """ gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs) new_mean = ( p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() ) return new_mean def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): """ Compute what the p_mean_variance output would have been, should the model's score function be conditioned by cond_fn. See condition_mean() for details on cond_fn. Unlike condition_mean(), this instead uses the conditioning strategy from Song et al (2020). """ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) eps = eps - (1 - alpha_bar).sqrt() * cond_fn( x, self._scale_timesteps(t), **model_kwargs ) out = p_mean_var.copy() out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) out["mean"], _, _ = self.q_posterior_mean_variance( x_start=out["pred_xstart"], x_t=x, t=t ) return out def p_sample( self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, ): """ Sample x_{t-1} from the model at the given timestep. :param model: the model to sample from. :param x: the current tensor at x_{t-1}. :param t: the value of t, starting at 0 for the first diffusion step. :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. :param cond_fn: if not None, this is a gradient function that acts similarly to the model. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict containing the following keys: - 'sample': a random sample from the model. - 'pred_xstart': a prediction of x_0. """ out = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) noise = th.randn_like(x) if len(t.shape) == 1: nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) ) # no noise when t == 0 else: nonzero_mask = (t != 0).float() if cond_fn is not None: out["mean"] = self.condition_mean( cond_fn, out, x, t, model_kwargs=model_kwargs ) sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise return {"sample": sample, "pred_xstart": out["pred_xstart"], "pred_eps": out["pred_eps"], "mean": out["mean"], "log_variance": out["log_variance"]} def p_sample_loop( self, model, shape, noise=None, clip_denoised=True, causal=False, causal_slope=1, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, ): """ Generate samples from the model. :param model: the model module. :param shape: the shape of the samples, (N, C, H, W). :param noise: if specified, the noise from the encoder to sample. Should be of the same shape as `shape`. :param clip_denoised: if True, clip x_start predictions to [-1, 1]. :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. :param cond_fn: if not None, this is a gradient function that acts similarly to the model. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :param device: if specified, the device to create the samples on. If not specified, use a model parameter's device. :param progress: if True, show a tqdm progress bar. :return: a non-differentiable batch of samples. """ final = None for sample in self.p_sample_loop_progressive( model, shape, noise=noise, clip_denoised=clip_denoised, causal=causal, causal_slope=causal_slope, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, device=device, progress=progress, ): final = sample return final["sample"] def p_sample_loop_progressive( self, model, shape, noise=None, clip_denoised=True, causal=False, causal_slope=1, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, ): """ Generate samples from the model and yield intermediate samples from each timestep of diffusion. Arguments are the same as p_sample_loop(). Returns a generator over dicts, where each dict is the return value of p_sample(). """ if device is None: device = next(model.parameters()).device assert isinstance(shape, (tuple, list)) if noise is not None: img = noise else: img = th.randn(*shape, device=device) indices = list(range(self.num_timesteps))[::-1] orig_img = img for i in tqdm(indices): t = th.tensor([i] * shape[0], device=device) mask = torch.zeros_like(img) if causal: t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope * self._get_scale_ratio(), add_jitter=False).unsqueeze(1) t, mask = causal_mask_and_fix(t, self.num_timesteps) mask = mask.repeat(img.shape[0], img.shape[1], 1) with th.no_grad(): out = self.p_sample( model, img, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, ) yield out img = out["sample"] if torch.any(mask): img[mask] = orig_img[mask] # For causal diffusion, keep resetting these predictions until they are unmasked. orig_img = img def p_sample_loop_with_guidance( self, model, guidance_input, mask, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, ): if device is None: device = next(model.parameters()).device shape = guidance_input.shape if noise is None: noise = th.randn(*shape, device=device) indices = list(range(self.num_timesteps))[::-1] img = noise for i in tqdm(indices): t = th.tensor([i] * shape[0], device=device) with th.no_grad(): out = self.p_sample( model, img, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, ) model_driven_out = out["sample"] * mask.logical_not() guidance_driven_out = self.q_sample(guidance_input, t, noise=noise) * mask img = model_driven_out + guidance_driven_out return img def p_sample_loop_for_log_perplexity( self, model, truth, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, ): if device is None: device = next(model.parameters()).device shape = truth.shape if noise is None: noise = th.randn(*shape, device=device) indices = list(range(self.num_timesteps))[::-1] img = noise #perp = self.num_timesteps logperp = 0 for i in tqdm(indices): t = th.tensor([i] * shape[0], device=device) with th.no_grad(): out = self.p_sample( model, img, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, ) eps = out["pred_eps"] err = noise - eps m = Normal(torch.tensor([0.0]), torch.tensor([1.0])) nprobs = m.cdf(-err.abs().cpu()) * 2 logperp = torch.log(nprobs) / self.num_timesteps + logperp #perp = nprobs * perp print(f'Num infs: : {torch.isinf(logperp).sum()}') # probably should just log this separately. logperp[torch.isinf(logperp)] = logperp.max() * 2 return -logperp def ddim_sample( self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, eta=0.0, ): """ Sample x_{t-1} from the model using DDIM. Same usage as p_sample(). """ out = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) if cond_fn is not None: out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) sigma = ( eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev) ) # Equation 12. noise = th.randn_like(x) mean_pred = ( out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps ) if len(t.shape) == 2: nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) ) # no noise when t == 0 else: nonzero_mask = (t != 0).float() sample = mean_pred + nonzero_mask * sigma * noise return {"sample": sample, "pred_xstart": out["pred_xstart"]} def ddim_reverse_sample( self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None, eta=0.0, ): """ Sample x_{t+1} from the model using DDIM reverse ODE. """ assert eta == 0.0, "Reverse ODE only for deterministic path" out = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = ( _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"] ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) # Equation 12. reversed mean_pred = ( out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps ) return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} def ddim_sample_loop( self, model, shape, noise=None, clip_denoised=True, causal=False, causal_slope=1, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=True, eta=0.0, ): """ Generate samples from the model using DDIM. Same usage as p_sample_loop(). """ final = None for sample in self.ddim_sample_loop_progressive( model, shape, noise=noise, clip_denoised=clip_denoised, causal=causal, causal_slope=causal_slope, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, device=device, progress=progress, eta=eta, ): final = sample return final["sample"] def ddim_sample_loop_with_guidance( self, model, guidance_input, mask, noise=None, clip_denoised=True, causal=False, causal_slope=1, denoised_fn=None, cond_fn=None, model_kwargs=None, eta=0.0, ): device = guidance_input.device shape = guidance_input.shape if noise is not None: img = noise else: img = th.randn(*shape, device=device) indices = list(range(self.num_timesteps))[::-1] orig_img = img for i in tqdm(indices): t = th.tensor([i] * shape[0], device=device) c_mask = torch.zeros_like(img) if causal: t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope * self._get_scale_ratio(), add_jitter=False).unsqueeze(1) t, c_mask = causal_mask_and_fix(t, self.num_timesteps) t[c_mask] = self.num_timesteps-1 c_mask = c_mask.repeat(img.shape[0], img.shape[1], 1) with th.no_grad(): out = self.ddim_sample( model, img, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, eta=eta, ) model_driven_out = out["sample"] * mask.logical_not() if torch.any(c_mask): model_driven_out[c_mask] = orig_img[c_mask] # For causal diffusion, keep resetting these predictions until they are unmasked. guidance_driven_out = self.q_sample(guidance_input, t, noise=noise) * mask img = model_driven_out + guidance_driven_out orig_img = orig_img return img def ddim_sample_loop_progressive( self, model, shape, noise=None, clip_denoised=True, causal=False, causal_slope=1, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=True, eta=0.0, ): """ Use DDIM to sample from the model and yield intermediate samples from each timestep of DDIM. Same usage as p_sample_loop_progressive(). """ if device is None: device = next(model.parameters()).device assert isinstance(shape, (tuple, list)) if noise is not None: img = noise else: img = th.randn(*shape, device=device) indices = list(range(self.num_timesteps))[::-1] if progress: # Lazy import so that we don't depend on tqdm. from tqdm.auto import tqdm indices = tqdm(indices) orig_img = img for i in indices: t = th.tensor([i] * shape[0], device=device) mask = torch.zeros_like(img) if causal: t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope * self._get_scale_ratio(), add_jitter=False).unsqueeze(1) t, mask = causal_mask_and_fix(t, self.num_timesteps) t[mask] = self.num_timesteps-1 mask = mask.repeat(img.shape[0], img.shape[1], 1) with th.no_grad(): out = self.ddim_sample( model, img, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, eta=eta, ) yield out img = out["sample"] if torch.any(mask): img[mask] = orig_img[mask] # For causal diffusion, keep resetting these predictions until they are unmasked. orig_img = orig_img def _vb_terms_bpd( self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None ): """ Get a term for the variational lower-bound. The resulting units are bits (rather than nats, as one might expect). This allows for comparison to other papers. :return: a dict with the following keys: - 'output': a shape [N] tensor of NLLs or KLs. - 'pred_xstart': the x_0 predictions. """ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( x_start=x_start, x_t=x_t, t=t ) out = self.p_mean_variance( model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs ) kl = normal_kl( true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] ) kl = kl / np.log(2.0) decoder_nll = -discretized_gaussian_log_likelihood( x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] ) assert decoder_nll.shape == x_start.shape decoder_nll = decoder_nll / np.log(2.0) # At the first timestep return the decoder NLL, # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) if len(t.shape) == 1: output = th.where((t == 0).view(-1, 1, 1), decoder_nll, kl) else: output = th.where((t == 0), decoder_nll, kl) return {"output": output, "pred_xstart": out["pred_xstart"]} def causal_training_losses(self, model, x_start, t, causal_slope=1, model_kwargs=None, noise=None, channel_balancing_fn=None): """ Compute training losses for a causal diffusion process. """ assert len(x_start.shape) == 3, "causal_training_losses assumes a 1d sequence with the axis being the time axis." ct = causal_timestep_adjustment(t, x_start.shape[-1], self.num_timesteps, causal_slope * self._get_scale_ratio(), add_jitter=True) ct = ct.unsqueeze(1) # Necessary to make the output shape compatible with x_start. return self.training_losses(model, x_start, ct, model_kwargs, noise, channel_balancing_fn) def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, channel_balancing_fn=None): """ Compute training losses for a single timestep. :param model: the model to evaluate loss on. :param x_start: the [N x C x ...] tensor of inputs. :param t: a batch of timestep indices. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :param noise: if specified, the specific Gaussian noise to try to remove. :return: a dict with the key "loss" containing a tensor of shape [N]. Some mean or variance settings may also have other keys. """ if model_kwargs is None: model_kwargs = {} if noise is None: noise = th.randn_like(x_start) if len(t.shape) == 3: t, t_mask = causal_mask_and_fix(t, self.num_timesteps) t_mask = t_mask.logical_not() # This is used to mask out losses for timesteps that are out of bounds. else: t_mask = torch.ones_like(x_start) x_t = self.q_sample(x_start, t, noise=noise) terms = {} if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: # TODO: support multiple model outputs for this mode. terms["loss"] = mean_flat(self._vb_terms_bpd( model=model, x_start=x_start, x_t=x_t, t=t, clip_denoised=False, model_kwargs=model_kwargs, )["output"]) if self.loss_type == LossType.RESCALED_KL: terms["loss"] *= self.num_timesteps elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: model_outputs = model(x_t, self._scale_timesteps(t), **model_kwargs) if isinstance(model_outputs, tuple): model_output = model_outputs[0] terms['extra_outputs'] = model_outputs[1:] else: model_output = model_outputs if self.model_var_type in [ ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE, ]: B, C = x_t.shape[:2] assert model_output.shape == (B, C * 2, *x_t.shape[2:]) model_output, model_var_values = th.split(model_output, C, dim=1) # Learn the variance using the variational bound, but don't let # it affect our mean prediction. frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) terms["vb"] = self._vb_terms_bpd( model=lambda *args, r=frozen_out: r, x_start=x_start, x_t=x_t, t=t, clip_denoised=False, )["output"] if self.loss_type == LossType.RESCALED_MSE: # Divide by 1000 for equivalence with initial implementation. # Without a factor of 1/1000, the VB term hurts the MSE term. terms["vb"] *= self.num_timesteps / 1000.0 if self.model_mean_type == ModelMeanType.PREVIOUS_X: target = self.q_posterior_mean_variance( x_start=x_start, x_t=x_t, t=t )[0] x_start_pred = torch.zeros(x_start) # Not supported. elif self.model_mean_type == ModelMeanType.START_X: target = x_start x_start_pred = model_output elif self.model_mean_type == ModelMeanType.EPSILON: target = noise x_start_pred = self._predict_xstart_from_eps(x_t, t, model_output) else: raise NotImplementedError(self.model_mean_type) assert model_output.shape == target.shape == x_start.shape s_err = t_mask * (target - model_output) ** 2 if channel_balancing_fn is not None: s_err = channel_balancing_fn(s_err) terms["mse_by_batch"] = s_err.reshape(s_err.shape[0], -1).mean(dim=1) terms["mse"] = mean_flat(s_err) terms["vb"] = terms["vb"] * t_mask terms["x_start_predicted"] = x_start_pred if "vb" in terms: if channel_balancing_fn is not None: terms["vb"] = channel_balancing_fn(terms["vb"]) terms["loss"] = terms["mse"] + mean_flat(terms["vb"]) else: terms["loss"] = terms["mse"] else: raise NotImplementedError(self.loss_type) return terms def _prior_bpd(self, x_start): """ Get the prior KL term for the variational lower-bound, measured in bits-per-dim. This term can't be optimized, as it only depends on the encoder. :param x_start: the [N x C x ...] tensor of inputs. :return: a batch of [N] KL values (in bits), one per batch element. """ batch_size = x_start.shape[0] t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) kl_prior = normal_kl( mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 ) return mean_flat(kl_prior) / np.log(2.0) def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): """ Compute the entire variational lower-bound, measured in bits-per-dim, as well as other related quantities. :param model: the model to evaluate loss on. :param x_start: the [N x C x ...] tensor of inputs. :param clip_denoised: if True, clip denoised samples. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict containing the following keys: - total_bpd: the total variational lower-bound, per batch element. - prior_bpd: the prior term in the lower-bound. - vb: an [N x T] tensor of terms in the lower-bound. - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. - mse: an [N x T] tensor of epsilon MSEs for each timestep. """ device = x_start.device batch_size = x_start.shape[0] vb = [] xstart_mse = [] mse = [] for t in list(range(self.num_timesteps))[::-1]: t_batch = th.tensor([t] * batch_size, device=device) noise = th.randn_like(x_start) x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) # Calculate VLB term at the current timestep with th.no_grad(): out = mean_flat(self._vb_terms_bpd( model, x_start=x_start, x_t=x_t, t=t_batch, clip_denoised=clip_denoised, model_kwargs=model_kwargs, )) vb.append(out["output"]) xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) mse.append(mean_flat((eps - noise) ** 2)) vb = th.stack(vb, dim=1) xstart_mse = th.stack(xstart_mse, dim=1) mse = th.stack(mse, dim=1) prior_bpd = self._prior_bpd(x_start) total_bpd = vb.sum(dim=1) + prior_bpd return { "total_bpd": total_bpd, "prior_bpd": prior_bpd, "vb": vb, "xstart_mse": xstart_mse, "mse": mse, } def _extract_into_tensor(arr, timesteps, broadcast_shape): """ Extract values from a 1-D numpy array for a batch of indices. :param arr: the 1-D numpy array. :param timesteps: a tensor of indices into the array to extract. :param broadcast_shape: a larger shape of K dimensions with the batch dimension equal to the length of timesteps. :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. """ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() while len(res.shape) < len(broadcast_shape): res = res[..., None] return res.expand(broadcast_shape) def test_causal_training_losses(): from models.diffusion.respace import SpacedDiffusion from models.diffusion.respace import space_timesteps diff = SpacedDiffusion(use_timesteps=space_timesteps(4000, [4000]), model_mean_type='epsilon', model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', 4000), conditioning_free=False, conditioning_free_k=1) class IdentityTwoArg(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, *args, **kwargs): return x.repeat(1,2,1) model = IdentityTwoArg() diff.causal_training_losses(model, torch.randn(4,256,400), torch.tensor([500,1000,3000,3500]), causal_slope=4) def graph_causal_timestep_adjustment(): import matplotlib.pyplot as plt S = 400 #slope=4 num_timesteps=4000 for slpe in range(10, 400, 50): slope = slpe / 10 t_res = [] for t in range(num_timesteps, -1, -num_timesteps//50): T = causal_timestep_adjustment(torch.tensor([t]), S, num_timesteps, causal_slope=slope, add_jitter=False)[0] # The following adjustment makes it easier to visualize the timestep regions where the model is actually working. #T_adj = (T == num_timesteps).logical_or(T == -1) #T[T_adj] = t t_res.append(T) plt.plot(T.numpy()) for i in range(len(t_res)): for j in range(len(t_res)): if i == j: continue assert not torch.all(t_res[i] == t_res[j]) plt.ylim(0,num_timesteps) plt.xlim(0,4000) plt.ylabel('timestep') plt.savefig(f'{slpe}.png') plt.clf() def graph_causal_timestep_adjustment_by_timestep(): import matplotlib.pyplot as plt S = 400 slope=8 num_timesteps=4000 t_res = [] for t in range(num_timesteps, -1, -num_timesteps//50): T = causal_timestep_adjustment(torch.tensor([t]), S, num_timesteps, causal_slope=slope, add_jitter=False)[0] t_res.append(T) plt.plot(T.numpy()) plt.ylim(0,num_timesteps) plt.xlim(0,4000) plt.ylabel('timestep') plt.savefig(f'{t}.png') plt.clf() if __name__ == '__main__': #test_causal_training_losses() #graph_causal_timestep_adjustment() graph_causal_timestep_adjustment_by_timestep()