#!python
# cython: language_level=3

import ffmpeg
import numpy as np
import urllib.request
import os
from pathlib import Path

MODELS_DIR = str(Path('~/.ggml-models').expanduser())


cimport numpy as cnp

cdef int SAMPLE_RATE = 16000
cdef char* TEST_FILE = 'test.wav'
cdef char* DEFAULT_MODEL = 'base'
cdef char* LANGUAGE = b'en'
cdef int N_THREADS = os.cpu_count()
cdef _Bool PRINT_REALTIME = False
cdef _Bool PRINT_PROGRESS = False
cdef _Bool TRANSLATE = False


MODELS = {
	'ggml-tiny.bin': 'https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-tiny.bin',
	'ggml-tiny.en.bin': 'https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-tiny.en.bin',
	'ggml-base.bin': 'https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-base.bin',
	'ggml-base.en.bin': 'https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin',
	'ggml-small.bin': 'https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-small.bin',
	'ggml-small.en.bin': 'https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-small.en.bin',
	'ggml-medium.bin': 'https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-medium.bin',
	'ggml-medium.en.bin': 'https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-medium.en.bin',
	'ggml-large-v1.bin': 'https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-large-v1.bin',
	'ggml-large.bin': 'https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-large.bin',
}

def model_exists(model, models_dir=MODELS_DIR):
	return os.path.exists(Path(models_dir).joinpath(model))

def download_model(model, models_dir=MODELS_DIR):
	"""Downloads ggml model with the given identifier

	The filenames mirror the ones given in ggerganov's repos.
	e.g. 'small' becomes 'ggml-small.bin'

	Args:
	    model: The model identifier
	    models_dir: The path where the file is written to
	"""
	if model_exists(model, models_dir=models_dir):
		return

	print(f'Downloading {model} to {models_dir}...')
	url = MODELS[model]
	os.makedirs(models_dir, exist_ok=True)
	with urllib.request.urlopen(url) as r:
		with open(Path(models_dir).joinpath(model), 'wb') as f:
			f.write(r.read())


cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] load_audio(bytes file, int sr = SAMPLE_RATE):
	try:
		out = (
			ffmpeg.input(file, threads=0)
			.output(
				"-", format="s16le",
				acodec="pcm_s16le",
				ac=1, ar=sr
			)
			.run(
				cmd=["ffmpeg", "-nostdin"],
				capture_stdout=True,
				capture_stderr=True
			)
		)[0]
	except Exception:
		raise RuntimeError(f"File '{file}' not found")

	cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] frames = (
		np.frombuffer(out, np.int16)
		.flatten()
		.astype(np.float32)
	) / pow(2, 15)

	return frames

cdef whisper_full_params set_params(_Bool print_realtime, _Bool print_progress, _Bool translate, char* language, int n_threads) nogil:
	cdef whisper_full_params params = whisper_full_default_params(
		whisper_sampling_strategy.WHISPER_SAMPLING_GREEDY
	)
	params.print_realtime = print_realtime
	params.print_progress = print_progress
	params.translate = translate
	params.language = <const char *> language
	params.n_threads = n_threads
	return params

cdef class Whisper:
	cdef whisper_context * ctx
	cdef whisper_full_params params

	def __init__(self, model = DEFAULT_MODEL, models_dir = MODELS_DIR, _Bool print_realtime = PRINT_REALTIME, _Bool print_progress = PRINT_PROGRESS, _Bool translate = TRANSLATE, char* language = LANGUAGE, int n_threads = N_THREADS, _Bool print_system_info = False): # not pretty, look for a way to use kwargs?
		"""Constructor for Whisper class.

		Automatically checks for model and downloads it if necessary.

		Args:
		    model: Model identifier, e.g. 'base' (see MODELS)
		    models_dir: The path where the models should be stored
		    print_realtime: whisper.cpp's real time transcription output
		    print_progress: whisper.cpp's progress indicator
		    translate: whisper.cpp's translation option
		    language: Which language to use. Must be a byte string.
		    n_threads: Amount of threads to use
		    print_system_info: whisper.cpp's system info output
		"""
		model_fullname = f'ggml-{model}.bin' #.encode('utf8')
		download_model(model_fullname, models_dir=models_dir)
		model_path = Path(models_dir).joinpath(model_fullname)
		cdef bytes model_b = str(model_path).encode('utf8')
		self.ctx = whisper_init(model_b)
		self.params = set_params(print_realtime, print_progress, translate, language, n_threads)
		if print_system_info:
			whisper_print_system_info()

	def __dealloc__(self):
		whisper_free(self.ctx)

	def transcribe(self, filename = TEST_FILE):
		"""Transcribes from given file.

		Args:
		    filename: Path to file

		Returns:
		    A result id for extract_text(...)

		Raises:
		    RuntimeError: The given file could not be found
		"""

		#print(f"Loading data from '{filename}'...")
		cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] frames = load_audio(<bytes>filename)

		#print("Transcribing..")
		return whisper_full(self.ctx, self.params, &frames[0], len(frames))
	
	def extract_text(self, int res):
		"""Extracts the text from a transcription.

		Args:
		    res: A result id from transcribe(...)

		Returns:
		    A list of transcribed strings.

		Raises:
		    RuntimeError: The given result id was invalid.
		"""
		#print("Extracting text...")
		if res != 0:
			raise RuntimeError
		cdef int n_segments = whisper_full_n_segments(self.ctx)
		return [
			whisper_full_get_segment_text(self.ctx, i).decode() for i in range(n_segments)
		]