from torch import Tensor
from typing import Any, Protocol

Stats = dict[str, float]

class TrainFeeder(Protocol):
	def __call__(
		self, *, engine: "Engine", batch: Any
	) -> None | tuple[Tensor, Stats]:
		...

def default_feeder(engine, batch):
	if isinstance(batch, list):
		engine( *batch )
	elif isinstance(batch, dict):
		engine( **batch )
	else:
		engine( batch )

	losses = engine.gather_attribute("loss")
	loss = torch.stack([*losses.values()]).sum()

	stats = {}
	stats |= {k: v.item() for k, v in losses.items()}

	return loss, stats
	

from ..config import cfg
from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
from ..utils.distributed import init_distributed, distributed_initialized, is_global_leader, world_size
from ..models.lora import freeze_non_lora_weights, lora_get_state_dict, lora_load_state_dict

import logging
import time
import torch
import torch.distributed
import os

from torch import Tensor
from torch.distributed import all_reduce
from typing import Any, Protocol
from functools import cached_property

from .base import TrainFeeder
from ..utils import wrapper as ml

_logger = logging.getLogger(__name__)

if not distributed_initialized() and cfg.trainer.backend == "local": # and world_size() > 1:
	init_distributed(torch.distributed.init_process_group)

# A very naive engine implementation using barebones PyTorch
class Engine():
	def __init__(self, *args, **kwargs):
		if 'hyper_config' in kwargs:
			self.hyper_config = kwargs['hyper_config']
			kwargs.pop("hyper_config")

		self.module = kwargs['model'].to(cfg.device).to(torch.float32 if cfg.trainer.amp else cfg.trainer.dtype)
		self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None
		self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None

		self.global_steps = kwargs.pop("global_steps", 0)
		self.micro_steps = kwargs.pop("micro_steps", 0)
		self.global_samples = kwargs.pop("global_samples", 0)
		self.tokens_processed = kwargs.pop("tokens_processed", 0)

		self._frozen_params = set()

		self.max_nan_losses = 8
		self.loss_scaler = torch.cuda.amp.GradScaler() if cfg.trainer.scale_loss else None

		self.current_batch_size = 0
		self._global_grad_norm = None

	def freeze(self, freeze_all=True):
		# set to freeze 
		if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"):
			raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None")

		# freeze non-LoRA params if requested
		if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None:
			return freeze_non_lora_weights( self.module )

		for name, param in self.module.named_parameters():
			if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params):
				param.requires_grad_(False)
				self._frozen_params.add(param)

	def unfreeze(self):
		for p in self._frozen_params:
			p.requires_grad_(True)
		self._frozen_params.clear()

	@property
	def _training(self):
		if not hasattr(self, "hyper_config"):
			return True
		return self.hyper_config.training

	@property
	def global_step(self):
		return self.global_steps

	@property
	def micro_step(self):
		return self.micro_steps

	@property
	def batch_size(self):
		return self.current_batch_size if self.current_batch_size > 0 else cfg.hyperparameters.batch_size

	@property
	def gradient_accumulation_steps(self):
		return cfg.hyperparameters.gradient_accumulation_steps
	
	@property
	def gradient_clipping(self):
		return cfg.hyperparameters.gradient_clipping

	def gather_attribute(self, *args, **kwargs):
		return gather_attribute(self.module, *args, **kwargs)

	def dispatch_attribute(self, *args, **kwargs):
		return dispatch_attribute(self.module, *args, **kwargs)

	def save_checkpoint(self, save_dir, tag ):
		if is_global_leader():
			module = self.module.state_dict()

			# if training lora
			# this is a separate path to override saving the weights
			lora = None
			if cfg.lora is not None:
				lora, module = lora_get_state_dict( module, split = True )
				save_dir = cfg.ckpt_dir / cfg.lora.full_name

			save_path = save_dir / tag / "state.pth"
			save_path.parent.mkdir(parents=True, exist_ok=True)

			torch.save({
				"module": module,
				"lora": lora,
				"optimizer": self.optimizer.state_dict() if self.optimizer is not None else None,
				"lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
				
				"stats": {		
					"global_step": self.global_step,
					"micro_step": self.micro_step,
					"global_samples": self.global_samples,
					"tokens_processed": self.tokens_processed,
				}
			}, save_path)

			open(save_dir / "latest", 'w').write( tag )

		torch.distributed.barrier()

	def load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True, load_module_only=False):
		# override to load the lora instead
		if cfg.lora is not None:
			load_dir = cfg.ckpt_dir / cfg.lora.full_name

		if tag is None:
			tag_path = load_dir / "latest"
			if not tag_path.exists():
				return
			tag = open(tag_path).read()

		load_path = load_dir / tag / "state.pth"
		if not load_path.exists():
			return

		state = torch.load(load_path, map_location=torch.device(cfg.device))
		self.global_steps = state['stats']['global_step'] if 'stats' in state else state['global_step']
		self.micro_steps = state['stats']['micro_step'] if 'stats' in state else state['micro_step']
		self.global_samples = state['stats']['global_samples'] if 'stats' in state else state['global_samples']
		self.tokens_processed = state['stats']['tokens_processed'] if 'stats' in state else state['tokens_processed']
		self.module.load_state_dict(state['module'], strict=cfg.trainer.strict_loading)

		load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state
		load_lr_scheduler_states = load_lr_scheduler_states and self.lr_scheduler is not None and 'lr_scheduler' in state
		
		if load_optimizer_states:
			self.optimizer.load_state_dict(state['optimizer']) #, map_location=torch.device(cfg.device))
		
		if load_lr_scheduler_states:
			self.lr_scheduler.load_state_dict(state['lr_scheduler']) #, map_location=torch.device(cfg.device))

		if 'lora' in state:
			lora_load_state_dict( self.module, state['lora'] )

	def eval(self):
		return self.module.eval()
	
	def train(self):
		return self.module.train()

	def to(self, *args, **kwargs):
		self.module = self.module.to(*args, **kwargs)
		if self.optimizer:
			self.optimizer = self.optimizer.to(*args, **kwargs)

		return self

	def __call__(self, *args, **kwargs):
		return self.forward(*args, **kwargs)

	@cached_property
	def device(self):
		return next(self.module.parameters()).device

	def forward(self, *args, **kwargs):
		return self.module.forward(*args, **kwargs)

	def backward(self, loss):
		if self.loss_scaler is not None:
			return self.loss_scaler.scale(loss / self.gradient_accumulation_steps).backward()
		return (loss / self.gradient_accumulation_steps).backward()


	def step(self):
		with torch.set_grad_enabled(self.gradient_accumulation_steps > 1):
			self.micro_steps += 1 
			self.global_samples += self.batch_size

			if (self.micro_steps + 1) % max(1, self.gradient_accumulation_steps) == 0:
				torch.nn.utils.clip_grad_norm_(self.module.parameters(), self.gradient_clipping)

				self.global_steps += 1 
				if self.loss_scaler is not None:
					self.loss_scaler.step(self.optimizer)
					self.loss_scaler.update()
				else:
					self.optimizer.step()
				self.optimizer.zero_grad()

				self._get_grad_norm()
	
	def _get_grad_norm(self):
		t = [ param.grad.detach().flatten() for param in self.module.parameters() if param.grad is not None ]
		self._global_grad_norm = torch.cat(t).norm().item() if len(t) else None

	def get_lr(self):
		lrs = []
		for param_group in self.optimizer.param_groups:
			if 'd_coeff' in param_group:
				lrs.append(param_group['d_coeff'])
			elif 'lr' in param_group:
				lrs.append(param_group['lr'])
		return lrs

	def set_lr(self, lr):
		for param_group in self.optimizer.param_groups:
			if 'd_coeff' in param_group:
				param_group['d_coeff'] = lr
			elif 'lr' in param_group:
				param_group['lr'] = lr

	def get_global_grad_norm(self):
		return self._global_grad_norm

	def traverse(self, *args, **kwargs):
		with ml.autocast():
			self.forward(*args, **kwargs)
		
		losses = self.gather_attribute("loss")
		loss = torch.stack([*losses.values()]).sum()

		if torch.isnan(loss).any():
			self.max_nan_losses = self.max_nan_losses - 1
			if self.max_nan_losses < 0:
				raise RuntimeError("Too many NaN losses detected.")

		stats = {}
		stats |= {k: v.item() for k, v in losses.items()}
		stats |= self.gather_attribute("scalar")

		self.backward(loss)
		self.step()

		return stats

# and now to ignore everything from the above
class Engines(dict[str, Engine]):
	def __init__(self, *args, **kwargs):
		super().__init__(*args, **kwargs)
		self.setup()

	def setup(self):
		self._global_step = 0
		self._micro_step = 0
		self._batch_size = 0
		self._global_samples = 0

	@property
	def global_step(self):
		return self._global_step
	
	@property
	def micro_step(self):
		return self._micro_step

	@property
	def batch_size(self):
		return self._batch_size

	@property
	def global_samples(self):
		return self._global_samples

	def gather_attribute(self, *args, **kwargs):
		ret = {}
		for engine in self.values():
			ret |= engine.gather_attribute(*args, **kwargs)
		return ret

	def dispatch_attribute(self, *args, **kwargs):
		for engine in self.values():
			engine.dispatch_attribute(*args, **kwargs)

	def export(self, userdata={}, callback=None):
		for name, engine in self.items():
			module = engine.module.state_dict()
			lora = None
			save_path = cfg.ckpt_dir / name / "fp32.pth"

			if cfg.lora is not None:				
				lora, module = lora_get_state_dict( module, split = True )
				save_path = cfg.ckpt_dir / cfg.lora.full_name / "fp32.pth"

			state_dict = {
				'module': module,
				'lora': lora,
				"stats": {
					"global_step": engine.global_step,
					"micro_step": engine.micro_step,
					"global_samples": engine.global_samples,
					"tokens_processed": engine.tokens_processed,
				},
				"userdata": userdata
			}
			if callback:
				state_dict = callback( state_dict, config = engine.hyper_config, save_path = save_path )

			torch.save(state_dict, save_path)
			print(f"Exported {name} to {save_path}")

	def save_checkpoint(self, tag=None):
		if not tag:
			tag = cfg.trainer.save_tag
		tag = tag.lower()
		if tag[:2] == "it" or tag[:4] == "step":
			tag = f'{self.global_step}'

		cfg.ckpt_dir.mkdir(parents=True, exist_ok=True)
		for name, engine in self.items():
			if not engine._training:
				continue

			save_dir = cfg.ckpt_dir / name
			try:
				engine.save_checkpoint(save_dir, tag=tag)
			except Exception as e:
				print(f'Failed to save checkpoint for engine {name}:', str(e))

			# might be better to prune before saving for safety, but [:0] returns an empty list, but I could do [:-cfg.trainer.keep_last_checkpoints - 1 if cfg.trainer.keep_last_checkpoints > 1 else None]
			if cfg.trainer.keep_last_checkpoints > 0 and is_global_leader():
				checkpoints = [ d for d in list(save_dir.glob("*")) if d.is_dir() ]
				checkpoints.sort(key=lambda x: x.stat().st_mtime)
				checkpoints = checkpoints[:-cfg.trainer.keep_last_checkpoints]
				for d in checkpoints:
					if not d.is_dir() or not d.exists():									
						continue
					print("Removing", d)
					for p in d.iterdir():
						p.unlink()
					d.rmdir()

	def load_checkpoint(self, tag=None):
		if not tag:
			tag = cfg.trainer.load_tag

		for name, engine in self.items():
			load_dir = cfg.ckpt_dir / name
			engine.load_checkpoint(
				tag=tag,
				load_dir=load_dir,
				load_module_strict=cfg.trainer.strict_loading,
				load_optimizer_states=False if cfg.trainer.load_module_only else cfg.trainer.load_states,
				load_lr_scheduler_states=False if cfg.trainer.load_module_only else cfg.trainer.load_states,
				load_module_only=cfg.trainer.load_module_only,
			)
			if cfg.trainer.restart_step_count:
				engine.global_steps = 0
				engine.mocro_step = 0
				engine.global_samples = 0
				engine.tokens_processed = 0

		# update the LR because for some god awful reason it gets overwritten when loading from a checkpoint but only when it's not using a scheduler
		if cfg.hyperparameters.scheduler_type == "":
			self.set_lr(cfg.hyperparameters.learning_rate)

		self._update()

	def set_lr(self, lr):
		for engine in self.values():
			if not engine._training:
				continue
			engine.set_lr(lr)

	def _update(self):
		for engine in self.values():
			self._global_step = max(self._global_step, engine.global_step)
			self._micro_step = max(self._micro_step, engine.micro_step)
			self._batch_size = max(self._batch_size, engine.batch_size)
			self._global_samples = max(self._global_samples, engine.global_samples)

	def eval(self):
		for engine in self.values():
			engine.eval()

	def train(self):
		for engine in self.values():
			engine.train()

	def traverse(self):
		stats = {}
		for name, engine in self.items():
			stat = engine.traverse()
			stats.update(flatten_dict({ name.split("-")[0]: stat }))
		return stats

	def step(self, batch, feeder: TrainFeeder = default_feeder):
		total_elapsed_time = 0

		stats: Any = dict()

		if cfg.trainer.gc_mode == 'step':
			do_gc()

		for name, engine in self.items():
			if not engine._training:
				continue

			device = engine.device

			if cfg.trainer.gc_mode == 'substep':
				do_gc()

			start_time = time.time()

			tries = 4
			n_ooms = torch.zeros([], device=device)			
			
			batch = to_device(batch, device)

			if not cfg.trainer.check_for_oom:
				res = feeder( engine=engine, batch=batch )
			else:
				while tries >= 0:
					try:
						res = feeder( engine=engine, batch=batch )
						break
					except RuntimeError as e:
						print("Forward", str(e))

						if "out of memory" not in str(e):
							self.save_checkpoint()
							raise e

						# shrink batch size until it's happy
						for k in batch:
							batch[k] = batch[k][:-1]

						if tries <= 0:
							# trigger OOM
							n_ooms += 1
						else:
							# also do GC
							do_gc()
						continue

				if world_size() > 1:
					all_reduce(n_ooms)
				if n_ooms.item() > 0:
					self.save_checkpoint()
					raise RuntimeError("Out of memory during forward pass!")

			if res is None:
				continue
			
			loss, engine_stats = res
			engine_stats |= self.gather_attribute("scalar")

			n_ooms = torch.zeros([], device=device)
			
			if cfg.trainer.aggressive_optimizations:
				batch = to_device(batch, 'cpu')

			if not cfg.trainer.check_for_oom:
				engine.backward(loss)
			else:
				# to-do: properly handle when one GPU throws an OOM because it just halts
				try:
					engine.backward(loss)
				except RuntimeError as e:
					print("Backwards:", str(e))

					if "out of memory" not in str(e):
						self.save_checkpoint()
						raise e
					
					n_ooms += 1

				if world_size() > 1:
					all_reduce(n_ooms)

				if n_ooms.item() > 0:
					self.save_checkpoint()
				
				raise RuntimeError("Out of memory during backwards pass!")

			engine.step()
			
			#torch.cuda.synchronize()

			elapsed_time = time.time() - start_time
			total_elapsed_time += elapsed_time
			grad_norm = engine.get_global_grad_norm()
			loss_scale = 1
			if hasattr(engine.optimizer, "loss_scale") and engine.optimizer.loss_scale is not None:
				loss_scale = engine.optimizer.loss_scale
			
			if grad_norm is not None:
				grad_norm /= loss_scale

			stats.update(
				flatten_dict(
					{
						name.split("-")[0]: dict(
							**engine_stats,
							lr=engine.get_lr()[0],
							grad_norm=grad_norm,
							loss_scale=loss_scale if loss_scale != 1 else None,
							elapsed_time=elapsed_time,
							engine_step=engine.global_step,
							samples_processed=engine.global_samples,
							tokens_processed=engine.tokens_processed,
						)
					}
				),
			)

		self._update()

		if len(self.keys()) > 1:
			stats["elapsed_time"] = total_elapsed_time
		
		stats["it"] = self.global_step

		return stats