from functools import partial

import torch
import torch.nn.functional as F
from torch import nn

from einops import rearrange
from rotary_embedding_torch import RotaryEmbedding, broadcat

# helpers

def exists(val):
	return val is not None


def default(val, d):
	return val if exists(val) else d


def cast_tuple(val, depth = 1):
	if isinstance(val, list):
		val = tuple(val)
	return val if isinstance(val, tuple) else (val,) * depth


def max_neg_value(t):
	return -torch.finfo(t.dtype).max


def stable_softmax(t, dim = -1, alpha = 32 ** 2):
	t = t / alpha
	t = t - torch.amax(t, dim = dim, keepdim = True).detach()
	return (t * alpha).softmax(dim = dim)


def route_args(router, args, depth):
	routed_args = [(dict(), dict()) for _ in range(depth)]
	matched_keys = [key for key in args.keys() if key in router]

	for key in matched_keys:
		val = args[key]
		for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
			new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes)
			routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
	return routed_args


# classes
class SequentialSequence(nn.Module):
	def __init__(self, layers, args_route = {}, layer_dropout = 0.):
		super().__init__()
		assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers'
		self.layers = layers
		self.args_route = args_route
		self.layer_dropout = layer_dropout

	def forward(self, x, **kwargs):
		args = route_args(self.args_route, kwargs, len(self.layers))
		layers_and_args = list(zip(self.layers, args))

		for (f, g), (f_args, g_args) in layers_and_args:
			x = x + f(x, **f_args)
			x = x + g(x, **g_args)
		return x


class DivideMax(nn.Module):
	def __init__(self, dim):
		super().__init__()
		self.dim = dim

	def forward(self, x):
		maxes = x.amax(dim = self.dim, keepdim = True).detach()
		return x / maxes


# https://arxiv.org/abs/2103.17239
class LayerScale(nn.Module):
	def __init__(self, dim, depth, fn):
		super().__init__()
		if depth <= 18:
			init_eps = 0.1
		elif depth > 18 and depth <= 24:
			init_eps = 1e-5
		else:
			init_eps = 1e-6

		scale = torch.zeros(1, 1, dim).fill_(init_eps)
		self.scale = nn.Parameter(scale)
		self.fn = fn
	def forward(self, x, **kwargs):
		return self.fn(x, **kwargs) * self.scale

# layer norm


class PreNorm(nn.Module):
	def __init__(self, dim, fn, sandwich = False):
		super().__init__()
		self.norm = nn.LayerNorm(dim)
		self.norm_out = nn.LayerNorm(dim) if sandwich else nn.Identity()
		self.fn = fn

	def forward(self, x, **kwargs):
		x = self.norm(x)
		x = self.fn(x, **kwargs)
		return self.norm_out(x)

# feed forward


class GEGLU(nn.Module):
	def forward(self, x):
		x, gates = x.chunk(2, dim = -1)
		return x * F.gelu(gates)


class FeedForward(nn.Module):
	def __init__(self, dim, dropout = 0., mult = 4.):
		super().__init__()
		self.net = nn.Sequential(
			nn.Linear(dim, dim * mult * 2),
			GEGLU(),
			nn.Dropout(dropout),
			nn.Linear(dim * mult, dim)
		)

	def forward(self, x):
		return self.net(x)

# Attention


class Attention(nn.Module):
	def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0.):
		super().__init__()
		inner_dim = dim_head *  heads
		self.heads = heads
		self.seq_len = seq_len
		self.scale = dim_head ** -0.5

		self.causal = causal

		self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
		self.to_out = nn.Sequential(
			nn.Linear(inner_dim, dim),
			nn.Dropout(dropout)
		)

	def forward(self, x, mask = None):
		b, n, _, h, device = *x.shape, self.heads, x.device
		softmax = torch.softmax

		qkv = self.to_qkv(x).chunk(3, dim = -1)
		q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

		q = q * self.scale

		dots = torch.einsum('b h i d, b h j d -> b h i j', q, k)
		mask_value = max_neg_value(dots)

		if exists(mask):
			mask = rearrange(mask, 'b j -> b () () j')
			dots.masked_fill_(~mask, mask_value)
			del mask

		if self.causal:
			i, j = dots.shape[-2:]
			mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
			dots.masked_fill_(mask, mask_value)

		attn = softmax(dots, dim=-1)

		out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
		out = rearrange(out, 'b h n d -> b n (h d)')
		out = self.to_out(out)
		return out


# main transformer class
class Transformer(nn.Module):
	def __init__(
		self,
		*,
		dim,
		depth,
		seq_len,
		causal = True,
		heads = 8,
		dim_head = 64,
		ff_mult = 4,
		attn_dropout = 0.,
		ff_dropout = 0.,
		sparse_attn = False,
		sandwich_norm = False,
	):
		super().__init__()
		layers = nn.ModuleList([])
		sparse_layer = cast_tuple(sparse_attn, depth)

		for ind, sparse_attn in zip(range(depth), sparse_layer):
			attn = Attention(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout)

			ff = FeedForward(dim, mult = ff_mult, dropout = ff_dropout)

			layers.append(nn.ModuleList([
				LayerScale(dim, ind + 1, PreNorm(dim, attn, sandwich = sandwich_norm)),
				LayerScale(dim, ind + 1, PreNorm(dim, ff, sandwich = sandwich_norm))
			]))

		execute_type = SequentialSequence
		route_attn = ((True, False),) * depth
		attn_route_map = {'mask': route_attn}

		self.layers = execute_type(layers, args_route = attn_route_map)

	def forward(self, x, **kwargs):
		return self.layers(x, **kwargs)