# Adapted from https://github.com/neonbjb/tortoise-tts/tree/98a891e66e7a1f11a830f31bd1ce06cc1f6a88af/tortoise/models/diffusion.py

import enum
import math
import random
from tqdm import tqdm
from abc import abstractmethod

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from tqdm.auto import tqdm

from torch import autocast

from .arch_utils import normalization, AttentionBlock


"""
This is an almost carbon copy of gaussian_diffusion.py from OpenAI's ImprovedDiffusion repo, which itself:

This code started out as a PyTorch port of Ho et al's diffusion models:
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py

Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules.
"""



def normal_kl(mean1, logvar1, mean2, logvar2):
	"""
	Compute the KL divergence between two gaussians.

	Shapes are automatically broadcasted, so batches can be compared to
	scalars, among other use cases.
	"""
	tensor = None
	for obj in (mean1, logvar1, mean2, logvar2):
		if isinstance(obj, torch.Tensor):
			tensor = obj
			break
	assert tensor is not None, "at least one argument must be a Tensor"

	# Force variances to be Tensors. Broadcasting helps convert scalars to
	# Tensors, but it does not work for torch.exp().
	logvar1, logvar2 = [
		x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
		for x in (logvar1, logvar2)
	]

	return 0.5 * (
		-1.0
		+ logvar2
		- logvar1
		+ torch.exp(logvar1 - logvar2)
		+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
	)


def approx_standard_normal_cdf(x):
	"""
	A fast approximation of the cumulative distribution function of the
	standard normal.
	"""
	return 0.5 * (1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * torch.pow(x, 3))))


def discretized_gaussian_log_likelihood(x, *, means, log_scales):
	"""
	Compute the log-likelihood of a Gaussian distribution discretizing to a
	given image.

	:param x: the target images. It is assumed that this was uint8 values,
			  rescaled to the range [-1, 1].
	:param means: the Gaussian mean Tensor.
	:param log_scales: the Gaussian log stddev Tensor.
	:return: a tensor like x of log probabilities (in nats).
	"""
	assert x.shape == means.shape == log_scales.shape
	centered_x = x - means
	inv_stdv = torch.exp(-log_scales)
	plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
	cdf_plus = approx_standard_normal_cdf(plus_in)
	min_in = inv_stdv * (centered_x - 1.0 / 255.0)
	cdf_min = approx_standard_normal_cdf(min_in)
	log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12))
	log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12))
	cdf_delta = cdf_plus - cdf_min
	log_probs = torch.where(
		x < -0.999,
		log_cdf_plus,
		torch.where(x > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))),
	)
	assert log_probs.shape == x.shape
	return log_probs


def mean_flat(tensor):
	"""
	Take the mean over all non-batch dimensions.
	"""
	return tensor.mean(dim=list(range(1, len(tensor.shape))))


def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
	"""
	Get a pre-defined beta schedule for the given name.

	The beta schedule library consists of beta schedules which remain similar
	in the limit of num_diffusion_timesteps.
	Beta schedules may be added, but should not be removed or changed once
	they are committed to maintain backwards compatibility.
	"""
	if schedule_name == "linear":
		# Linear schedule from Ho et al, extended to work for any number of
		# diffusion steps.
		scale = 1000 / num_diffusion_timesteps
		beta_start = scale * 0.0001
		beta_end = scale * 0.02
		return np.linspace(
			beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
		)
	elif schedule_name == "cosine":
		return betas_for_alpha_bar(
			num_diffusion_timesteps,
			lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
		)
	else:
		raise NotImplementedError(f"unknown beta schedule: {schedule_name}")


def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
	"""
	Create a beta schedule that discretizes the given alpha_t_bar function,
	which defines the cumulative product of (1-beta) over time from t = [0,1].

	:param num_diffusion_timesteps: the number of betas to produce.
	:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
					  produces the cumulative product of (1-beta) up to that
					  part of the diffusion process.
	:param max_beta: the maximum beta to use; use values lower than 1 to
					 prevent singularities.
	"""
	betas = []
	for i in range(num_diffusion_timesteps):
		t1 = i / num_diffusion_timesteps
		t2 = (i + 1) / num_diffusion_timesteps
		betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
	return np.array(betas)


class ModelMeanType(enum.Enum):
	"""
	Which type of output the model predicts.
	"""

	PREVIOUS_X = 'previous_x'  # the model predicts x_{t-1}
	START_X = 'start_x'  # the model predicts x_0
	EPSILON = 'epsilon'  # the model predicts epsilon


class ModelVarType(enum.Enum):
	"""
	What is used as the model's output variance.

	The LEARNED_RANGE option has been added to allow the model to predict
	values between FIXED_SMALL and FIXED_LARGE, making its job easier.
	"""

	LEARNED = 'learned'
	FIXED_SMALL = 'fixed_small'
	FIXED_LARGE = 'fixed_large'
	LEARNED_RANGE = 'learned_range'


class LossType(enum.Enum):
	MSE = 'mse'  # use raw MSE loss (and KL when learning variances)
	RESCALED_MSE = 'rescaled_mse'  # use raw MSE loss (with RESCALED_KL when learning variances)
	KL = 'kl'  # use the variational lower-bound
	RESCALED_KL = 'rescaled_kl'  # like KL, but rescale to estimate the full VLB

	def is_vb(self):
		return self == LossType.KL or self == LossType.RESCALED_KL


class GaussianDiffusion:
	"""
	Utilities for training and sampling diffusion models.

	Ported directly from here, and then adapted over time to further experimentation.
	https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42

	:param betas: a 1-D numpy array of betas for each diffusion timestep,
				  starting at T and going to 1.
	:param model_mean_type: a ModelMeanType determining what the model outputs.
	:param model_var_type: a ModelVarType determining how variance is output.
	:param loss_type: a LossType determining the loss function to use.
	:param rescale_timesteps: if True, pass floating point timesteps into the
							  model so that they are always scaled like in the
							  original paper (0 to 1000).
	"""

	def __init__(
		self,
		*,
		betas,
		model_mean_type,
		model_var_type,
		loss_type,
		rescale_timesteps=False,
		conditioning_free=False,
		conditioning_free_k=1,
		ramp_conditioning_free=True,
	):
		self.model_mean_type = ModelMeanType(model_mean_type)
		self.model_var_type = ModelVarType(model_var_type)
		self.loss_type = LossType(loss_type)
		self.rescale_timesteps = rescale_timesteps
		self.conditioning_free = conditioning_free
		self.conditioning_free_k = conditioning_free_k
		self.ramp_conditioning_free = ramp_conditioning_free

		# Use float64 for accuracy.
		betas = np.array(betas, dtype=np.float64)
		self.betas = betas
		assert len(betas.shape) == 1, "betas must be 1-D"
		assert (betas > 0).all() and (betas <= 1).all()

		self.num_timesteps = int(betas.shape[0])

		alphas = 1.0 - betas
		self.alphas_cumprod = np.cumprod(alphas, axis=0)
		self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
		self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
		assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)

		# calculations for diffusion q(x_t | x_{t-1}) and others
		self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
		self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
		self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
		self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
		self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)

		# calculations for posterior q(x_{t-1} | x_t, x_0)
		self.posterior_variance = (
			betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
		)
		# log calculation clipped because the posterior variance is 0 at the
		# beginning of the diffusion chain.
		self.posterior_log_variance_clipped = np.log(
			np.append(self.posterior_variance[1], self.posterior_variance[1:])
		)
		self.posterior_mean_coef1 = (
			betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
		)
		self.posterior_mean_coef2 = (
			(1.0 - self.alphas_cumprod_prev)
			* np.sqrt(alphas)
			/ (1.0 - self.alphas_cumprod)
		)

	def q_mean_variance(self, x_start, t):
		"""
		Get the distribution q(x_t | x_0).

		:param x_start: the [N x C x ...] tensor of noiseless inputs.
		:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
		:return: A tuple (mean, variance, log_variance), all of x_start's shape.
		"""
		mean = (
			_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
		)
		variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
		log_variance = _extract_into_tensor(
			self.log_one_minus_alphas_cumprod, t, x_start.shape
		)
		return mean, variance, log_variance

	def q_sample(self, x_start, t, noise=None):
		"""
		Diffuse the data for a given number of diffusion steps.

		In other words, sample from q(x_t | x_0).

		:param x_start: the initial data batch.
		:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
		:param noise: if specified, the split-out normal noise.
		:return: A noisy version of x_start.
		"""
		if noise is None:
			noise = torch.randn_like(x_start)
		assert noise.shape == x_start.shape
		return (
			_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
			+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
			* noise
		)

	def q_posterior_mean_variance(self, x_start, x_t, t):
		"""
		Compute the mean and variance of the diffusion posterior:

			q(x_{t-1} | x_t, x_0)

		"""
		assert x_start.shape == x_t.shape
		posterior_mean = (
			_extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
			+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
		)
		posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
		posterior_log_variance_clipped = _extract_into_tensor(
			self.posterior_log_variance_clipped, t, x_t.shape
		)
		assert (
			posterior_mean.shape[0]
			== posterior_variance.shape[0]
			== posterior_log_variance_clipped.shape[0]
			== x_start.shape[0]
		)
		return posterior_mean, posterior_variance, posterior_log_variance_clipped

	def p_mean_variance(
		self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
	):
		"""
		Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
		the initial x, x_0.

		:param model: the model, which takes a signal and a batch of timesteps
					  as input.
		:param x: the [N x C x ...] tensor at time t.
		:param t: a 1-D Tensor of timesteps.
		:param clip_denoised: if True, clip the denoised signal into [-1, 1].
		:param denoised_fn: if not None, a function which applies to the
			x_start prediction before it is used to sample. Applies before
			clip_denoised.
		:param model_kwargs: if not None, a dict of extra keyword arguments to
			pass to the model. This can be used for conditioning.
		:return: a dict with the following keys:
				 - 'mean': the model mean output.
				 - 'variance': the model variance output.
				 - 'log_variance': the log of 'variance'.
				 - 'pred_xstart': the prediction for x_0.
		"""
		if model_kwargs is None:
			model_kwargs = {}

		B, C = x.shape[:2]
		assert t.shape == (B,)
		model_output = model(x, self._scale_timesteps(t), **model_kwargs)
		if self.conditioning_free:
			model_output_no_conditioning = model(x, self._scale_timesteps(t), conditioning_free=True, **model_kwargs)

		if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
			assert model_output.shape == (B, C * 2, *x.shape[2:])
			model_output, model_var_values = torch.split(model_output, C, dim=1)
			if self.conditioning_free:
				model_output_no_conditioning, _ = torch.split(model_output_no_conditioning, C, dim=1)
			if self.model_var_type == ModelVarType.LEARNED:
				model_log_variance = model_var_values
				model_variance = torch.exp(model_log_variance)
			else:
				min_log = _extract_into_tensor(
					self.posterior_log_variance_clipped, t, x.shape
				)
				max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
				# The model_var_values is [-1, 1] for [min_var, max_var].
				frac = (model_var_values + 1) / 2
				model_log_variance = frac * max_log + (1 - frac) * min_log
				model_variance = torch.exp(model_log_variance)
		else:
			model_variance, model_log_variance = {
				# for fixedlarge, we set the initial (log-)variance like so
				# to get a better decoder log likelihood.
				ModelVarType.FIXED_LARGE: (
					np.append(self.posterior_variance[1], self.betas[1:]),
					np.log(np.append(self.posterior_variance[1], self.betas[1:])),
				),
				ModelVarType.FIXED_SMALL: (
					self.posterior_variance,
					self.posterior_log_variance_clipped,
				),
			}[self.model_var_type]
			model_variance = _extract_into_tensor(model_variance, t, x.shape)
			model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)

		if self.conditioning_free:
			if self.ramp_conditioning_free:
				assert t.shape[0] == 1  # This should only be used in inference.
				cfk = self.conditioning_free_k * (1 - self._scale_timesteps(t)[0].item() / self.num_timesteps)
			else:
				cfk = self.conditioning_free_k
			model_output = (1 + cfk) * model_output - cfk * model_output_no_conditioning

		def process_xstart(x):
			if denoised_fn is not None:
				x = denoised_fn(x)
			if clip_denoised:
				return x.clamp(-1, 1)
			return x

		if self.model_mean_type == ModelMeanType.PREVIOUS_X:
			pred_xstart = process_xstart(
				self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
			)
			model_mean = model_output
		elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
			if self.model_mean_type == ModelMeanType.START_X:
				pred_xstart = process_xstart(model_output)
			else:
				pred_xstart = process_xstart(
					self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
				)
			model_mean, _, _ = self.q_posterior_mean_variance(
				x_start=pred_xstart, x_t=x, t=t
			)
		else:
			raise NotImplementedError(self.model_mean_type)

		assert (
			model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
		)
		return {
			"mean": model_mean,
			"variance": model_variance,
			"log_variance": model_log_variance,
			"pred_xstart": pred_xstart,
		}

	def _predict_xstart_from_eps(self, x_t, t, eps):
		assert x_t.shape == eps.shape
		return (
			_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
			- _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
		)

	def _predict_xstart_from_xprev(self, x_t, t, xprev):
		assert x_t.shape == xprev.shape
		return (  # (xprev - coef2*x_t) / coef1
			_extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
			- _extract_into_tensor(
				self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
			)
			* x_t
		)

	def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
		return (
			_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
			- pred_xstart
		) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)

	def _scale_timesteps(self, t):
		if self.rescale_timesteps:
			return t.float() * (1000.0 / self.num_timesteps)
		return t

	def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
		"""
		Compute the mean for the previous step, given a function cond_fn that
		computes the gradient of a conditional log probability with respect to
		x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
		condition on y.

		This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
		"""
		gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
		new_mean = (
			p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
		)
		return new_mean

	def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
		"""
		Compute what the p_mean_variance output would have been, should the
		model's score function be conditioned by cond_fn.

		See condition_mean() for details on cond_fn.

		Unlike condition_mean(), this instead uses the conditioning strategy
		from Song et al (2020).
		"""
		alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)

		eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
		eps = eps - (1 - alpha_bar).sqrt() * cond_fn(
			x, self._scale_timesteps(t), **model_kwargs
		)

		out = p_mean_var.copy()
		out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
		out["mean"], _, _ = self.q_posterior_mean_variance(
			x_start=out["pred_xstart"], x_t=x, t=t
		)
		return out

	def sample_loop(self,  *args, **kwargs):
		# YUCK
		sampler = kwargs.pop("sampler").lower() if "sampler" in kwargs else "ddim"
		if sampler == 'p':
			return self.p_sample_loop(*args, **kwargs)
		if sampler == 'ddim':
			return self.ddim_sample_loop(*args, **kwargs)
		
		raise RuntimeError(f"Sampler not implemented: {sampler}")

	def p_sample(
		self,
		model,
		x,
		t,
		clip_denoised=True,
		denoised_fn=None,
		cond_fn=None,
		model_kwargs=None,
	):
		"""
		Sample x_{t-1} from the model at the given timestep.

		:param model: the model to sample from.
		:param x: the current tensor at x_{t-1}.
		:param t: the value of t, starting at 0 for the first diffusion step.
		:param clip_denoised: if True, clip the x_start prediction to [-1, 1].
		:param denoised_fn: if not None, a function which applies to the
			x_start prediction before it is used to sample.
		:param cond_fn: if not None, this is a gradient function that acts
						similarly to the model.
		:param model_kwargs: if not None, a dict of extra keyword arguments to
			pass to the model. This can be used for conditioning.
		:return: a dict containing the following keys:
				 - 'sample': a random sample from the model.
				 - 'pred_xstart': a prediction of x_0.
		"""
		out = self.p_mean_variance(
			model,
			x,
			t,
			clip_denoised=clip_denoised,
			denoised_fn=denoised_fn,
			model_kwargs=model_kwargs,
		)
		noise = torch.randn_like(x)
		nonzero_mask = (
			(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
		)  # no noise when t == 0
		if cond_fn is not None:
			out["mean"] = self.condition_mean(
				cond_fn, out, x, t, model_kwargs=model_kwargs
			)
		sample = out["mean"] + nonzero_mask * torch.exp(0.5 * out["log_variance"]) * noise
		return {"sample": sample, "pred_xstart": out["pred_xstart"]}

	def p_sample_loop(
		self,
		model,
		shape,
		noise=None,
		clip_denoised=True,
		denoised_fn=None,
		cond_fn=None,
		model_kwargs=None,
		device=None,
		progress=False,
	):
		"""
		Generate samples from the model.

		:param model: the model module.
		:param shape: the shape of the samples, (N, C, H, W).
		:param noise: if specified, the noise from the encoder to sample.
					  Should be of the same shape as `shape`.
		:param clip_denoised: if True, clip x_start predictions to [-1, 1].
		:param denoised_fn: if not None, a function which applies to the
			x_start prediction before it is used to sample.
		:param cond_fn: if not None, this is a gradient function that acts
						similarly to the model.
		:param model_kwargs: if not None, a dict of extra keyword arguments to
			pass to the model. This can be used for conditioning.
		:param device: if specified, the device to create the samples on.
					   If not specified, use a model parameter's device.
		:param progress: if True, show a tqdm progress bar.
		:return: a non-differentiable batch of samples.
		"""
		final = None
		for sample in self.p_sample_loop_progressive(
			model,
			shape,
			noise=noise,
			clip_denoised=clip_denoised,
			denoised_fn=denoised_fn,
			cond_fn=cond_fn,
			model_kwargs=model_kwargs,
			device=device,
			progress=progress,
		):
			final = sample
		return final["sample"]

	def p_sample_loop_progressive(
		self,
		model,
		shape,
		noise=None,
		clip_denoised=True,
		denoised_fn=None,
		cond_fn=None,
		model_kwargs=None,
		device=None,
		progress=False,
	):
		"""
		Generate samples from the model and yield intermediate samples from
		each timestep of diffusion.

		Arguments are the same as p_sample_loop().
		Returns a generator over dicts, where each dict is the return value of
		p_sample().
		"""
		if device is None:
			device = next(model.parameters()).device
		assert isinstance(shape, (tuple, list))
		if noise is not None:
			img = noise
		else:
			img = torch.randn(*shape, device=device)
		indices = list(range(self.num_timesteps))[::-1]

		for i in tqdm(indices, disable=not progress, desc="Diffusion"):
			t = torch.tensor([i] * shape[0], device=device)
			with torch.no_grad():
				out = self.p_sample(
					model,
					img,
					t,
					clip_denoised=clip_denoised,
					denoised_fn=denoised_fn,
					cond_fn=cond_fn,
					model_kwargs=model_kwargs,
				)
				yield out
				img = out["sample"]

	def ddim_sample(
		self,
		model,
		x,
		t,
		clip_denoised=True,
		denoised_fn=None,
		cond_fn=None,
		model_kwargs=None,
		eta=0.0,
	):
		"""
		Sample x_{t-1} from the model using DDIM.

		Same usage as p_sample().
		"""
		out = self.p_mean_variance(
			model,
			x,
			t,
			clip_denoised=clip_denoised,
			denoised_fn=denoised_fn,
			model_kwargs=model_kwargs,
		)
		if cond_fn is not None:
			out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)

		# Usually our model outputs epsilon, but we re-derive it
		# in case we used x_start or x_prev prediction.
		eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])

		alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
		alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
		sigma = (
			eta
			* torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
			* torch.sqrt(1 - alpha_bar / alpha_bar_prev)
		)
		# Equation 12.
		noise = torch.randn_like(x)
		mean_pred = (
			out["pred_xstart"] * torch.sqrt(alpha_bar_prev)
			+ torch.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
		)
		nonzero_mask = (
			(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
		)  # no noise when t == 0
		sample = mean_pred + nonzero_mask * sigma * noise
		return {"sample": sample, "pred_xstart": out["pred_xstart"]}

	def ddim_reverse_sample(
		self,
		model,
		x,
		t,
		clip_denoised=True,
		denoised_fn=None,
		model_kwargs=None,
		eta=0.0,
	):
		"""
		Sample x_{t+1} from the model using DDIM reverse ODE.
		"""
		assert eta == 0.0, "Reverse ODE only for deterministic path"
		out = self.p_mean_variance(
			model,
			x,
			t,
			clip_denoised=clip_denoised,
			denoised_fn=denoised_fn,
			model_kwargs=model_kwargs,
		)
		# Usually our model outputs epsilon, but we re-derive it
		# in case we used x_start or x_prev prediction.
		eps = (
			_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
			- out["pred_xstart"]
		) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
		alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)

		# Equation 12. reversed
		mean_pred = (
			out["pred_xstart"] * torch.sqrt(alpha_bar_next)
			+ torch.sqrt(1 - alpha_bar_next) * eps
		)

		return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}

	def ddim_sample_loop(
		self,
		model,
		shape,
		noise=None,
		clip_denoised=True,
		denoised_fn=None,
		cond_fn=None,
		model_kwargs=None,
		device=None,
		progress=False,
		eta=0.0,
	):
		"""
		Generate samples from the model using DDIM.

		Same usage as p_sample_loop().
		"""
		final = None
		for sample in self.ddim_sample_loop_progressive(
			model,
			shape,
			noise=noise,
			clip_denoised=clip_denoised,
			denoised_fn=denoised_fn,
			cond_fn=cond_fn,
			model_kwargs=model_kwargs,
			device=device,
			progress=progress,
			eta=eta,
		):
			final = sample
		return final["sample"]

	def ddim_sample_loop_progressive(
		self,
		model,
		shape,
		noise=None,
		clip_denoised=True,
		denoised_fn=None,
		cond_fn=None,
		model_kwargs=None,
		device=None,
		progress=False,
		eta=0.0,
	):
		"""
		Use DDIM to sample from the model and yield intermediate samples from
		each timestep of DDIM.

		Same usage as p_sample_loop_progressive().
		"""
		if device is None:
			device = next(model.parameters()).device
		assert isinstance(shape, (tuple, list))
		if noise is not None:
			img = noise
		else:
			img = torch.randn(*shape, device=device)
		indices = list(range(self.num_timesteps))[::-1]

		for i in tqdm(indices, disable=not progress, desc="Diffusion"):
			t = torch.tensor([i] * shape[0], device=device)
			with torch.no_grad():
				out = self.ddim_sample(
					model,
					img,
					t,
					clip_denoised=clip_denoised,
					denoised_fn=denoised_fn,
					cond_fn=cond_fn,
					model_kwargs=model_kwargs,
					eta=eta,
				)
				yield out
				img = out["sample"]

	def _vb_terms_bpd(
		self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
	):
		"""
		Get a term for the variational lower-bound.

		The resulting units are bits (rather than nats, as one might expect).
		This allows for comparison to other papers.

		:return: a dict with the following keys:
				 - 'output': a shape [N] tensor of NLLs or KLs.
				 - 'pred_xstart': the x_0 predictions.
		"""
		true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
			x_start=x_start, x_t=x_t, t=t
		)
		out = self.p_mean_variance(
			model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
		)
		kl = normal_kl(
			true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
		)
		kl = mean_flat(kl) / np.log(2.0)

		decoder_nll = -discretized_gaussian_log_likelihood(
			x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
		)
		assert decoder_nll.shape == x_start.shape
		decoder_nll = mean_flat(decoder_nll) / np.log(2.0)

		# At the first timestep return the decoder NLL,
		# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
		output = torch.where((t == 0), decoder_nll, kl)
		return {"output": output, "pred_xstart": out["pred_xstart"]}

	def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
		"""
		Compute training losses for a single timestep.

		:param model: the model to evaluate loss on.
		:param x_start: the [N x C x ...] tensor of inputs.
		:param t: a batch of timestep indices.
		:param model_kwargs: if not None, a dict of extra keyword arguments to
			pass to the model. This can be used for conditioning.
		:param noise: if specified, the specific Gaussian noise to try to remove.
		:return: a dict with the key "loss" containing a tensor of shape [N].
				 Some mean or variance settings may also have other keys.
		"""
		if model_kwargs is None:
			model_kwargs = {}
		if noise is None:
			noise = torch.randn_like(x_start)
		x_t = self.q_sample(x_start, t, noise=noise)

		terms = {}

		if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
			# TODO: support multiple model outputs for this mode.
			terms["loss"] = self._vb_terms_bpd(
				model=model,
				x_start=x_start,
				x_t=x_t,
				t=t,
				clip_denoised=False,
				model_kwargs=model_kwargs,
			)["output"]
			if self.loss_type == LossType.RESCALED_KL:
				terms["loss"] *= self.num_timesteps
		elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
			model_outputs = model(x_t, self._scale_timesteps(t), **model_kwargs)
			if isinstance(model_outputs, tuple):
				model_output = model_outputs[0]
				terms['extra_outputs'] = model_outputs[1:]
			else:
				model_output = model_outputs

			if self.model_var_type in [
				ModelVarType.LEARNED,
				ModelVarType.LEARNED_RANGE,
			]:
				B, C = x_t.shape[:2]
				assert model_output.shape == (B, C * 2, *x_t.shape[2:])
				model_output, model_var_values = torch.split(model_output, C, dim=1)
				# Learn the variance using the variational bound, but don't let
				# it affect our mean prediction.
				frozen_out = torch.cat([model_output.detach(), model_var_values], dim=1)
				terms["vb"] = self._vb_terms_bpd(
					model=lambda *args, r=frozen_out: r,
					x_start=x_start,
					x_t=x_t,
					t=t,
					clip_denoised=False,
				)["output"]
				if self.loss_type == LossType.RESCALED_MSE:
					# Divide by 1000 for equivalence with initial implementation.
					# Without a factor of 1/1000, the VB term hurts the MSE term.
					terms["vb"] *= self.num_timesteps / 1000.0

			if self.model_mean_type == ModelMeanType.PREVIOUS_X:
				target = self.q_posterior_mean_variance(
					x_start=x_start, x_t=x_t, t=t
				)[0]
				x_start_pred = torch.zeros(x_start)  # Not supported.
			elif self.model_mean_type == ModelMeanType.START_X:
				target = x_start
				x_start_pred = model_output
			elif self.model_mean_type == ModelMeanType.EPSILON:
				target = noise
				x_start_pred = self._predict_xstart_from_eps(x_t, t, model_output)
			else:
				raise NotImplementedError(self.model_mean_type)
			assert model_output.shape == target.shape == x_start.shape
			terms["mse"] = mean_flat((target - model_output) ** 2)
			terms["x_start_predicted"] = x_start_pred
			if "vb" in terms:
				terms["loss"] = terms["mse"] + terms["vb"]
			else:
				terms["loss"] = terms["mse"]
		else:
			raise NotImplementedError(self.loss_type)

		return terms

	def autoregressive_training_losses(self, model, x_start, t, model_output_keys, gd_out_key, model_kwargs=None, noise=None):
		"""
		Compute training losses for a single timestep.

		:param model: the model to evaluate loss on.
		:param x_start: the [N x C x ...] tensor of inputs.
		:param t: a batch of timestep indices.
		:param model_kwargs: if not None, a dict of extra keyword arguments to
			pass to the model. This can be used for conditioning.
		:param noise: if specified, the specific Gaussian noise to try to remove.
		:return: a dict with the key "loss" containing a tensor of shape [N].
				 Some mean or variance settings may also have other keys.
		"""
		if model_kwargs is None:
			model_kwargs = {}
		if noise is None:
			noise = torch.randn_like(x_start)
		x_t = self.q_sample(x_start, t, noise=noise)
		terms = {}
		if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
			assert False  # not currently supported for this type of diffusion.
		elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
			model_outputs = model(x_t, x_start, self._scale_timesteps(t), **model_kwargs)
			terms.update({k: o for k, o in zip(model_output_keys, model_outputs)})
			model_output = terms[gd_out_key]
			if self.model_var_type in [
				ModelVarType.LEARNED,
				ModelVarType.LEARNED_RANGE,
			]:
				B, C = x_t.shape[:2]
				assert model_output.shape == (B, C, 2, *x_t.shape[2:])
				model_output, model_var_values = model_output[:, :, 0], model_output[:, :, 1]
				# Learn the variance using the variational bound, but don't let
				# it affect our mean prediction.
				frozen_out = torch.cat([model_output.detach(), model_var_values], dim=1)
				terms["vb"] = self._vb_terms_bpd(
					model=lambda *args, r=frozen_out: r,
					x_start=x_start,
					x_t=x_t,
					t=t,
					clip_denoised=False,
				)["output"]
				if self.loss_type == LossType.RESCALED_MSE:
					# Divide by 1000 for equivalence with initial implementation.
					# Without a factor of 1/1000, the VB term hurts the MSE term.
					terms["vb"] *= self.num_timesteps / 1000.0

			if self.model_mean_type == ModelMeanType.PREVIOUS_X:
				target = self.q_posterior_mean_variance(
					x_start=x_start, x_t=x_t, t=t
				)[0]
				x_start_pred = torch.zeros(x_start)  # Not supported.
			elif self.model_mean_type == ModelMeanType.START_X:
				target = x_start
				x_start_pred = model_output
			elif self.model_mean_type == ModelMeanType.EPSILON:
				target = noise
				x_start_pred = self._predict_xstart_from_eps(x_t, t, model_output)
			else:
				raise NotImplementedError(self.model_mean_type)
			assert model_output.shape == target.shape == x_start.shape
			terms["mse"] = mean_flat((target - model_output) ** 2)
			terms["x_start_predicted"] = x_start_pred
			if "vb" in terms:
				terms["loss"] = terms["mse"] + terms["vb"]
			else:
				terms["loss"] = terms["mse"]
		else:
			raise NotImplementedError(self.loss_type)

		return terms

	def _prior_bpd(self, x_start):
		"""
		Get the prior KL term for the variational lower-bound, measured in
		bits-per-dim.

		This term can't be optimized, as it only depends on the encoder.

		:param x_start: the [N x C x ...] tensor of inputs.
		:return: a batch of [N] KL values (in bits), one per batch element.
		"""
		batch_size = x_start.shape[0]
		t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
		qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
		kl_prior = normal_kl(
			mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
		)
		return mean_flat(kl_prior) / np.log(2.0)

	def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
		"""
		Compute the entire variational lower-bound, measured in bits-per-dim,
		as well as other related quantities.

		:param model: the model to evaluate loss on.
		:param x_start: the [N x C x ...] tensor of inputs.
		:param clip_denoised: if True, clip denoised samples.
		:param model_kwargs: if not None, a dict of extra keyword arguments to
			pass to the model. This can be used for conditioning.

		:return: a dict containing the following keys:
				 - total_bpd: the total variational lower-bound, per batch element.
				 - prior_bpd: the prior term in the lower-bound.
				 - vb: an [N x T] tensor of terms in the lower-bound.
				 - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
				 - mse: an [N x T] tensor of epsilon MSEs for each timestep.
		"""
		device = x_start.device
		batch_size = x_start.shape[0]

		vb = []
		xstart_mse = []
		mse = []
		for t in list(range(self.num_timesteps))[::-1]:
			t_batch = torch.tensor([t] * batch_size, device=device)
			noise = torch.randn_like(x_start)
			x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
			# Calculate VLB term at the current timestep
			with torch.no_grad():
				out = self._vb_terms_bpd(
					model,
					x_start=x_start,
					x_t=x_t,
					t=t_batch,
					clip_denoised=clip_denoised,
					model_kwargs=model_kwargs,
				)
			vb.append(out["output"])
			xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
			eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
			mse.append(mean_flat((eps - noise) ** 2))

		vb = torch.stack(vb, dim=1)
		xstart_mse = torch.stack(xstart_mse, dim=1)
		mse = torch.stack(mse, dim=1)

		prior_bpd = self._prior_bpd(x_start)
		total_bpd = vb.sum(dim=1) + prior_bpd
		return {
			"total_bpd": total_bpd,
			"prior_bpd": prior_bpd,
			"vb": vb,
			"xstart_mse": xstart_mse,
			"mse": mse,
		}


def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
	"""
	Get a pre-defined beta schedule for the given name.

	The beta schedule library consists of beta schedules which remain similar
	in the limit of num_diffusion_timesteps.
	Beta schedules may be added, but should not be removed or changed once
	they are committed to maintain backwards compatibility.
	"""
	if schedule_name == "linear":
		# Linear schedule from Ho et al, extended to work for any number of
		# diffusion steps.
		scale = 1000 / num_diffusion_timesteps
		beta_start = scale * 0.0001
		beta_end = scale * 0.02
		return np.linspace(
			beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
		)
	elif schedule_name == "cosine":
		return betas_for_alpha_bar(
			num_diffusion_timesteps,
			lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
		)
	else:
		raise NotImplementedError(f"unknown beta schedule: {schedule_name}")


class SpacedDiffusion(GaussianDiffusion):
	"""
	A diffusion process which can skip steps in a base diffusion process.

	:param use_timesteps: a collection (sequence or set) of timesteps from the
						  original diffusion process to retain.
	:param kwargs: the kwargs to create the base diffusion process.
	"""

	def __init__(self, use_timesteps, **kwargs):
		self.use_timesteps = set(use_timesteps)
		self.timestep_map = []
		self.original_num_steps = len(kwargs["betas"])

		base_diffusion = GaussianDiffusion(**kwargs)  # pylint: disable=missing-kwoa
		last_alpha_cumprod = 1.0
		new_betas = []
		for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
			if i in self.use_timesteps:
				new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
				last_alpha_cumprod = alpha_cumprod
				self.timestep_map.append(i)
		kwargs["betas"] = np.array(new_betas)
		super().__init__(**kwargs)

	def p_mean_variance(
		self, model, *args, **kwargs
	):  # pylint: disable=signature-differs
		return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)

	def training_losses(
		self, model, *args, **kwargs
	):  # pylint: disable=signature-differs
		return super().training_losses(self._wrap_model(model), *args, **kwargs)

	def autoregressive_training_losses(
		self, model, *args, **kwargs
	):  # pylint: disable=signature-differs
		return super().autoregressive_training_losses(self._wrap_model(model, True), *args, **kwargs)

	def condition_mean(self, cond_fn, *args, **kwargs):
		return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)

	def condition_score(self, cond_fn, *args, **kwargs):
		return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)

	def _wrap_model(self, model, autoregressive=False):
		if isinstance(model, _WrappedModel) or isinstance(model, _WrappedAutoregressiveModel):
			return model
		mod = _WrappedAutoregressiveModel if autoregressive else _WrappedModel
		return mod(
			model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
		)

	def _scale_timesteps(self, t):
		# Scaling is done by the wrapped model.
		return t


def space_timesteps(num_timesteps, section_counts):
	"""
	Create a list of timesteps to use from an original diffusion process,
	given the number of timesteps we want to take from equally-sized portions
	of the original process.

	For example, if there's 300 timesteps and the section counts are [10,15,20]
	then the first 100 timesteps are strided to be 10 timesteps, the second 100
	are strided to be 15 timesteps, and the final 100 are strided to be 20.

	If the stride is a string starting with "ddim", then the fixed striding
	from the DDIM paper is used, and only one section is allowed.

	:param num_timesteps: the number of diffusion steps in the original
						  process to divide up.
	:param section_counts: either a list of numbers, or a string containing
						   comma-separated numbers, indicating the step count
						   per section. As a special case, use "ddimN" where N
						   is a number of steps to use the striding from the
						   DDIM paper.
	:return: a set of diffusion steps from the original process to use.
	"""
	if isinstance(section_counts, str):
		if section_counts.startswith("ddim"):
			desired_count = int(section_counts[len("ddim") :])
			for i in range(1, num_timesteps):
				if len(range(0, num_timesteps, i)) == desired_count:
					return set(range(0, num_timesteps, i))
			raise ValueError(
				f"cannot create exactly {num_timesteps} steps with an integer stride"
			)
		section_counts = [int(x) for x in section_counts.split(",")]
	size_per = num_timesteps // len(section_counts)
	extra = num_timesteps % len(section_counts)
	start_idx = 0
	all_steps = []
	for i, section_count in enumerate(section_counts):
		size = size_per + (1 if i < extra else 0)
		if size < section_count:
			raise ValueError(
				f"cannot divide section of {size} steps into {section_count}"
			)
		if section_count <= 1:
			frac_stride = 1
		else:
			frac_stride = (size - 1) / (section_count - 1)
		cur_idx = 0.0
		taken_steps = []
		for _ in range(section_count):
			taken_steps.append(start_idx + round(cur_idx))
			cur_idx += frac_stride
		all_steps += taken_steps
		start_idx += size
	return set(all_steps)


class _WrappedModel:
	def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
		self.model = model
		self.timestep_map = timestep_map
		self.rescale_timesteps = rescale_timesteps
		self.original_num_steps = original_num_steps

	def __call__(self, x, ts, **kwargs):
		map_tensor = torch.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
		new_ts = map_tensor[ts]
		if self.rescale_timesteps:
			new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
		return self.model(x, new_ts, **kwargs)


class _WrappedAutoregressiveModel:
	def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
		self.model = model
		self.timestep_map = timestep_map
		self.rescale_timesteps = rescale_timesteps
		self.original_num_steps = original_num_steps

	def __call__(self, x, x0, ts, **kwargs):
		map_tensor = torch.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
		new_ts = map_tensor[ts]
		if self.rescale_timesteps:
			new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
		return self.model(x, x0, new_ts, **kwargs)

def _extract_into_tensor(arr, timesteps, broadcast_shape):
	"""
	Extract values from a 1-D numpy array for a batch of indices.

	:param arr: the 1-D numpy array.
	:param timesteps: a tensor of indices into the array to extract.
	:param broadcast_shape: a larger shape of K dimensions with the batch
							dimension equal to the length of timesteps.
	:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
	"""
	res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
	while len(res.shape) < len(broadcast_shape):
		res = res[..., None]
	return res.expand(broadcast_shape)

def is_latent(t):
	return t.dtype == torch.float


def is_sequence(t):
	return t.dtype == torch.long


def timestep_embedding(timesteps, dim, max_period=10000):
	"""
	Create sinusoidal timestep embeddings.

	:param timesteps: a 1-D Tensor of N indices, one per batch element.
					  These may be fractional.
	:param dim: the dimension of the output.
	:param max_period: controls the minimum frequency of the embeddings.
	:return: an [N x dim] Tensor of positional embeddings.
	"""
	half = dim // 2
	freqs = torch.exp(
		-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
	).to(device=timesteps.device)
	args = timesteps[:, None].float() * freqs[None]
	embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
	if dim % 2:
		embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
	return embedding


class TimestepBlock(nn.Module):
	@abstractmethod
	def forward(self, x, emb):
		"""
		Apply the module to `x` given `emb` timestep embeddings.
		"""


class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
	def forward(self, x, emb):
		for layer in self:
			if isinstance(layer, TimestepBlock):
				x = layer(x, emb)
			else:
				x = layer(x)
		return x


class ResBlock(TimestepBlock):
	def __init__(
		self,
		channels,
		emb_channels,
		dropout,
		out_channels=None,
		dims=2,
		kernel_size=3,
		efficient_config=True,
		use_scale_shift_norm=False,
	):
		super().__init__()
		self.channels = channels
		self.emb_channels = emb_channels
		self.dropout = dropout
		self.out_channels = out_channels or channels
		self.use_scale_shift_norm = use_scale_shift_norm
		padding = {1: 0, 3: 1, 5: 2}[kernel_size]
		eff_kernel = 1 if efficient_config else 3
		eff_padding = 0 if efficient_config else 1

		self.in_layers = nn.Sequential(
			normalization(channels),
			nn.SiLU(),
			nn.Conv1d(channels, self.out_channels, eff_kernel, padding=eff_padding),
		)

		self.emb_layers = nn.Sequential(
			nn.SiLU(),
			nn.Linear(
				emb_channels,
				2 * self.out_channels if use_scale_shift_norm else self.out_channels,
			),
		)
		self.out_layers = nn.Sequential(
			normalization(self.out_channels),
			nn.SiLU(),
			nn.Dropout(p=dropout),
				nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding),
		)

		if self.out_channels == channels:
			self.skip_connection = nn.Identity()
		else:
			self.skip_connection = nn.Conv1d(channels, self.out_channels, eff_kernel, padding=eff_padding)

	def forward(self, x, emb):
		h = self.in_layers(x)
		emb_out = self.emb_layers(emb).type(h.dtype)
		while len(emb_out.shape) < len(h.shape):
			emb_out = emb_out[..., None]
		if self.use_scale_shift_norm:
			out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
			scale, shift = torch.chunk(emb_out, 2, dim=1)
			h = out_norm(h) * (1 + scale) + shift
			h = out_rest(h)
		else:
			h = h + emb_out
			h = self.out_layers(h)
		return self.skip_connection(x) + h


class DiffusionLayer(TimestepBlock):
	def __init__(self, model_channels, dropout, num_heads):
		super().__init__()
		self.resblk = ResBlock(model_channels, model_channels, dropout, model_channels, dims=1, use_scale_shift_norm=True)
		self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True)

	def forward(self, x, time_emb):
		y = self.resblk(x, time_emb)
		return self.attn(y)

class DiffusionTTS(nn.Module):
	def __init__(
		self,
		model_channels=1024, # 512
		num_layers=10, # 8
		in_channels=100,
		in_latent_channels=1024, # 512
		in_tokens=8193,
		out_channels=200,  # mean and variance
		dropout=0,
		use_fp16=False,
		num_heads=16,
		# Parameters for regularization.
		layer_drop=0, # 0.1
		unconditioned_percentage=0, # 0.1  # This implements a mechanism similar to what is used in classifier-free training.
	):
		super().__init__()

		self.in_channels = in_channels
		self.model_channels = model_channels
		self.out_channels = out_channels
		self.dropout = dropout
		self.num_heads = num_heads
		self.unconditioned_percentage = unconditioned_percentage
		self.enable_fp16 = use_fp16
		self.layer_drop = layer_drop

		self.inp_block = nn.Conv1d(in_channels, model_channels, 3, 1, 1)
		self.time_embed = nn.Sequential(
			nn.Linear(model_channels, model_channels),
			nn.SiLU(),
			nn.Linear(model_channels, model_channels),
		)

		# Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
		# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
		# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
		# transformer network.
		self.code_embedding = nn.Embedding(in_tokens, model_channels)
		self.code_converter = nn.Sequential(
			AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
			AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
			AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
		)
		self.code_norm = normalization(model_channels)
		self.latent_conditioner = nn.Sequential(
			nn.Conv1d(in_latent_channels, model_channels, 3, padding=1),
			AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
			AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
			AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
			AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
		)
		self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2),
												 nn.Conv1d(model_channels, model_channels*2,3,padding=1,stride=2),
												 AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, use_checkpoint=False),
												 AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, use_checkpoint=False),
												 AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, use_checkpoint=False),
												 AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, use_checkpoint=False),
												 AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, use_checkpoint=False))
		self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,1))
		self.conditioning_timestep_integrator = TimestepEmbedSequential(
			DiffusionLayer(model_channels, dropout, num_heads),
			DiffusionLayer(model_channels, dropout, num_heads),
			DiffusionLayer(model_channels, dropout, num_heads),
		)

		self.integrating_conv = nn.Conv1d(model_channels*2, model_channels, kernel_size=1)
		self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)

		self.layers = nn.ModuleList([DiffusionLayer(model_channels, dropout, num_heads) for _ in range(num_layers)] +
									[ResBlock(model_channels, model_channels, dropout, dims=1, use_scale_shift_norm=True) for _ in range(3)])

		self.out = nn.Sequential(
			normalization(model_channels),
			nn.SiLU(),
			nn.Conv1d(model_channels, out_channels, 3, padding=1),
		)

	def get_grad_norm_parameter_groups(self):
		groups = {
			'minicoder': list(self.contextual_embedder.parameters()),
			'layers': list(self.layers.parameters()),
			'code_converters': list(self.code_embedding.parameters()) + list(self.code_converter.parameters()) + list(self.latent_conditioner.parameters()) + list(self.latent_conditioner.parameters()),
			'timestep_integrator': list(self.conditioning_timestep_integrator.parameters()) + list(self.integrating_conv.parameters()),
			'time_embed': list(self.time_embed.parameters()),
		}
		return groups

	def get_conditioning(self, conditioning_input):
		speech_conditioning_input = conditioning_input.unsqueeze(1) if len(
			conditioning_input.shape) == 3 else conditioning_input
		conds = []
		for j in range(speech_conditioning_input.shape[1]):
			conds.append(self.contextual_embedder(speech_conditioning_input[:, j]))
		conds = torch.cat(conds, dim=-1)
		conds = conds.mean(dim=-1)
		return conds

	def timestep_independent(self, aligned_conditioning, conditioning_latent, expected_seq_len, return_code_pred):
		# Shuffle aligned_latent to BxCxS format
		if is_latent(aligned_conditioning):
			aligned_conditioning = aligned_conditioning.permute(0, 2, 1)

		cond_scale, cond_shift = torch.chunk(conditioning_latent, 2, dim=1)
		if is_latent(aligned_conditioning):
			code_emb = self.latent_conditioner(aligned_conditioning)
		else:
			code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1)
			code_emb = self.code_converter(code_emb)
		code_emb = self.code_norm(code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1)

		unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device)
		# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
		if self.training and self.unconditioned_percentage > 0:
			unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1),
											   device=code_emb.device) < self.unconditioned_percentage
			code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1),
								   code_emb)
		expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode='nearest')

		if not return_code_pred:
			return expanded_code_emb
		else:
			mel_pred = self.mel_head(expanded_code_emb)
			# Multiply mel_pred by !unconditioned_branches, which drops the gradient on unconditioned branches. This is because we don't want that gradient being used to train parameters through the codes_embedder as it unbalances contributions to that network from the MSE loss.
			mel_pred = mel_pred * unconditioned_batches.logical_not()
			return expanded_code_emb, mel_pred

	def forward(self, x, timesteps, aligned_conditioning=None, conditioning_latent=None, precomputed_aligned_embeddings=None, conditioning_free=False, return_code_pred=False):
		"""
		Apply the model to an input batch.

		:param x: an [N x C x ...] Tensor of inputs.
		:param timesteps: a 1-D batch of timesteps.
		:param aligned_conditioning: an aligned latent or sequence of tokens providing useful data about the sample to be produced.
		:param conditioning_latent: a pre-computed conditioning latent; see get_conditioning().
		:param precomputed_aligned_embeddings: Embeddings returned from self.timestep_independent()
		:param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered.
		:return: an [N x C x ...] Tensor of outputs.
		"""
		assert precomputed_aligned_embeddings is not None or (aligned_conditioning is not None and conditioning_latent is not None)
		assert not (return_code_pred and precomputed_aligned_embeddings is not None)  # These two are mutually exclusive.

		unused_params = []
		if conditioning_free:
			code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
			unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
			unused_params.extend(list(self.latent_conditioner.parameters()))
		else:
			if precomputed_aligned_embeddings is not None:
				code_emb = precomputed_aligned_embeddings
			else:
				code_emb, mel_pred = self.timestep_independent(aligned_conditioning, conditioning_latent, x.shape[-1], True)
				if is_latent(aligned_conditioning):
					unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
				else:
					unused_params.extend(list(self.latent_conditioner.parameters()))

			unused_params.append(self.unconditioned_embedding)

		time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
		code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
		x = self.inp_block(x)
		x = torch.cat([x, code_emb], dim=1)
		x = self.integrating_conv(x)
		for i, lyr in enumerate(self.layers):
			# Do layer drop where applicable. Do not drop first and last layers.
			if self.training and self.layer_drop > 0 and i != 0 and i != (len(self.layers)-1) and random.random() < self.layer_drop:
				unused_params.extend(list(lyr.parameters()))
			else:
				# First and last blocks will have autocast disabled for improved precision.
				with autocast(x.device.type, enabled=self.enable_fp16 and i != 0):
					x = lyr(x, time_emb)

		x = x.float()
		out = self.out(x)

		# Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
		extraneous_addition = 0
		for p in unused_params:
			extraneous_addition = extraneous_addition + p.mean()
		out = out + extraneous_addition * 0

		if return_code_pred:
			return out, mel_pred
		return out

def get_diffuser(
	steps=80,
	cond_free=True,
	cond_free_k=2,
	trained_diffusion_steps=4000,
):
	return SpacedDiffusion(
		use_timesteps=space_timesteps(trained_diffusion_steps, [steps]),
		model_mean_type='epsilon',
		model_var_type='learned_range',
		loss_type='mse',
		betas=get_named_beta_schedule('linear', trained_diffusion_steps),
		conditioning_free=cond_free,
		conditioning_free_k=cond_free_k
	)

if __name__ == '__main__':
	clip = torch.randn(2, 100, 400)
	aligned_latent = torch.randn(2,388,512)
	aligned_sequence = torch.randint(0,8192,(2,100))
	cond = torch.randn(2, 100, 400)
	ts = torch.LongTensor([600, 600])
	model = DiffusionTTS(512, layer_drop=.3, unconditioned_percentage=.5)
	# Test with latent aligned conditioning
	#o = model(clip, ts, aligned_latent, cond)
	# Test with sequence aligned conditioning
	o = model(clip, ts, aligned_sequence, cond)