# https://github.com/microsoft/torchscale

from torchscale.architecture.config import RetNetConfig
from torchscale.architecture.retnet import RetNetDecoder
# from retnet import RetNet

# override MultiScaleRetention's forward because training with te throws an error
from torchscale.component.multiscale_retention import MultiScaleRetention, theta_shift

def MultiScaleRetention_forward(
		self,
		x,
		rel_pos,
		chunkwise_recurrent=False,
		incremental_state=None
	):
		bsz, tgt_len, _ = x.size()
		(sin, cos), inner_mask = rel_pos

		q = self.q_proj(x)
		k = self.k_proj(x) * self.scaling
		v = self.v_proj(x)
		g = self.g_proj(x)

		q = q.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)
		k = k.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)

		qr = theta_shift(q, sin, cos)
		kr = theta_shift(k, sin, cos)

		if incremental_state is not None:
			output = self.recurrent_forward(qr, kr, v, inner_mask, incremental_state)
		elif chunkwise_recurrent:
			output = self.chunk_recurrent_forward(qr, kr, v, inner_mask)
		else:
			output = self.parallel_forward(qr, kr, v, inner_mask)
		
		output = self.group_norm(output).reshape(bsz, tgt_len, self.head_dim * self.num_heads)

		output = self.gate_fn(g) * output

		output = self.out_proj(output)

		return output

MultiScaleRetention.forward = MultiScaleRetention_forward