import os
if 'XDG_CACHE_HOME' not in os.environ:
	os.environ['XDG_CACHE_HOME'] = os.path.realpath(os.path.join(os.getcwd(), './models/'))

if 'TORTOISE_MODELS_DIR' not in os.environ:
	os.environ['TORTOISE_MODELS_DIR'] = os.path.realpath(os.path.join(os.getcwd(), './models/tortoise/'))

if 'TRANSFORMERS_CACHE' not in os.environ:
	os.environ['TRANSFORMERS_CACHE'] = os.path.realpath(os.path.join(os.getcwd(), './models/transformers/'))

import argparse
import time
import math
import json
import base64
import re
import urllib.request
import signal
import gc
import subprocess
import psutil
import yaml
import hashlib
import string
import random

from tqdm import tqdm
import torch
import torchaudio
import music_tag
import gradio as gr
import gradio.utils
import pandas as pd
import numpy as np

from glob import glob
from datetime import datetime
from datetime import timedelta

from tortoise.api import TextToSpeech as TorToise_TTS, MODELS, get_model_path, pad_or_truncate
from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir, get_voices
from tortoise.utils.text import split_and_recombine_text
from tortoise.utils.device import get_device_name, set_device_name, get_device_count, get_device_vram, get_device_batch_size, do_gc


MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth"

WHISPER_MODELS = ["tiny", "base", "small", "medium", "large", "large-v1", "large-v2"]
WHISPER_SPECIALIZED_MODELS = ["tiny.en", "base.en", "small.en", "medium.en"]
WHISPER_BACKENDS = ["openai/whisper", "lightmare/whispercpp", "m-bain/whisperx"]
VOCODERS = ['univnet', 'bigvgan_base_24khz_100band', 'bigvgan_24khz_100band']
TTSES = ['tortoise']

INFERENCING = False
GENERATE_SETTINGS_ARGS = None

LEARNING_RATE_SCHEMES = {"Multistep": "MultiStepLR", "Cos. Annealing": "CosineAnnealingLR_Restart"}
LEARNING_RATE_SCHEDULE = [ 2, 4, 9, 18, 25, 33, 50 ]

RESAMPLERS = {}

MIN_TRAINING_DURATION = 0.6
MAX_TRAINING_DURATION = 11.6097505669
MAX_TRAINING_CHAR_LENGTH = 200

VALLE_ENABLED = False
BARK_ENABLED = False

VERBOSE_DEBUG = True

import traceback

try:
	from whisper.normalizers.english import EnglishTextNormalizer
	from whisper.normalizers.basic import BasicTextNormalizer
	from whisper.tokenizer import LANGUAGES 

	print("Whisper detected")
except Exception as e:
	if VERBOSE_DEBUG:
		print(traceback.format_exc())
	pass

try:
	from vall_e.emb.qnt import encode as valle_quantize
	from vall_e.emb.g2p import encode as valle_phonemize

	from vall_e.inference import TTS as VALLE_TTS

	import soundfile

	print("VALL-E detected")
	VALLE_ENABLED = True
except Exception as e:
	if VERBOSE_DEBUG:
		print(traceback.format_exc())
	pass

if VALLE_ENABLED:
	TTSES.append('vall-e')

# torchaudio.set_audio_backend('soundfile')

try:
	import bark
	from bark import text_to_semantic
	from bark.generation import SAMPLE_RATE as BARK_SAMPLE_RATE, ALLOWED_PROMPTS, preload_models, codec_decode, generate_coarse, generate_fine, generate_text_semantic, load_codec_model
	from bark.api import generate_audio as bark_generate_audio
	from encodec.utils import convert_audio

	from scipy.io.wavfile import write as write_wav

	print("Bark detected")
	BARK_ENABLED = True
except Exception as e:
	if VERBOSE_DEBUG:
		print(traceback.format_exc())
	pass

if BARK_ENABLED:
	TTSES.append('bark')

	def semantic_to_audio_tokens(
	    semantic_tokens,
	    history_prompt = None,
	    temp = 0.7,
	    silent = False,
	    output_full = False,
	):
	    coarse_tokens = generate_coarse(
	        semantic_tokens, history_prompt=history_prompt, temp=temp, silent=silent, use_kv_caching=True
	    )
	    fine_tokens = generate_fine(coarse_tokens, history_prompt=history_prompt, temp=0.5)

	    if output_full:
	        full_generation = {
	            "semantic_prompt": semantic_tokens,
	            "coarse_prompt": coarse_tokens,
	            "fine_prompt": fine_tokens,
	        }
	        return full_generation
	    return fine_tokens

	class Bark_TTS():
		def __init__(self, small=False):
			self.input_sample_rate = BARK_SAMPLE_RATE
			self.output_sample_rate = BARK_SAMPLE_RATE # args.output_sample_rate

			preload_models(
				text_use_gpu=True,
				coarse_use_gpu=True,
				fine_use_gpu=True,
				codec_use_gpu=True,

				text_use_small=small,
				coarse_use_small=small,
				fine_use_small=small,
				
				force_reload=False
			)

			self.device = get_device_name()

			try:
				from vocos import Vocos
				self.vocos_enabled = True
				print("Vocos detected")
			except Exception as e:
				if VERBOSE_DEBUG:
					print(traceback.format_exc())
				self.vocos_enabled = False

			try:
				from hubert.hubert_manager import HuBERTManager
				from hubert.pre_kmeans_hubert import CustomHubert
				from hubert.customtokenizer import CustomTokenizer

				hubert_manager = HuBERTManager()
				hubert_manager.make_sure_hubert_installed()
				hubert_manager.make_sure_tokenizer_installed()

				self.hubert_enabled = True
				print("HuBERT detected")
			except Exception as e:
				if VERBOSE_DEBUG:
					print(traceback.format_exc())
				self.hubert_enabled = False

			if self.vocos_enabled:
				self.vocos = Vocos.from_pretrained("charactr/vocos-encodec-24khz").to(self.device)

		def create_voice( self, voice ):
			transcription_json = f'./training/{voice}/whisper.json'
			if not os.path.exists(transcription_json):
				raise f"Transcription for voice not found: {voice}"
			
			transcriptions = json.load(open(transcription_json, 'r', encoding="utf-8"))
			candidates = []
			for file in transcriptions:
				result = transcriptions[file]
				added = 0

				for segment in result['segments']:
					path = file.replace(".wav", f"_{pad(segment['id'], 4)}.wav")
					# check if the slice actually exists
					if not os.path.exists(f'./training/{voice}/audio/{path}'):
						continue

					entry = (
						path,
						segment['end'] - segment['start'],
						segment['text']
					)
					candidates.append(entry)
					added = added + 1

				# if nothing got added (assuming because nothign was sliced), use the master file
				if added == 0: # added < len(result['segments']):
					start = 0
					end = 0
					for segment in result['segments']:
						start = max( start, segment['start'] )
						end = max( end, segment['end'] )

					entry = (
						file,
						end - start,
						result['text']
					)
					candidates.append(entry)

			candidates.sort(key=lambda x: x[1])
			candidate = random.choice(candidates)
			audio_filepath = f'./training/{voice}/audio/{candidate[0]}'
			text = candidate[-1]

			print("Using as reference:", audio_filepath, text)

			# Load and pre-process the audio waveform
			model = load_codec_model(use_gpu=True)
			wav, sr = torchaudio.load(audio_filepath)
			wav = convert_audio(wav, sr, model.sample_rate, model.channels)

			# generate semantic tokens

			if self.hubert_enabled:
				wav = wav.to(self.device)

				# Extract discrete codes from EnCodec
				with torch.no_grad():
					encoded_frames = model.encode(wav.unsqueeze(0))
				codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze()  # [n_q, T]

				# get seconds of audio
				seconds = wav.shape[-1] / model.sample_rate

				# Load the HuBERT model
				hubert_model = CustomHubert(checkpoint_path='./data/models/hubert/hubert.pt').to(self.device)

				# Load the CustomTokenizer model
				tokenizer = CustomTokenizer.load_from_checkpoint('./data/models/hubert/tokenizer.pth').to(self.device)

				semantic_vectors = hubert_model.forward(wav, input_sample_hz=model.sample_rate)
				semantic_tokens = tokenizer.get_token(semantic_vectors)

				# move codes to cpu
				codes = codes.cpu().numpy()
				# move semantic tokens to cpu
				semantic_tokens = semantic_tokens.cpu().numpy()
			else:
				wav = wav.unsqueeze(0).to(self.device)

				# Extract discrete codes from EnCodec
				with torch.no_grad():
					encoded_frames = model.encode(wav)
				codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze().cpu().numpy()  # [n_q, T]

				# get seconds of audio
				seconds = wav.shape[-1] / model.sample_rate

				# generate semantic tokens
				semantic_tokens = generate_text_semantic(text, max_gen_duration_s=seconds, top_k=50, top_p=.95, temp=0.7)

			# print(bark.__file__)
			bark_location = os.path.dirname(os.path.relpath(bark.__file__)) # './modules/bark/bark/'
			output_path = f'./{bark_location}/assets/prompts/' + voice.replace("/", "_") + '.npz'
			np.savez(output_path, fine_prompt=codes, coarse_prompt=codes[:2, :], semantic_prompt=semantic_tokens)

		def inference( self, text, voice, text_temp=0.7, waveform_temp=0.7 ):
			if voice == "random":
				voice = None
			else:
				if not os.path.exists('./modules/bark/bark/assets/prompts/' + voice + '.npz'):
					self.create_voice( voice )
				voice = voice.replace("/", "_")
				if voice not in ALLOWED_PROMPTS:
					ALLOWED_PROMPTS.add( voice )

			semantic_tokens = text_to_semantic(text, history_prompt=voice, temp=text_temp, silent=False)
			audio_tokens = semantic_to_audio_tokens( semantic_tokens, history_prompt=voice, temp=waveform_temp, silent=False, output_full=False )

			if VOCOS_ENABLED:
				audio_tokens_torch = torch.from_numpy(audio_tokens).to(self.device)
				features = self.vocos.codes_to_features(audio_tokens_torch)
				wav = self.vocos.decode(features, bandwidth_id=torch.tensor([2], device=self.device))
			else:
				wav = codec_decode( audio_tokens )

			return ( wav, BARK_SAMPLE_RATE )
			# return (bark_generate_audio(text, history_prompt=voice, text_temp=text_temp, waveform_temp=waveform_temp), BARK_SAMPLE_RATE)

args = None
tts = None
tts_loading = False
webui = None
voicefixer = None

whisper_model = None
whisper_align_model = None

training_state = None

current_voice = None

def cleanup_voice_name( name ):
	return name.split("/")[-1]

def resample( waveform, input_rate, output_rate=44100 ):
	# mono-ize
	waveform = torch.mean(waveform, dim=0, keepdim=True)

	if input_rate == output_rate:
		return waveform, output_rate

	key = f'{input_rate}:{output_rate}'
	if not key in RESAMPLERS:
		RESAMPLERS[key] = torchaudio.transforms.Resample(
			input_rate,
			output_rate,
			lowpass_filter_width=16,
			rolloff=0.85,
			resampling_method="kaiser_window",
			beta=8.555504641634386,
		)

	return RESAMPLERS[key]( waveform ), output_rate

def generate(**kwargs):
	if args.tts_backend == "tortoise":
		return generate_tortoise(**kwargs)
	if args.tts_backend == "vall-e":
		return generate_valle(**kwargs)
	if args.tts_backend == "bark":
		return generate_bark(**kwargs)

def generate_bark(**kwargs):
	parameters = {}
	parameters.update(kwargs)

	voice = parameters['voice']
	progress = parameters['progress'] if 'progress' in parameters else None
	if parameters['seed'] == 0:
		parameters['seed'] = None

	usedSeed = parameters['seed']

	global args
	global tts

	unload_whisper()
	unload_voicefixer()

	if not tts:
		# should check if it's loading or unloaded, and load it if it's unloaded
		if tts_loading:
			raise Exception("TTS is still initializing...")
		if progress is not None:
			notify_progress("Initializing TTS...", progress=progress)
		load_tts()
	if hasattr(tts, "loading") and tts.loading:
		raise Exception("TTS is still initializing...")

	do_gc()

	voice_samples = None
	conditioning_latents = None
	sample_voice = None

	voice_cache = {}

	def get_settings( override=None ):
		settings = {
			'voice': parameters['voice'],
			'text_temp': float(parameters['temperature']),
			'waveform_temp': float(parameters['temperature']),
		}

		# could be better to just do a ternary on everything above, but i am not a professional
		selected_voice = voice
		if override is not None:
			if 'voice' in override:
				selected_voice = override['voice']

			for k in override:
				if k not in settings:
					continue
				settings[k] = override[k]

		return settings

	if not parameters['delimiter']:
		parameters['delimiter'] = "\n"
	elif parameters['delimiter'] == "\\n":
		parameters['delimiter'] = "\n"

	if parameters['delimiter'] and parameters['delimiter'] != "" and parameters['delimiter'] in parameters['text']:
		texts = parameters['text'].split(parameters['delimiter'])
	else:
		texts = split_and_recombine_text(parameters['text'])
 
	full_start_time = time.time()
 
	outdir = f"{args.results_folder}/{voice}/"
	os.makedirs(outdir, exist_ok=True)

	audio_cache = {}

	volume_adjust = torchaudio.transforms.Vol(gain=args.output_volume, gain_type="amplitude") if args.output_volume != 1 else None

	idx = 0
	idx_cache = {}
	for i, file in enumerate(os.listdir(outdir)):
		filename = os.path.basename(file)
		extension = os.path.splitext(filename)[-1][1:]
		if extension != ".json" and extension != ".wav":
			continue
		match = re.findall(rf"^{cleanup_voice_name(voice)}_(\d+)(?:.+?)?{extension}$", filename)
		if match and len(match) > 0:
			key = int(match[0])
			idx_cache[key] = True

	if len(idx_cache) > 0:
		keys = sorted(list(idx_cache.keys()))
		idx = keys[-1] + 1

	idx = pad(idx, 4)

	def get_name(line=0, candidate=0, combined=False):
		name = f"{idx}"
		if combined:
			name = f"{name}_combined"
		elif len(texts) > 1:
			name = f"{name}_{line}"
		if parameters['candidates'] > 1:
			name = f"{name}_{candidate}"
		return name

	def get_info( voice, settings = None, latents = True ):
		info = {}
		info.update(parameters)

		info['time'] = time.time()-full_start_time
		info['datetime'] = datetime.now().isoformat()

		info['progress'] = None
		del info['progress']

		if info['delimiter'] == "\n":
			info['delimiter'] = "\\n"

		if settings is not None:
			for k in settings:
				if k in info:
					info[k] = settings[k]
		return info

	INFERENCING = True
	for line, cut_text in enumerate(texts):	
		tqdm_prefix = f'[{str(line+1)}/{str(len(texts))}]'
		print(f"{tqdm_prefix} Generating line: {cut_text}")
		start_time = time.time()

		# do setting editing
		match = re.findall(r'^(\{.+\}) (.+?)$', cut_text) 
		override = None
		if match and len(match) > 0:
			match = match[0]
			try:
				override = json.loads(match[0])
				cut_text = match[1].strip()
			except Exception as e:
				raise Exception("Prompt settings editing requested, but received invalid JSON")

		settings = get_settings( override=override )

		gen = tts.inference(cut_text, **settings )

		run_time = time.time()-start_time
		print(f"Generating line took {run_time} seconds")

		if not isinstance(gen, list):
			gen = [gen]

		for j, g in enumerate(gen):
			wav, sr = g
			name = get_name(line=line, candidate=j)

			settings['text'] = cut_text
			settings['time'] = run_time
			settings['datetime'] = datetime.now().isoformat()

			# save here in case some error happens mid-batch
			if VOCOS_ENABLED:
				torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu(), sr)
			else:
				write_wav(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', sr, wav)
			wav, sr = torchaudio.load(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')

			audio_cache[name] = {
				'audio': wav,
				'settings': get_info(voice=override['voice'] if override and 'voice' in override else voice, settings=settings)
			}

	del gen
	do_gc()
	INFERENCING = False

	for k in audio_cache:
		audio = audio_cache[k]['audio']

		audio, _ = resample(audio, tts.output_sample_rate, args.output_sample_rate)
		if volume_adjust is not None:
			audio = volume_adjust(audio)

		audio_cache[k]['audio'] = audio
		torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{k}.wav', audio, args.output_sample_rate)

	output_voices = []
	for candidate in range(parameters['candidates']):
		if len(texts) > 1:
			audio_clips = []
			for line in range(len(texts)):
				name = get_name(line=line, candidate=candidate)
				audio = audio_cache[name]['audio']
				audio_clips.append(audio)
			
			name = get_name(candidate=candidate, combined=True)
			audio = torch.cat(audio_clips, dim=-1)
			torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', audio, args.output_sample_rate)

			audio = audio.squeeze(0).cpu()
			audio_cache[name] = {
				'audio': audio,
				'settings': get_info(voice=voice),
				'output': True
			}
		else:
			try:
				name = get_name(candidate=candidate)
				audio_cache[name]['output'] = True
			except Exception as e:
				for name in audio_cache:
					audio_cache[name]['output'] = True


	if args.voice_fixer:
		if not voicefixer:
			notify_progress("Loading voicefix...", progress=progress)
			load_voicefixer()

		try:
			fixed_cache = {}
			for name in tqdm(audio_cache, desc="Running voicefix..."):
				del audio_cache[name]['audio']
				if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
					continue

				path = f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav'
				fixed = f'{outdir}/{cleanup_voice_name(voice)}_{name}_fixed.wav'
				voicefixer.restore(
					input=path,
					output=fixed,
					cuda=get_device_name() == "cuda" and args.voice_fixer_use_cuda,
					#mode=mode,
				)
				
				fixed_cache[f'{name}_fixed'] = {
					'settings': audio_cache[name]['settings'],
					'output': True
				}
				audio_cache[name]['output'] = False
			
			for name in fixed_cache:
				audio_cache[name] = fixed_cache[name]
		except Exception as e:
			print(e)
			print("\nFailed to run Voicefixer")

	for name in audio_cache:
		if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
			if args.prune_nonfinal_outputs:
				audio_cache[name]['pruned'] = True
				os.remove(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
			continue

		output_voices.append(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')

		if not args.embed_output_metadata:
			with open(f'{outdir}/{cleanup_voice_name(voice)}_{name}.json', 'w', encoding="utf-8") as f:
				f.write(json.dumps(audio_cache[name]['settings'], indent='\t') )

	if args.embed_output_metadata:
		for name in tqdm(audio_cache, desc="Embedding metadata..."):
			if 'pruned' in audio_cache[name] and audio_cache[name]['pruned']:
				continue

			metadata = music_tag.load_file(f"{outdir}/{cleanup_voice_name(voice)}_{name}.wav")
			metadata['lyrics'] = json.dumps(audio_cache[name]['settings'])
			metadata.save()
 
	if sample_voice is not None:
		sample_voice = (tts.input_sample_rate, sample_voice.numpy())

	info = get_info(voice=voice, latents=False)
	print(f"Generation took {info['time']} seconds, saved to '{output_voices[0]}'\n")

	info['seed'] = usedSeed
	if 'latents' in info:
		del info['latents']

	os.makedirs('./config/', exist_ok=True)
	with open(f'./config/generate.json', 'w', encoding="utf-8") as f:
		f.write(json.dumps(info, indent='\t') )

	stats = [
		[ parameters['seed'], "{:.3f}".format(info['time']) ]
	]

	return (
		sample_voice,
		output_voices,
		stats,
	)

def generate_valle(**kwargs):
	parameters = {}
	parameters.update(kwargs)

	voice = parameters['voice']
	progress = parameters['progress'] if 'progress' in parameters else None
	if parameters['seed'] == 0:
		parameters['seed'] = None

	usedSeed = parameters['seed']

	global args
	global tts

	unload_whisper()
	unload_voicefixer()

	if not tts:
		# should check if it's loading or unloaded, and load it if it's unloaded
		if tts_loading:
			raise Exception("TTS is still initializing...")
		if progress is not None:
			notify_progress("Initializing TTS...", progress=progress)
		load_tts()
	if hasattr(tts, "loading") and tts.loading:
		raise Exception("TTS is still initializing...")

	do_gc()

	voice_samples = None
	conditioning_latents = None
	sample_voice = None

	voice_cache = {}
	def fetch_voice( voice ):
		if voice in voice_cache:
			return voice_cache[voice]
		voice_dir = f'./training/{voice}/audio/'

		if not os.path.isdir(voice_dir) or len(os.listdir(voice_dir)) == 0:
			voice_dir = f'./voices/{voice}/'

		files = [ f'{voice_dir}/{d}' for d in os.listdir(voice_dir) if d[-4:] == ".wav" ]
		# return files
		voice_cache[voice] = random.choice(files)
		return voice_cache[voice]

	def get_settings( override=None ):
		settings = {
			'ar_temp': float(parameters['temperature']),
			'nar_temp': float(parameters['temperature']),
			'max_ar_steps': parameters['num_autoregressive_samples'],
		}

		# could be better to just do a ternary on everything above, but i am not a professional
		selected_voice = voice
		if override is not None:
			if 'voice' in override:
				selected_voice = override['voice']

			for k in override:
				if k not in settings:
					continue
				settings[k] = override[k]

		settings['references'] = [ fetch_voice(voice=selected_voice) for _ in range(3) ]
		return settings

	if not parameters['delimiter']:
		parameters['delimiter'] = "\n"
	elif parameters['delimiter'] == "\\n":
		parameters['delimiter'] = "\n"

	if parameters['delimiter'] and parameters['delimiter'] != "" and parameters['delimiter'] in parameters['text']:
		texts = parameters['text'].split(parameters['delimiter'])
	else:
		texts = split_and_recombine_text(parameters['text'])
 
	full_start_time = time.time()
 
	outdir = f"{args.results_folder}/{voice}/"
	os.makedirs(outdir, exist_ok=True)

	audio_cache = {}

	volume_adjust = torchaudio.transforms.Vol(gain=args.output_volume, gain_type="amplitude") if args.output_volume != 1 else None

	idx = 0
	idx_cache = {}
	for i, file in enumerate(os.listdir(outdir)):
		filename = os.path.basename(file)
		extension = os.path.splitext(filename)[-1][1:]
		if extension != ".json" and extension != ".wav":
			continue
		match = re.findall(rf"^{voice}_(\d+)(?:.+?)?{extension}$", filename)
		if match and len(match) > 0:
			key = int(match[0])
			idx_cache[key] = True

	if len(idx_cache) > 0:
		keys = sorted(list(idx_cache.keys()))
		idx = keys[-1] + 1

	idx = pad(idx, 4)

	def get_name(line=0, candidate=0, combined=False):
		name = f"{idx}"
		if combined:
			name = f"{name}_combined"
		elif len(texts) > 1:
			name = f"{name}_{line}"
		if parameters['candidates'] > 1:
			name = f"{name}_{candidate}"
		return name

	def get_info( voice, settings = None, latents = True ):
		info = {}
		info.update(parameters)

		info['time'] = time.time()-full_start_time
		info['datetime'] = datetime.now().isoformat()

		info['progress'] = None
		del info['progress']

		if info['delimiter'] == "\n":
			info['delimiter'] = "\\n"

		if settings is not None:
			for k in settings:
				if k in info:
					info[k] = settings[k]
		return info

	INFERENCING = True
	for line, cut_text in enumerate(texts):	
		tqdm_prefix = f'[{str(line+1)}/{str(len(texts))}]'
		print(f"{tqdm_prefix} Generating line: {cut_text}")
		start_time = time.time()

		# do setting editing
		match = re.findall(r'^(\{.+\}) (.+?)$', cut_text) 
		override = None
		if match and len(match) > 0:
			match = match[0]
			try:
				override = json.loads(match[0])
				cut_text = match[1].strip()
			except Exception as e:
				raise Exception("Prompt settings editing requested, but received invalid JSON")

		name = get_name(line=line, candidate=0)

		settings = get_settings( override=override )
		references = settings['references']
		settings.pop("references")
		settings['out_path'] = f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav'

		gen = tts.inference(cut_text, references, **settings )

		run_time = time.time()-start_time
		print(f"Generating line took {run_time} seconds")

		if not isinstance(gen, list):
			gen = [gen]

		for j, g in enumerate(gen):
			wav, sr = g
			name = get_name(line=line, candidate=j)

			settings['text'] = cut_text
			settings['time'] = run_time
			settings['datetime'] = datetime.now().isoformat()

			# save here in case some error happens mid-batch
			#torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu(), sr)
			#soundfile.write(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu()[0,0], sr)
			wav, sr = torchaudio.load(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')

			audio_cache[name] = {
				'audio': wav,
				'settings': get_info(voice=override['voice'] if override and 'voice' in override else voice, settings=settings)
			}

	del gen
	do_gc()
	INFERENCING = False

	for k in audio_cache:
		audio = audio_cache[k]['audio']

		audio, _ = resample(audio, tts.output_sample_rate, args.output_sample_rate)
		if volume_adjust is not None:
			audio = volume_adjust(audio)

		audio_cache[k]['audio'] = audio
		torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{k}.wav', audio, args.output_sample_rate)

	output_voices = []
	for candidate in range(parameters['candidates']):
		if len(texts) > 1:
			audio_clips = []
			for line in range(len(texts)):
				name = get_name(line=line, candidate=candidate)
				audio = audio_cache[name]['audio']
				audio_clips.append(audio)
			
			name = get_name(candidate=candidate, combined=True)
			audio = torch.cat(audio_clips, dim=-1)
			torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', audio, args.output_sample_rate)

			audio = audio.squeeze(0).cpu()
			audio_cache[name] = {
				'audio': audio,
				'settings': get_info(voice=voice),
				'output': True
			}
		else:
			name = get_name(candidate=candidate)
			audio_cache[name]['output'] = True


	if args.voice_fixer:
		if not voicefixer:
			notify_progress("Loading voicefix...", progress=progress)
			load_voicefixer()

		try:
			fixed_cache = {}
			for name in tqdm(audio_cache, desc="Running voicefix..."):
				del audio_cache[name]['audio']
				if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
					continue

				path = f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav'
				fixed = f'{outdir}/{cleanup_voice_name(voice)}_{name}_fixed.wav'
				voicefixer.restore(
					input=path,
					output=fixed,
					cuda=get_device_name() == "cuda" and args.voice_fixer_use_cuda,
					#mode=mode,
				)
				
				fixed_cache[f'{name}_fixed'] = {
					'settings': audio_cache[name]['settings'],
					'output': True
				}
				audio_cache[name]['output'] = False
			
			for name in fixed_cache:
				audio_cache[name] = fixed_cache[name]
		except Exception as e:
			print(e)
			print("\nFailed to run Voicefixer")

	for name in audio_cache:
		if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
			if args.prune_nonfinal_outputs:
				audio_cache[name]['pruned'] = True
				os.remove(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
			continue

		output_voices.append(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')

		if not args.embed_output_metadata:
			with open(f'{outdir}/{cleanup_voice_name(voice)}_{name}.json', 'w', encoding="utf-8") as f:
				f.write(json.dumps(audio_cache[name]['settings'], indent='\t') )

	if args.embed_output_metadata:
		for name in tqdm(audio_cache, desc="Embedding metadata..."):
			if 'pruned' in audio_cache[name] and audio_cache[name]['pruned']:
				continue

			metadata = music_tag.load_file(f"{outdir}/{cleanup_voice_name(voice)}_{name}.wav")
			metadata['lyrics'] = json.dumps(audio_cache[name]['settings'])
			metadata.save()
 
	if sample_voice is not None:
		sample_voice = (tts.input_sample_rate, sample_voice.numpy())

	info = get_info(voice=voice, latents=False)
	print(f"Generation took {info['time']} seconds, saved to '{output_voices[0]}'\n")

	info['seed'] = usedSeed
	if 'latents' in info:
		del info['latents']

	os.makedirs('./config/', exist_ok=True)
	with open(f'./config/generate.json', 'w', encoding="utf-8") as f:
		f.write(json.dumps(info, indent='\t') )

	stats = [
		[ parameters['seed'], "{:.3f}".format(info['time']) ]
	]

	return (
		sample_voice,
		output_voices,
		stats,
	)

def generate_tortoise(**kwargs):
	parameters = {}
	parameters.update(kwargs)

	voice = parameters['voice']
	progress = parameters['progress'] if 'progress' in parameters else None
	if parameters['seed'] == 0:
		parameters['seed'] = None

	usedSeed = parameters['seed']

	global args
	global tts

	unload_whisper()
	unload_voicefixer()

	if not tts:
		# should check if it's loading or unloaded, and load it if it's unloaded
		if tts_loading:
			raise Exception("TTS is still initializing...")
		load_tts()
	if hasattr(tts, "loading") and tts.loading:
		raise Exception("TTS is still initializing...")

	do_gc()

	voice_samples = None
	conditioning_latents = None
	sample_voice = None

	voice_cache = {}
	def fetch_voice( voice ):
		cache_key = f'{voice}:{tts.autoregressive_model_hash[:8]}'
		if cache_key in voice_cache:
			return voice_cache[cache_key]

		print(f"Loading voice: {voice} with model {tts.autoregressive_model_hash[:8]}")
		sample_voice = None
		if voice == "microphone":
			if parameters['mic_audio'] is None:
				raise Exception("Please provide audio from mic when choosing `microphone` as a voice input")
			voice_samples, conditioning_latents = [load_audio(parameters['mic_audio'], tts.input_sample_rate)], None
		elif voice == "random":
			voice_samples, conditioning_latents = None, tts.get_random_conditioning_latents()
		else:
			if progress is not None:
				notify_progress(f"Loading voice: {voice}", progress=progress)

			voice_samples, conditioning_latents = load_voice(voice, model_hash=tts.autoregressive_model_hash)
			
		if voice_samples and len(voice_samples) > 0:
			if conditioning_latents is None:
				conditioning_latents = compute_latents(voice=voice, voice_samples=voice_samples, voice_latents_chunks=parameters['voice_latents_chunks'])
				
			sample_voice = torch.cat(voice_samples, dim=-1).squeeze().cpu()
			voice_samples = None

		voice_cache[cache_key] = (voice_samples, conditioning_latents, sample_voice)
		return voice_cache[cache_key]

	def get_settings( override=None ):
		settings = {
			'temperature': float(parameters['temperature']),

			'top_p': float(parameters['top_p']),
			'diffusion_temperature': float(parameters['diffusion_temperature']),
			'length_penalty': float(parameters['length_penalty']),
			'repetition_penalty': float(parameters['repetition_penalty']),
			'cond_free_k': float(parameters['cond_free_k']),

			'num_autoregressive_samples': parameters['num_autoregressive_samples'],
			'sample_batch_size': args.sample_batch_size,
			'diffusion_iterations': parameters['diffusion_iterations'],

			'voice_samples': None,
			'conditioning_latents': None,

			'use_deterministic_seed': parameters['seed'],
			'return_deterministic_state': True,
			'k': parameters['candidates'],
			'diffusion_sampler': parameters['diffusion_sampler'],
			'breathing_room': parameters['breathing_room'],
			'half_p': "Half Precision" in parameters['experimentals'],
			'cond_free': "Conditioning-Free" in parameters['experimentals'],
			'cvvp_amount': parameters['cvvp_weight'],
			
			'autoregressive_model': args.autoregressive_model,
			'diffusion_model': args.diffusion_model,
			'tokenizer_json': args.tokenizer_json,
		}

		# could be better to just do a ternary on everything above, but i am not a professional
		selected_voice = voice
		if override is not None:
			if 'voice' in override:
				selected_voice = override['voice']

			for k in override:
				if k not in settings:
					continue
				settings[k] = override[k]

		if settings['autoregressive_model'] is not None:
			if settings['autoregressive_model'] == "auto":
				settings['autoregressive_model'] = deduce_autoregressive_model(selected_voice)
			tts.load_autoregressive_model(settings['autoregressive_model'])

		if settings['diffusion_model'] is not None:
			if settings['diffusion_model'] == "auto":
				settings['diffusion_model'] = deduce_diffusion_model(selected_voice)
			tts.load_diffusion_model(settings['diffusion_model'])
		
		if settings['tokenizer_json'] is not None:
			tts.load_tokenizer_json(settings['tokenizer_json'])

		settings['voice_samples'], settings['conditioning_latents'], _ = fetch_voice(voice=selected_voice)

		# clamp it down for the insane users who want this
		# it would be wiser to enforce the sample size to the batch size, but this is what the user wants
		settings['sample_batch_size'] = args.sample_batch_size
		if not settings['sample_batch_size']:
			settings['sample_batch_size'] = tts.autoregressive_batch_size
		if settings['num_autoregressive_samples'] < settings['sample_batch_size']:
			settings['sample_batch_size'] = settings['num_autoregressive_samples']

		if settings['conditioning_latents'] is not None and len(settings['conditioning_latents']) == 2 and settings['cvvp_amount'] > 0:
			print("Requesting weighing against CVVP weight, but voice latents are missing some extra data. Please regenerate your voice latents with 'Slimmer voice latents' unchecked.")
			settings['cvvp_amount'] = 0
			
		return settings

	if not parameters['delimiter']:
		parameters['delimiter'] = "\n"
	elif parameters['delimiter'] == "\\n":
		parameters['delimiter'] = "\n"

	if parameters['delimiter'] and parameters['delimiter'] != "" and parameters['delimiter'] in parameters['text']:
		texts = parameters['text'].split(parameters['delimiter'])
	else:
		texts = split_and_recombine_text(parameters['text'])
 
	full_start_time = time.time()
 
	outdir = f"{args.results_folder}/{voice}/"
	os.makedirs(outdir, exist_ok=True)

	audio_cache = {}

	volume_adjust = torchaudio.transforms.Vol(gain=args.output_volume, gain_type="amplitude") if args.output_volume != 1 else None

	idx = 0
	idx_cache = {}
	for i, file in enumerate(os.listdir(outdir)):
		filename = os.path.basename(file)
		extension = os.path.splitext(filename)[-1][1:]
		if extension != "json" and extension != "wav":
			continue
		match = re.findall(rf"^{voice}_(\d+)(?:.+?)?{extension}$", filename)
		if match and len(match) > 0:
			key = int(match[0])
			idx_cache[key] = True

	if len(idx_cache) > 0:
		keys = sorted(list(idx_cache.keys()))
		idx = keys[-1] + 1

	idx = pad(idx, 4)

	def get_name(line=0, candidate=0, combined=False):
		name = f"{idx}"
		if combined:
			name = f"{name}_combined"
		elif len(texts) > 1:
			name = f"{name}_{line}"
		if parameters['candidates'] > 1:
			name = f"{name}_{candidate}"
		return name

	def get_info( voice, settings = None, latents = True ):
		info = {}
		info.update(parameters)

		info['time'] = time.time()-full_start_time
		info['datetime'] = datetime.now().isoformat()

		info['model'] = tts.autoregressive_model_path
		info['model_hash'] = tts.autoregressive_model_hash 

		info['progress'] = None
		del info['progress']

		if info['delimiter'] == "\n":
			info['delimiter'] = "\\n"

		if settings is not None:
			for k in settings:
				if k in info:
					info[k] = settings[k]

			if 'half_p' in settings and 'cond_free' in settings:
				info['experimentals'] = []
				if settings['half_p']:
					info['experimentals'].append("Half Precision")
				if settings['cond_free']:
					info['experimentals'].append("Conditioning-Free")

		if latents and "latents" not in info:
			voice = info['voice']
			model_hash = settings["model_hash"][:8] if settings is not None and "model_hash" in settings else tts.autoregressive_model_hash[:8]

			dir = f'{get_voice_dir()}/{voice}/'
			latents_path = f'{dir}/cond_latents_{model_hash}.pth'

			if voice == "random" or voice == "microphone":
				if latents and settings is not None and settings['conditioning_latents']:
					os.makedirs(dir, exist_ok=True)
					torch.save(conditioning_latents, latents_path)

			if latents_path and os.path.exists(latents_path):
				try:
					with open(latents_path, 'rb') as f:
						info['latents'] = base64.b64encode(f.read()).decode("ascii")
				except Exception as e:
					pass

		return info

	INFERENCING = True
	for line, cut_text in enumerate(texts):
		if should_phonemize():
			cut_text = phonemizer( cut_text )

		if parameters['emotion'] == "Custom":
			if parameters['prompt'] and parameters['prompt'].strip() != "":
				cut_text = f"[{parameters['prompt']},] {cut_text}"
		elif parameters['emotion'] != "None" and parameters['emotion']:
			cut_text = f"[I am really {parameters['emotion'].lower()},] {cut_text}"
		
		tqdm_prefix = f'[{str(line+1)}/{str(len(texts))}]'
		print(f"{tqdm_prefix} Generating line: {cut_text}")
		start_time = time.time()

		# do setting editing
		match = re.findall(r'^(\{.+\}) (.+?)$', cut_text) 
		override = None
		if match and len(match) > 0:
			match = match[0]
			try:
				override = json.loads(match[0])
				cut_text = match[1].strip()
			except Exception as e:
				raise Exception("Prompt settings editing requested, but received invalid JSON")

		settings = get_settings( override=override )
		gen, additionals = tts.tts(cut_text, **settings )

		parameters['seed'] = additionals[0]
		run_time = time.time()-start_time
		print(f"Generating line took {run_time} seconds")

		if not isinstance(gen, list):
			gen = [gen]

		for j, g in enumerate(gen):
			audio = g.squeeze(0).cpu()
			name = get_name(line=line, candidate=j)

			settings['text'] = cut_text
			settings['time'] = run_time
			settings['datetime'] = datetime.now().isoformat()
			if args.tts_backend == "tortoise":
				settings['model'] = tts.autoregressive_model_path
				settings['model_hash'] = tts.autoregressive_model_hash

			audio_cache[name] = {
				'audio': audio,
				'settings': get_info(voice=override['voice'] if override and 'voice' in override else voice, settings=settings)
			}
			# save here in case some error happens mid-batch
			torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', audio, tts.output_sample_rate)

	del gen
	do_gc()
	INFERENCING = False

	for k in audio_cache:
		audio = audio_cache[k]['audio']

		audio, _ = resample(audio, tts.output_sample_rate, args.output_sample_rate)
		if volume_adjust is not None:
			audio = volume_adjust(audio)

		audio_cache[k]['audio'] = audio
		torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{k}.wav', audio, args.output_sample_rate)

	output_voices = []
	for candidate in range(parameters['candidates']):
		if len(texts) > 1:
			audio_clips = []
			for line in range(len(texts)):
				name = get_name(line=line, candidate=candidate)
				audio = audio_cache[name]['audio']
				audio_clips.append(audio)
			
			name = get_name(candidate=candidate, combined=True)
			audio = torch.cat(audio_clips, dim=-1)
			torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', audio, args.output_sample_rate)

			audio = audio.squeeze(0).cpu()
			audio_cache[name] = {
				'audio': audio,
				'settings': get_info(voice=voice),
				'output': True
			}
		else:
			name = get_name(candidate=candidate)
			audio_cache[name]['output'] = True


	if args.voice_fixer:
		if not voicefixer:
			notify_progress("Loading voicefix...", progress=progress)
			load_voicefixer()

		try:
			fixed_cache = {}
			for name in tqdm(audio_cache, desc="Running voicefix..."):
				del audio_cache[name]['audio']
				if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
					continue

				path = f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav'
				fixed = f'{outdir}/{cleanup_voice_name(voice)}_{name}_fixed.wav'
				voicefixer.restore(
					input=path,
					output=fixed,
					cuda=get_device_name() == "cuda" and args.voice_fixer_use_cuda,
					#mode=mode,
				)
				
				fixed_cache[f'{name}_fixed'] = {
					'settings': audio_cache[name]['settings'],
					'output': True
				}
				audio_cache[name]['output'] = False
			
			for name in fixed_cache:
				audio_cache[name] = fixed_cache[name]
		except Exception as e:
			print(e)
			print("\nFailed to run Voicefixer")

	for name in audio_cache:
		if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
			if args.prune_nonfinal_outputs:
				audio_cache[name]['pruned'] = True
				os.remove(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
			continue

		output_voices.append(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')

		if not args.embed_output_metadata:
			with open(f'{outdir}/{cleanup_voice_name(voice)}_{name}.json', 'w', encoding="utf-8") as f:
				f.write(json.dumps(audio_cache[name]['settings'], indent='\t') )

	if args.embed_output_metadata:
		for name in tqdm(audio_cache, desc="Embedding metadata..."):
			if 'pruned' in audio_cache[name] and audio_cache[name]['pruned']:
				continue

			metadata = music_tag.load_file(f"{outdir}/{cleanup_voice_name(voice)}_{name}.wav")
			metadata['lyrics'] = json.dumps(audio_cache[name]['settings'])
			metadata.save()
 
	if sample_voice is not None:
		sample_voice = (tts.input_sample_rate, sample_voice.numpy())

	info = get_info(voice=voice, latents=False)
	print(f"Generation took {info['time']} seconds, saved to '{output_voices[0]}'\n")

	info['seed'] = usedSeed
	if 'latents' in info:
		del info['latents']

	os.makedirs('./config/', exist_ok=True)
	with open(f'./config/generate.json', 'w', encoding="utf-8") as f:
		f.write(json.dumps(info, indent='\t') )

	stats = [
		[ parameters['seed'], "{:.3f}".format(info['time']) ]
	]

	return (
		sample_voice,
		output_voices,
		stats,
	)

def cancel_generate():
	if not INFERENCING:
		return
		
	import tortoise.api

	tortoise.api.STOP_SIGNAL = True

def hash_file(path, algo="md5", buffer_size=0):
	hash = None
	if algo == "md5":
		hash = hashlib.md5()
	elif algo == "sha1":
		hash = hashlib.sha1()
	else:
		raise Exception(f'Unknown hash algorithm specified: {algo}')

	if not os.path.exists(path):
		raise Exception(f'Path not found: {path}')

	with open(path, 'rb') as f:
		if buffer_size > 0:
			while True:
				data = f.read(buffer_size)
				if not data:
					break
				hash.update(data)
		else:
			hash.update(f.read())

	return "{0}".format(hash.hexdigest())

def update_baseline_for_latents_chunks( voice ):
	global current_voice
	current_voice = voice

	path = f'{get_voice_dir()}/{voice}/'
	if not os.path.isdir(path):
		return 1

	dataset_file = f'./training/{voice}/train.txt'
	if os.path.exists(dataset_file):
		return 0 # 0 will leverage using the LJspeech dataset for computing latents

	files = os.listdir(path)
	
	total = 0
	total_duration = 0

	for file in files:
		if file[-4:] != ".wav":
			continue

		metadata = torchaudio.info(f'{path}/{file}')
		duration = metadata.num_frames / metadata.sample_rate
		total_duration += duration
		total = total + 1


	# brain too fried to figure out a better way
	if args.autocalculate_voice_chunk_duration_size == 0:
		return int(total_duration / total) if total > 0 else 1
	return int(total_duration / args.autocalculate_voice_chunk_duration_size) if total_duration > 0 else 1

def compute_latents(voice=None, voice_samples=None, voice_latents_chunks=0, original_ar=False, original_diffusion=False):
	global tts
	global args
	
	unload_whisper()
	unload_voicefixer()

	if not tts:
		if tts_loading:
			raise Exception("TTS is still initializing...")
		load_tts()

	if hasattr(tts, "loading") and tts.loading:
		raise Exception("TTS is still initializing...")

	if args.tts_backend == "bark":
		tts.create_voice( voice )
		return

	if args.autoregressive_model == "auto":
		tts.load_autoregressive_model(deduce_autoregressive_model(voice))

	if voice:
		load_from_dataset = voice_latents_chunks == 0

		if load_from_dataset:
			dataset_path = f'./training/{voice}/train.txt'
			if not os.path.exists(dataset_path):
				load_from_dataset = False
			else:
				with open(dataset_path, 'r', encoding="utf-8") as f:
					lines = f.readlines()

				print("Leveraging dataset for computing latents")

				voice_samples = []
				max_length = 0
				for line in lines:
					filename = f'./training/{voice}/{line.split("|")[0]}'
					
					waveform = load_audio(filename, 22050)
					max_length = max(max_length, waveform.shape[-1])
					voice_samples.append(waveform)

				for i in range(len(voice_samples)):
					voice_samples[i] = pad_or_truncate(voice_samples[i], max_length)

				voice_latents_chunks = len(voice_samples)
				if voice_latents_chunks == 0:
					print("Dataset is empty!")
					load_from_dataset = True
		if not load_from_dataset:
			voice_samples, _ = load_voice(voice, load_latents=False)

	if voice_samples is None:
		return

	conditioning_latents = tts.get_conditioning_latents(voice_samples, return_mels=not args.latents_lean_and_mean, slices=voice_latents_chunks, force_cpu=args.force_cpu_for_conditioning_latents, original_ar=original_ar, original_diffusion=original_diffusion)

	if len(conditioning_latents) == 4:
		conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None)
	
	outfile = f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth'
	torch.save(conditioning_latents, outfile)
	print(f'Saved voice latents: {outfile}')

	return conditioning_latents

# superfluous, but it cleans up some things
class TrainingState():
	def __init__(self, config_path, keep_x_past_checkpoints=0, start=True):
		self.killed = False
		
		self.training_dir = os.path.dirname(config_path)
		with open(config_path, 'r') as file:
			self.yaml_config = yaml.safe_load(file)

		self.json_config = json.load(open(f"{self.training_dir}/train.json", 'r', encoding="utf-8"))
		self.dataset_path = f"{self.training_dir}/train.txt"
		with open(self.dataset_path, 'r', encoding="utf-8") as f:
			self.dataset_size = len(f.readlines())

		self.batch_size = self.json_config["batch_size"]
		self.save_rate = self.json_config["save_rate"]

		self.epoch = 0
		self.epochs = self.json_config["epochs"]
		self.it = 0
		self.its = calc_iterations( self.epochs, self.dataset_size, self.batch_size )
		self.step = 0
		self.steps = int(self.its / self.dataset_size)
		self.checkpoint = 0
		self.checkpoints = int((self.its - self.it) / self.save_rate)

		self.gpus = self.json_config['gpus']

		self.buffer = []

		self.open_state = False
		self.training_started = False

		self.info = {}		
		
		self.it_rate = ""
		self.it_rates = 0
		
		self.epoch_rate = ""

		self.eta = "?"
		self.eta_hhmmss = "?"

		self.nan_detected = False

		self.last_info_check_at = 0
		self.statistics = {
			'loss': [],
			'lr': [],
			'grad_norm': [],
		}
		self.losses = []
		self.metrics = {
			'step': "",
			'rate': "",
			'loss': "",
		}

		self.loss_milestones = [ 1.0, 0.15, 0.05 ]

		if args.tts_backend=="vall-e":
			self.valle_last_it = 0
			self.valle_steps = 0

		if keep_x_past_checkpoints > 0:
			self.cleanup_old(keep=keep_x_past_checkpoints)
		if start:
			self.spawn_process(config_path=config_path, gpus=self.gpus)

	def spawn_process(self, config_path, gpus=1):
		if args.tts_backend == "vall-e":
			self.cmd = ['deepspeed', f'--num_gpus={gpus}', '--module', 'vall_e.train', f'yaml="{config_path}"']
		else:
			self.cmd = ['train.bat', config_path] if os.name == "nt" else ['./train.sh', config_path]

		print("Spawning process: ", " ".join(self.cmd))
		self.process = subprocess.Popen(self.cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)

	def parse_metrics(self, data):
		if isinstance(data, str):
			if line.find('Training Metrics:') >= 0:
				data = json.loads(line.split("Training Metrics:")[-1])
				data['mode'] = "training"
			elif line.find('Validation Metrics:') >= 0:
				data = json.loads(line.split("Validation Metrics:")[-1])
				data['mode'] = "validation"
			else:
				return

		self.info = data
		if 'epoch' in self.info:
			self.epoch = int(self.info['epoch'])
		if 'it' in self.info:
			self.it = int(self.info['it'])
		if 'step' in self.info:
			self.step = int(self.info['step'])
		if 'steps' in self.info:
			self.steps = int(self.info['steps'])

		if 'elapsed_time' in self.info:
			self.info['iteration_rate'] = self.info['elapsed_time']
			del self.info['elapsed_time']

		if 'iteration_rate' in self.info:
			it_rate = self.info['iteration_rate']
			self.it_rate = f'{"{:.3f}".format(1/it_rate)}it/s' if 0 < it_rate and it_rate < 1 else f'{"{:.3f}".format(it_rate)}s/it'
			self.it_rates += it_rate

			if self.it_rates > 0 and self.it * self.steps > 0:
				epoch_rate = self.it_rates / self.it * self.steps
				self.epoch_rate = f'{"{:.3f}".format(1/epoch_rate)}epoch/s' if 0 < epoch_rate and epoch_rate < 1 else f'{"{:.3f}".format(epoch_rate)}s/epoch'

			try:
				self.eta = (self.its - self.it) * (self.it_rates / self.it)
				eta = str(timedelta(seconds=int(self.eta)))
				self.eta_hhmmss = eta
			except Exception as e:
				self.eta_hhmmss = "?"
				pass

		self.metrics['step'] = [f"{self.epoch}/{self.epochs}"]
		if self.epochs != self.its:
			self.metrics['step'].append(f"{self.it}/{self.its}")
		if self.steps > 1:
			self.metrics['step'].append(f"{self.step}/{self.steps}")
		self.metrics['step'] = ", ".join(self.metrics['step'])

		if args.tts_backend == "tortoise":
			epoch = self.epoch + (self.step / self.steps)
		else:
			epoch = self.info['epoch'] if 'epoch' in self.info else self.it

		if self.it > 0:
			# probably can double for-loop but whatever
			keys = {
				'lrs': ['lr'],
				'losses': ['loss_text_ce', 'loss_mel_ce'],
				'accuracies': [],
				'precisions': [],
				'grad_norms': [],
			}
			if args.tts_backend == "vall-e":
				keys['lrs'] = [
					'ar.lr', 'nar.lr',
				]
				keys['losses'] = [
				#	'ar.loss', 'nar.loss', 'ar+nar.loss',
					'ar.loss.nll', 'nar.loss.nll',
				]

				keys['accuracies'] = [
					'ar.loss.acc', 'nar.loss.acc',
					'ar.stats.acc', 'nar.loss.acc',
				]
				keys['precisions'] = [ 'ar.loss.precision', 'nar.loss.precision', ]
				keys['grad_norms'] = ['ar.grad_norm', 'nar.grad_norm']

			for k in keys['lrs']:
				if k not in self.info:
					continue

				self.statistics['lr'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k})

			for k in keys['accuracies']:
				if k not in self.info:
					continue

				self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k})

			for k in keys['precisions']:
				if k not in self.info:
					continue

				self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k})
			
			for k in keys['losses']:
				if k not in self.info:
					continue

				prefix = ""

				if "mode" in self.info and self.info["mode"] == "validation":
					prefix = f'{self.info["name"] if "name" in self.info else "val"}_'

				self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{prefix}{k}' })

			self.losses.append( self.statistics['loss'][-1] )

			for k in keys['grad_norms']:
				if k not in self.info:
					continue
				self.statistics['grad_norm'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k})

		return data

	def get_status(self):
		message = None

		self.metrics['rate'] = []
		if self.epoch_rate:
			self.metrics['rate'].append(self.epoch_rate)
		if self.it_rate and self.epoch_rate[:-7] != self.it_rate[:-4]:
			self.metrics['rate'].append(self.it_rate)
		self.metrics['rate'] = ", ".join(self.metrics['rate'])

		eta_hhmmss = self.eta_hhmmss if self.eta_hhmmss else "?"

		self.metrics['loss'] = []
		if 'lr' in self.info:
			self.metrics['loss'].append(f'LR: {"{:.3e}".format(self.info["lr"])}')

		if len(self.losses) > 0:
			self.metrics['loss'].append(f'Loss: {"{:.3f}".format(self.losses[-1]["value"])}')

		if False and len(self.losses) >= 2:
			deriv = 0
			accum_length = len(self.losses)//2 # i *guess* this is fine when you think about it
			loss_value = self.losses[-1]["value"]

			for i in range(accum_length):
				d1_loss = self.losses[accum_length-i-1]["value"]
				d2_loss = self.losses[accum_length-i-2]["value"]
				dloss = (d2_loss - d1_loss)

				d1_step = self.losses[accum_length-i-1]["it"]
				d2_step = self.losses[accum_length-i-2]["it"]
				dstep = (d2_step - d1_step)

				if dstep == 0:
					continue
		
				inst_deriv = dloss / dstep
				deriv += inst_deriv

			deriv = deriv / accum_length

			print("Deriv: ", deriv)

			if deriv != 0: # dloss < 0:
				next_milestone = None
				for milestone in self.loss_milestones:
					if loss_value > milestone:
						next_milestone = milestone
						break

				print(f"Loss value: {loss_value} | Next milestone: {next_milestone} | Distance: {loss_value - next_milestone}")
						
				if next_milestone:
					# tfw can do simple calculus but not basic algebra in my head
					est_its = (next_milestone - loss_value) / deriv * 100
					print(f"Estimated: {est_its}")
					if est_its >= 0:
						self.metrics['loss'].append(f'Est. milestone {next_milestone} in: {int(est_its)}its')
				else:
					est_loss = inst_deriv * (self.its - self.it) + loss_value
					if est_loss >= 0:
						self.metrics['loss'].append(f'Est. final loss: {"{:.3f}".format(est_loss)}')

		self.metrics['loss'] = ", ".join(self.metrics['loss'])

		message = f"[{self.metrics['step']}] [{self.metrics['rate']}] [ETA: {eta_hhmmss}] [{self.metrics['loss']}]"
		if self.nan_detected:
			message = f"[!NaN DETECTED! {self.nan_detected}] {message}"

		return message

	def load_statistics(self, update=False):
		if not os.path.isdir(self.training_dir):
			return

		if args.tts_backend == "tortoise":
			logs = sorted([f'{self.training_dir}/finetune/{d}' for d in os.listdir(f'{self.training_dir}/finetune/') if d[-4:] == ".log" ])
		else:
			log_dir = "logs"
			logs = sorted([f'{self.training_dir}/{log_dir}/{d}/log.txt' for d in os.listdir(f'{self.training_dir}/{log_dir}/') ])

		if update:
			logs = [logs[-1]]

		infos = {}
		highest_step = self.last_info_check_at

		if not update:
			self.statistics['loss'] = []
			self.statistics['lr'] = []
			self.statistics['grad_norm'] = []
			self.it_rates = 0

		unq = {}
		averager = None
		prev_state = 0

		for log in logs:
			with open(log, 'r', encoding="utf-8") as f:
				lines = f.readlines()

			for line in lines:
				line = line.strip()
				if not line:
					continue
					
				if line[-1] == ".":
					line = line[:-1]

				if line.find('Training Metrics:') >= 0:
					split = line.split("Training Metrics:")[-1]
					data = json.loads(split)
					
					name = "train"
					mode = "training"
					prev_state = 0
				elif line.find('Validation Metrics:') >= 0:
					data = json.loads(line.split("Validation Metrics:")[-1])
					if "it" not in data:
						data['it'] = it
					if "epoch" not in data:
						data['epoch'] = epoch

					# name = data['name'] if 'name' in data else "val"
					mode = "validation"

					if prev_state == 0:
						name = "subtrain"
					else:
						name = "val"

					prev_state += 1
				else:
					continue

				if "it" not in data:
					continue
				
				it = data['it']
				epoch = data['epoch']
				
				if args.tts_backend == "vall-e":
					if not averager or averager['key'] != f'{it}_{name}' or averager['mode'] != mode:
						averager = {
							'key': f'{it}_{name}',
							'name': name,
							'mode': mode,
							"metrics": {}
						}
						for k in data:
							if data[k] is None:
								continue
							averager['metrics'][k] = [ data[k] ]
					else:
						for k in data:
							if data[k] is None:
								continue
							if k not in averager['metrics']:
								averager['metrics'][k] = [ data[k] ]
							else:
								averager['metrics'][k].append( data[k] )

					unq[f'{it}_{mode}_{name}'] = averager
				else:
					unq[f'{it}_{mode}_{name}'] = data

				if update and it <= self.last_info_check_at:
					continue
		
		blacklist = [ "batch", "eval" ]
		for it in unq:
			if args.tts_backend == "vall-e":
				stats = unq[it]
				data = {k: sum(v) / len(v) for k, v in stats['metrics'].items() if k not in blacklist }
				#data = {k: min(v) for k, v in stats['metrics'].items() if k not in blacklist }
				#data = {k: max(v) for k, v in stats['metrics'].items() if k not in blacklist }
				data['name'] = stats['name']
				data['mode'] = stats['mode']
				data['steps'] = len(stats['metrics']['it'])
			else:
				data = unq[it]
			self.parse_metrics(data)

		self.last_info_check_at = highest_step

	def cleanup_old(self, keep=2):
		if keep <= 0:
			return

		if args.tts_backend == "vall-e":
			return

		if not os.path.isdir(f'{self.training_dir}/finetune/'):
			return
			
		models = sorted([ int(d[:-8]) for d in os.listdir(f'{self.training_dir}/finetune/models/') if d[-8:] == "_gpt.pth" ])
		states = sorted([ int(d[:-6]) for d in os.listdir(f'{self.training_dir}/finetune/training_state/') if d[-6:] == ".state" ])
		remove_models = models[:-keep]
		remove_states = states[:-keep]

		for d in remove_models:
			path = f'{self.training_dir}/finetune/models/{d}_gpt.pth'
			print("Removing", path)
			os.remove(path)
		for d in remove_states:
			path = f'{self.training_dir}/finetune/training_state/{d}.state'
			print("Removing", path)
			os.remove(path)

	def parse(self, line, verbose=False, keep_x_past_checkpoints=0, buffer_size=8, progress=None ):
		self.buffer.append(f'{line}')

		data = None
		percent = 0
		message = None
		should_return = False

		MESSAGE_START = 'Start training from epoch'
		MESSAGE_FINSIHED = 'Finished training'
		MESSAGE_SAVING = 'Saving models and training states.'

		MESSAGE_METRICS_TRAINING = 'Training Metrics:'
		MESSAGE_METRICS_VALIDATION = 'Validation Metrics:'

		if line.find(MESSAGE_FINSIHED) >= 0:
			self.killed = True
		# rip out iteration info
		elif not self.training_started:
			if line.find(MESSAGE_START) >= 0:
				self.training_started = True # could just leverage the above variable, but this is python, and there's no point in these aggressive microoptimizations

				match = re.findall(r'epoch: ([\d,]+)', line)
				if match and len(match) > 0:
					self.epoch = int(match[0].replace(",", ""))
				match = re.findall(r'iter: ([\d,]+)', line)
				if match and len(match) > 0:
					self.it = int(match[0].replace(",", ""))

				self.checkpoints = int((self.its - self.it) / self.save_rate)

				self.load_statistics()

				should_return = True
		else:
			if line.find(MESSAGE_SAVING) >= 0:
				self.checkpoint += 1
				message = f"[{self.checkpoint}/{self.checkpoints}] Saving checkpoint..."
				percent = self.checkpoint / self.checkpoints

				self.cleanup_old(keep=keep_x_past_checkpoints)
			elif line.find(MESSAGE_METRICS_TRAINING) >= 0:
				data = json.loads(line.split(MESSAGE_METRICS_TRAINING)[-1])
				data['mode'] = "training"
			elif line.find(MESSAGE_METRICS_VALIDATION) >= 0:
				data = json.loads(line.split(MESSAGE_METRICS_VALIDATION)[-1])
				data['mode'] = "validation"

		if data is not None:
			if ': nan' in line and not self.nan_detected:
				self.nan_detected = self.it
			
			self.parse_metrics( data )
			message = self.get_status()
			
			if message:
				percent = self.it / float(self.its) # self.epoch / float(self.epochs)
				if progress is not None:
					progress(percent, message)

				self.buffer.append(f'[{"{:.3f}".format(percent*100)}%] {message}')
				should_return = True

		if verbose and not self.training_started:
			should_return = True

		self.buffer = self.buffer[-buffer_size:]
		
		result = None
		if should_return:
			result = "".join(self.buffer) if not self.training_started else message

		return (
			result,
			percent,
			message,
		)

try:
	import altair as alt
	alt.data_transformers.enable('default', max_rows=None)
except Exception as e:
	print(e)
	pass

def run_training(config_path, verbose=False, keep_x_past_checkpoints=0, progress=gr.Progress(track_tqdm=True)):
	global training_state
	if training_state and training_state.process:
		return "Training already in progress"


	# ensure we have the dvae.pth
	if args.tts_backend == "tortoise":
		get_model_path('dvae.pth')
	
	# I don't know if this is still necessary, as it was bitching at me for not doing this, despite it being in a separate process
	torch.multiprocessing.freeze_support()

	unload_tts()
	unload_whisper()
	unload_voicefixer()

	training_state = TrainingState(config_path=config_path, keep_x_past_checkpoints=keep_x_past_checkpoints)

	for line in iter(training_state.process.stdout.readline, ""):
		if training_state.killed:
			return

		result, percent, message = training_state.parse( line=line, verbose=verbose, keep_x_past_checkpoints=keep_x_past_checkpoints, progress=progress )
		print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}")
		if result:
			yield result

			if progress is not None and message:
				progress(percent, message)

	if training_state:
		training_state.process.stdout.close()
		return_code = training_state.process.wait()
		training_state = None

def update_training_dataplot(x_min=None, x_max=None, y_min=None, y_max=None, config_path=None):
	global training_state
	losses = None
	lrs = None
	grad_norms = None

	x_lim = [ x_min, x_max ]
	y_lim = [ y_min, y_max ]

	if not training_state:
		if config_path:
			training_state = TrainingState(config_path=config_path, start=False)
			training_state.load_statistics()
			message = training_state.get_status()
	
	if training_state:
		if not x_lim[-1]:
			x_lim[-1] = training_state.epochs

		if not y_lim[-1]:
			y_lim = None

		if len(training_state.statistics['loss']) > 0:
			losses = gr.LinePlot.update(
				value = pd.DataFrame(training_state.statistics['loss']),
				x_lim=x_lim, y_lim=y_lim,
				x="epoch", y="value", # x="it",
				title="Loss Metrics", color="type", tooltip=['epoch', 'it', 'value', 'type'],
				width=500, height=350
			)
		if len(training_state.statistics['lr']) > 0:
			lrs = gr.LinePlot.update(
				value = pd.DataFrame(training_state.statistics['lr']),
				x_lim=x_lim,
				x="epoch", y="value", # x="it",
				title="Learning Rate", color="type", tooltip=['epoch', 'it', 'value', 'type'],
				width=500, height=350
			)
		if len(training_state.statistics['grad_norm']) > 0:
			grad_norms = gr.LinePlot.update(
				value = pd.DataFrame(training_state.statistics['grad_norm']),
				x_lim=x_lim,
				x="epoch", y="value", # x="it",
				title="Gradient Normals", color="type", tooltip=['epoch', 'it', 'value', 'type'],
				width=500, height=350
			)
	
	if config_path:
		del training_state
		training_state = None

	return (losses, lrs, grad_norms)

def reconnect_training(verbose=False, progress=gr.Progress(track_tqdm=True)):
	global training_state
	if not training_state or not training_state.process:
		return "Training not in progress"

	for line in iter(training_state.process.stdout.readline, ""):
		result, percent, message = training_state.parse( line=line, verbose=verbose, progress=progress )
		print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}")
		if result:
			yield result

			if progress is not None and message:
				progress(percent, message)

def stop_training():
	global training_state
	if training_state is None:
		return "No training in progress"
	print("Killing training process...")
	training_state.killed = True

	children = []
	if args.tts_backend == "tortoise":
		# wrapped in a try/catch in case for some reason this fails outside of Linux
		try:
			children = [p.info for p in psutil.process_iter(attrs=['pid', 'name', 'cmdline']) if './src/train.py' in p.info['cmdline']]
		except Exception as e:
			pass

		training_state.process.stdout.close()
		training_state.process.terminate()
		training_state.process.kill()
	elif args.tts_backend == "vall-e":
		print(training_state.process.communicate(input='quit')[0])

	return_code = training_state.process.wait()

	for p in children:
		os.kill( p['pid'], signal.SIGKILL )

	training_state = None
	print("Killed training process.")
	return f"Training cancelled: {return_code}"

def get_halfp_model_path():
	autoregressive_model_path = get_model_path('autoregressive.pth')
	return autoregressive_model_path.replace(".pth", "_half.pth")

def convert_to_halfp():
	autoregressive_model_path = get_model_path('autoregressive.pth')
	print(f'Converting model to half precision: {autoregressive_model_path}')
	model = torch.load(autoregressive_model_path)
	for k in model:
		model[k] = model[k].half()

	outfile = get_halfp_model_path()
	torch.save(model, outfile)
	print(f'Converted model to half precision: {outfile}')


# collapses short segments into the previous segment
def whisper_sanitize( results ):
	sanitized = json.loads(json.dumps(results))
	sanitized['segments'] = []

	for segment in results['segments']:
		length = segment['end'] - segment['start']
		if length >= MIN_TRAINING_DURATION or len(sanitized['segments']) == 0:
			sanitized['segments'].append(segment)
			continue

		last_segment = sanitized['segments'][-1]
		# segment already asimilitated it, somehow
		if last_segment['end'] >= segment['end']:
			continue
		"""
		# segment already asimilitated it, somehow
		if last_segment['text'].endswith(segment['text']):
			continue
		"""
		last_segment['text'] += segment['text']
		last_segment['end'] = segment['end']

	for i in range(len(sanitized['segments'])):
		sanitized['segments'][i]['id'] = i

	return sanitized

def whisper_transcribe( file, language=None ):
	# shouldn't happen, but it's for safety
	global whisper_model
	global whisper_align_model

	if not whisper_model:
		load_whisper_model(language=language)

	if args.whisper_backend == "openai/whisper":
		if not language:
			language = None

		return whisper_model.transcribe(file, language=language)

	if args.whisper_backend == "lightmare/whispercpp":
		res = whisper_model.transcribe(file)
		segments = whisper_model.extract_text_and_timestamps( res )

		result = {
			'text': [],
			'segments': []
		}
		for segment in segments:
			reparsed = {
				'start': segment[0] / 100.0,
				'end': segment[1] / 100.0,
				'text': segment[2],
				'id': len(result['segments'])
			}
			result['text'].append( segment[2] )
			result['segments'].append(reparsed)

		result['text'] = " ".join(result['text'])
		return result

	if args.whisper_backend == "m-bain/whisperx":
		import whisperx

		device = "cuda" if get_device_name() == "cuda" else "cpu"
		result = whisper_model.transcribe(file, batch_size=args.whisper_batchsize)
			
		align_model, metadata = whisper_align_model
		result_aligned = whisperx.align(result["segments"], align_model, metadata, file, device, return_char_alignments=False)

		result['segments'] = result_aligned['segments']
		result['text'] = []
		for segment in result['segments']:
			segment['id'] = len(result['text'])
			result['text'].append(segment['text'].strip())
		result['text'] = " ".join(result['text'])

		return result

def validate_waveform( waveform, sample_rate, min_only=False ):
	if not torch.any(waveform < 0):
		return "Waveform is empty"

	num_channels, num_frames = waveform.shape
	duration = num_frames / sample_rate
	
	if duration < MIN_TRAINING_DURATION:
		return "Duration too short ({:.3f}s < {:.3f}s)".format(duration, MIN_TRAINING_DURATION)

	if not min_only:
		if duration > MAX_TRAINING_DURATION:
			return "Duration too long ({:.3f}s < {:.3f}s)".format(MAX_TRAINING_DURATION, duration)

	return

def transcribe_dataset( voice, language=None, skip_existings=False, progress=None ):
	unload_tts()

	global whisper_model
	if whisper_model is None:
		load_whisper_model(language=language)

	results = {}

	files = get_voice(voice, load_latents=False)
	indir = f'./training/{voice}/'
	infile = f'{indir}/whisper.json'

	quantize_in_memory = args.tts_backend == "vall-e"
	
	os.makedirs(f'{indir}/audio/', exist_ok=True)
	
	TARGET_SAMPLE_RATE = 22050
	if args.tts_backend != "tortoise":
		TARGET_SAMPLE_RATE = 24000
	if tts:
		TARGET_SAMPLE_RATE = tts.input_sample_rate

	if os.path.exists(infile):
		results = json.load(open(infile, 'r', encoding="utf-8"))

	for file in tqdm(files, desc="Iterating through voice files"):
		basename = os.path.basename(file)

		if basename in results and skip_existings:
			print(f"Skipping already parsed file: {basename}")
			continue

		try:
			result = whisper_transcribe(file, language=language)
		except Exception as e:
			print("Failed to transcribe:", file, e)
			continue

		results[basename] = result

		if not quantize_in_memory:
			waveform, sample_rate = torchaudio.load(file)
			# resample to the input rate, since it'll get resampled for training anyways
			# this should also "help" increase throughput a bit when filling the dataloaders
			waveform, sample_rate = resample(waveform, sample_rate, TARGET_SAMPLE_RATE)
			if waveform.shape[0] == 2:
				waveform = waveform[:1]
			
			try:
				kwargs = {}
				if basename[-4:] == ".wav":
					kwargs['encoding'] = "PCM_S"
					kwargs['bits_per_sample'] = 16

				torchaudio.save(f"{indir}/audio/{basename}", waveform, sample_rate, **kwargs)
			except Exception as e:
				print(e)

		with open(infile, 'w', encoding="utf-8") as f:
			f.write(json.dumps(results, indent='\t'))

		do_gc()

	modified = False
	for basename in results:
		try:
			sanitized = whisper_sanitize(results[basename])
			if len(sanitized['segments']) > 0 and len(sanitized['segments']) != len(results[basename]['segments']):
				results[basename] = sanitized
				modified = True
				print("Segments sanizited: ", basename)
		except Exception as e:
			print("Failed to sanitize:", basename, e)
			pass

	if modified:
		os.rename(infile, infile.replace(".json", ".unsanitized.json"))
		with open(infile, 'w', encoding="utf-8") as f:
			f.write(json.dumps(results, indent='\t'))

	return f"Processed dataset to: {indir}"

def slice_waveform( waveform, sample_rate, start, end, trim ):
	start = int(start * sample_rate)
	end = int(end * sample_rate)

	if start < 0:
		start = 0
	if end >= waveform.shape[-1]:
		end = waveform.shape[-1] - 1

	sliced = waveform[:, start:end]

	error = validate_waveform( sliced, sample_rate, min_only=True )
	if trim and not error:
		sliced = torchaudio.functional.vad( sliced, sample_rate )

	return sliced, error

def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, results=None, progress=gr.Progress() ):
	indir = f'./training/{voice}/'
	infile = f'{indir}/whisper.json'
	messages = []

	if not os.path.exists(infile):
		message = f"Missing dataset: {infile}"
		print(message)
		return message

	if results is None:
		results = json.load(open(infile, 'r', encoding="utf-8"))

	TARGET_SAMPLE_RATE = 22050
	if args.tts_backend != "tortoise":
		TARGET_SAMPLE_RATE = 24000
	if tts:
		TARGET_SAMPLE_RATE = tts.input_sample_rate

	files = 0
	segments = 0
	for filename in results:
		path = f'./voices/{voice}/{filename}'
		extension = os.path.splitext(filename)[-1][1:]
		out_extension = extension # "wav"

		if not os.path.exists(path):
			path = f'./training/{voice}/{filename}'

		if not os.path.exists(path):
			message = f"Missing source audio: {filename}"
			print(message)
			messages.append(message)
			continue

		files += 1
		result = results[filename]
		waveform, sample_rate = torchaudio.load(path)
		num_channels, num_frames = waveform.shape
		duration = num_frames / sample_rate

		for segment in result['segments']: 
			file = filename.replace(f".{extension}", f"_{pad(segment['id'], 4)}.{out_extension}")
			
			sliced, error = slice_waveform( waveform, sample_rate, segment['start'] + start_offset, segment['end'] + end_offset, trim_silence )
			if error:
				message = f"{error}, skipping... {file}"
				print(message)
				messages.append(message)
				continue
		
			sliced, _ = resample( sliced, sample_rate, TARGET_SAMPLE_RATE )

			if waveform.shape[0] == 2:
				waveform = waveform[:1]
				
			kwargs = {}
			if file[-4:] == ".wav":
				kwargs['encoding'] = "PCM_S"
				kwargs['bits_per_sample'] = 16

			torchaudio.save(f"{indir}/audio/{file}", sliced, TARGET_SAMPLE_RATE, **kwargs)
			
			segments +=1

	messages.append(f"Sliced segments: {files} => {segments}.")
	return "\n".join(messages)

# takes an LJSpeech-dataset-formatted .txt file and phonemize it
def phonemize_txt_file( path ):
	with open(path, 'r', encoding='utf-8') as f:
		lines = f.readlines()

	reparsed = []
	with open(path.replace(".txt", ".phn.txt"), 'a', encoding='utf-8') as f:
		for line in tqdm(lines, desc='Phonemizing...'):
			split = line.split("|")
			audio = split[0]
			text = split[2]

			phonemes = phonemizer( text )
			reparsed.append(f'{audio}|{phonemes}')
			f.write(f'\n{audio}|{phonemes}')
	

	joined = "\n".join(reparsed)
	with open(path.replace(".txt", ".phn.txt"), 'w', encoding='utf-8') as f:
		f.write(joined)

	return joined

# takes an LJSpeech-dataset-formatted .txt (and phonemized .phn.txt from the above) and creates a JSON that should slot in as whisper.json
def create_dataset_json( path ):
	with open(path, 'r', encoding='utf-8') as f:
		lines = f.readlines()

	phonemes = None
	phn_path = path.replace(".txt", ".phn.txt")
	if os.path.exists(phn_path):
		with open(phn_path, 'r', encoding='utf-8') as f:
			phonemes = f.readlines()

	data = {}

	for line in lines:
		split = line.split("|")
		audio = split[0]
		text = split[1]

		data[audio] = {
			'text': text.strip()
		}

	for line in phonemes:
		split = line.split("|")
		audio = split[0]
		text = split[1]

		data[audio]['phonemes'] = text.strip()

	with open(path.replace(".txt", ".json"), 'w', encoding='utf-8') as f:
		f.write(json.dumps(data, indent="\t"))


cached_backends = {}

def phonemizer( text, language="en-us" ):
	from phonemizer import phonemize
	from phonemizer.backend import BACKENDS

	def _get_backend( language="en-us", backend="espeak" ):
		key = f'{language}_{backend}'
		if key in cached_backends:
			return cached_backends[key]

		if backend == 'espeak':
			phonemizer = BACKENDS[backend]( language, preserve_punctuation=True, with_stress=True)
		elif backend == 'espeak-mbrola':
			phonemizer = BACKENDS[backend]( language )
		else: 
			phonemizer = BACKENDS[backend]( language, preserve_punctuation=True )

		cached_backends[key] = phonemizer
		return phonemizer
	if language == "en":
		language = "en-us"

	backend = _get_backend(language=language, backend=args.phonemizer_backend)
	if backend is not None:
		tokens = backend.phonemize( [text], strip=True )
	else:
		tokens = phonemize( [text], language=language, strip=True, preserve_punctuation=True, with_stress=True )

	return tokens[0] if len(tokens) == 0 else tokens
	tokenized = " ".join( tokens )

def should_phonemize():
	should = args.tokenizer_json is not None and args.tokenizer_json[-8:] == "ipa.json"
	if should:
		try:
			from phonemizer import phonemize
		except Exception as e:
			return False
	return should

def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, progress=gr.Progress() ):
	indir = f'./training/{voice}/'
	infile = f'{indir}/whisper.json'
	if not os.path.exists(infile):
		message = f"Missing dataset: {infile}"
		print(message)
		return message

	results = json.load(open(infile, 'r', encoding="utf-8"))

	errored = 0
	messages = []
	normalize = False # True
	phonemize = should_phonemize()
	lines = { 'training': [], 'validation': [] }
	segments = {}

	quantize_in_memory = args.tts_backend == "vall-e"

	if args.tts_backend != "tortoise":
		text_length = 0
		audio_length = 0

	start_offset = -0.1
	end_offset = 0.1
	trim_silence = False

	TARGET_SAMPLE_RATE = 22050
	if args.tts_backend != "tortoise":
		TARGET_SAMPLE_RATE = 24000
	if tts:
		TARGET_SAMPLE_RATE = tts.input_sample_rate

	for filename in tqdm(results, desc="Parsing results"):
		use_segment = use_segments

		extension = os.path.splitext(filename)[-1][1:]
		out_extension = extension # "wav"
		result = results[filename]
		lang = result['language']
		language = LANGUAGES[lang] if lang in LANGUAGES else lang
		normalizer = EnglishTextNormalizer() if language and language == "english" else BasicTextNormalizer()

		# check if unsegmented text exceeds 200 characters
		if not use_segment:
			if len(result['text']) > MAX_TRAINING_CHAR_LENGTH:
				message = f"Text length too long ({MAX_TRAINING_CHAR_LENGTH} < {len(result['text'])}), using segments: {filename}"
				print(message)
				messages.append(message)
				use_segment = True

		# check if unsegmented audio exceeds 11.6s
		if not use_segment:
			path = f'{indir}/audio/{filename}'
			if not quantize_in_memory and not os.path.exists(path):
				messages.append(f"Missing source audio: {filename}")
				errored += 1
				continue

			duration = 0
			for segment in result['segments']:
				duration = max(duration, result['segments'][segment]['end'])

			if duration >= MAX_TRAINING_DURATION:
				message = f"Audio too large, using segments: {filename}"
				print(message)
				messages.append(message)
				use_segment = True

		# implicitly segment
		if use_segment and not use_segments:
			exists = True
			for segment in result['segments']:
				duration = segment['end'] - segment['start']
				if duration <= MIN_TRAINING_DURATION or MAX_TRAINING_DURATION <= duration:
					continue

				path = f'{indir}/audio/' + filename.replace(f".{extension}", f"_{pad(segment['id'], 4)}.{out_extension}")
				if os.path.exists(path):
					continue
				exists = False
				break

			if not quantize_in_memory and not exists:
				tmp = {}
				tmp[filename] = result
				print(f"Audio not segmented, segmenting: {filename}")
				message = slice_dataset( voice, results=tmp )
				print(message)
				messages = messages + message.split("\n")
		
		waveform = None
		

		if quantize_in_memory:
			path = f'{indir}/audio/{filename}'
			if not os.path.exists(path):
				path = f'./voices/{voice}/{filename}'

			if not os.path.exists(path):
				message = f"Audio not found: {path}"
				print(message)
				messages.append(message)
				#continue
			else:
				waveform = torchaudio.load(path)
				waveform = resample(waveform[0], waveform[1], TARGET_SAMPLE_RATE)

		if not use_segment:
			segments[filename] = {
				'text': result['text'],
				'lang': lang,
				'language': language,
				'normalizer': normalizer,
				'phonemes': result['phonemes'] if 'phonemes' in result else None
			}

			if waveform:
				segments[filename]['waveform'] = waveform
		else:
			for segment in result['segments']:
				duration = segment['end'] - segment['start']
				if duration <= MIN_TRAINING_DURATION or MAX_TRAINING_DURATION <= duration:
					continue

				file = filename.replace(f".{extension}", f"_{pad(segment['id'], 4)}.{out_extension}")

				segments[file] = {
					'text': segment['text'],
					'lang': lang,
					'language': language,
					'normalizer': normalizer,
					'phonemes': segment['phonemes'] if 'phonemes' in segment else None
				}

				if waveform:
					sliced, error = slice_waveform( waveform[0], waveform[1], segment['start'] + start_offset, segment['end'] + end_offset, trim_silence )
					if error:
						message = f"{error}, skipping... {file}"
						print(message)
						messages.append(message)
						segments[file]['error'] = error
						#continue
					else:
						segments[file]['waveform'] = (sliced, waveform[1])

	jobs = {
		'quantize':  [[], []],
		'phonemize': [[], []],
	}

	for file in tqdm(segments, desc="Parsing segments"):
		extension = os.path.splitext(file)[-1][1:]
		result = segments[file]
		path = f'{indir}/audio/{file}'

		text = result['text']
		lang = result['lang']
		language = result['language']
		normalizer = result['normalizer']
		phonemes = result['phonemes']
		if phonemize and phonemes is None:
			phonemes = phonemizer( text, language=lang )
		
		normalized = normalizer(text) if normalize else text

		if len(text) > MAX_TRAINING_CHAR_LENGTH:
			message = f"Text length too long ({MAX_TRAINING_CHAR_LENGTH} < {len(text)}), skipping... {file}"
			print(message)
			messages.append(message)
			errored += 1
			continue

		# num_channels, num_frames = waveform.shape
		#duration = num_frames / sample_rate


		culled = len(text) < text_length
		#if not culled and audio_length > 0:
		#	culled = duration < audio_length

		line = f'audio/{file}|{phonemes if phonemize and phonemes else text}'

		lines['training' if not culled else 'validation'].append(line) 

		if culled or args.tts_backend != "vall-e":
			continue
		
		os.makedirs(f'{indir}/valle/', exist_ok=True)

		#phn_file = f'{indir}/valle/{file.replace(f".{extension}",".phn.txt")}'
		phn_file = f'./training/valle/data/{voice}/{file.replace(f".{extension}",".phn.txt")}'
		if not os.path.exists(phn_file):
			jobs['phonemize'][0].append(phn_file)
			jobs['phonemize'][1].append(normalized)
			"""
			phonemized = valle_phonemize( normalized )
			open(f'{indir}/valle/{file.replace(".wav",".phn.txt")}', 'w', encoding='utf-8').write(" ".join(phonemized))
			print("Phonemized:", file, normalized, text)
			"""

		#qnt_file = f'{indir}/valle/{file.replace(f".{extension}",".qnt.pt")}'
		qnt_file = f'./training/valle/data/{voice}/{file.replace(f".{extension}",".qnt.pt")}'
		if 'error' not in result:
			if not quantize_in_memory and not os.path.exists(path):
				message = f"Missing segment, skipping... {file}"
				print(message)
				messages.append(message)
				errored += 1
				continue

		if not os.path.exists(qnt_file):
			waveform = None
			if 'waveform' in result:
				waveform, sample_rate = result['waveform']
			elif os.path.exists(path):
				waveform, sample_rate = torchaudio.load(path)
				error = validate_waveform( waveform, sample_rate )
				if error:
					message = f"{error}, skipping... {file}"
					print(message)
					messages.append(message)
					errored += 1
					continue

			if waveform is not None:
				jobs['quantize'][0].append(qnt_file)
				jobs['quantize'][1].append((waveform, sample_rate))
				"""
				quantized = valle_quantize( waveform, sample_rate ).cpu()
				torch.save(quantized, f'{indir}/valle/{file.replace(".wav",".qnt.pt")}')
				print("Quantized:", file)
				"""

	for i in tqdm(range(len(jobs['quantize'][0])), desc="Quantizing"):
		qnt_file = jobs['quantize'][0][i]
		waveform, sample_rate = jobs['quantize'][1][i]

		quantized = valle_quantize( waveform, sample_rate ).cpu()
		torch.save(quantized, qnt_file)
		#print("Quantized:", qnt_file)

	for i in tqdm(range(len(jobs['phonemize'][0])), desc="Phonemizing"):
		phn_file = jobs['phonemize'][0][i]
		normalized = jobs['phonemize'][1][i]

		try:
			phonemized = valle_phonemize( normalized )
			open(phn_file, 'w', encoding='utf-8').write(" ".join(phonemized))
			#print("Phonemized:", phn_file)
		except Exception as e:
			message = f"Failed to phonemize: {phn_file}: {normalized}"
			messages.append(message)
			print(message)


	training_joined = "\n".join(lines['training'])
	validation_joined = "\n".join(lines['validation'])

	with open(f'{indir}/train.txt', 'w', encoding="utf-8") as f:
		f.write(training_joined)

	with open(f'{indir}/validation.txt', 'w', encoding="utf-8") as f:
		f.write(validation_joined)

	messages.append(f"Prepared {len(lines['training'])} lines (validation: {len(lines['validation'])}, culled: {errored}).\n{training_joined}\n\n{validation_joined}")
	return "\n".join(messages)

def calc_iterations( epochs, lines, batch_size ):
	return int(math.ceil(epochs * math.ceil(lines / batch_size)))

def schedule_learning_rate( iterations, schedule=LEARNING_RATE_SCHEDULE ):
	return [int(iterations * d) for d in schedule]

def optimize_training_settings( **kwargs ):
	messages = []
	settings = {}
	settings.update(kwargs)

	dataset_path = f"./training/{settings['voice']}/train.txt"
	with open(dataset_path, 'r', encoding="utf-8") as f:
		lines = len(f.readlines())

	if lines == 0:
		raise Exception("Empty dataset.")

	if settings['batch_size'] > lines:
		settings['batch_size'] = lines
		messages.append(f"Batch size is larger than your dataset, clamping batch size to: {settings['batch_size']}")	

	"""
	if lines % settings['batch_size'] != 0:
		settings['batch_size'] = int(lines / settings['batch_size'])
		if settings['batch_size'] == 0:
			settings['batch_size'] = 1
		messages.append(f"Batch size not neatly divisible by dataset size, adjusting batch size to: {settings['batch_size']}")
	"""
	if settings['gradient_accumulation_size'] == 0:
		settings['gradient_accumulation_size'] = 1
	
	if settings['batch_size'] / settings['gradient_accumulation_size'] < 2:
		settings['gradient_accumulation_size'] = int(settings['batch_size'] / 2)
		if settings['gradient_accumulation_size'] == 0:
			settings['gradient_accumulation_size'] = 1

		messages.append(f"Gradient accumulation size is too large for a given batch size, clamping gradient accumulation size to: {settings['gradient_accumulation_size']}")
	elif settings['batch_size'] % settings['gradient_accumulation_size'] != 0:
		settings['gradient_accumulation_size'] -= settings['batch_size'] % settings['gradient_accumulation_size']
		if settings['gradient_accumulation_size'] == 0:
			settings['gradient_accumulation_size'] = 1

		messages.append(f"Batch size is not evenly divisible by the gradient accumulation size, adjusting gradient accumulation size to: {settings['gradient_accumulation_size']}")

	if settings['batch_size'] % settings['gpus'] != 0:
		settings['batch_size'] -= settings['batch_size'] % settings['gpus']
		if settings['batch_size'] == 0:
			settings['batch_size'] = 1
		messages.append(f"Batch size not neatly divisible by GPU count, adjusting batch size to: {settings['batch_size']}")


	def get_device_batch_size( vram ):
		DEVICE_BATCH_SIZE_MAP = [
			(70, 128), # based on an A100-80G, I can safely get a ratio of 4096:32 = 128
			(32, 64), # based on my two 6800XTs, I can only really safely get a ratio of 128:2 = 64
			(16, 8), # based on an A4000, I can do a ratio of 512:64 = 8:1
			(8, 4), # interpolated
			(6, 2), # based on my 2060, it only really lets me have a batch ratio of 2:1
		]
		for k, v in DEVICE_BATCH_SIZE_MAP:
			if vram > (k-1):
				return v
		return 1

	if settings['gpus'] > get_device_count():
		settings['gpus'] = get_device_count()
		messages.append(f"GPU count exceeds defacto GPU count, clamping to: {settings['gpus']}")

	if settings['gpus'] <= 1:
		settings['gpus'] = 1
	else:
		messages.append(f"! EXPERIMENTAL ! Multi-GPU training is extremely particular, expect issues.")

	# assuming you have equal GPUs
	vram = get_device_vram() * settings['gpus']
	batch_ratio = int(settings['batch_size'] / settings['gradient_accumulation_size'])
	batch_cap = get_device_batch_size(vram)

	if batch_ratio > batch_cap:
		settings['gradient_accumulation_size'] = int(settings['batch_size'] / batch_cap)
		messages.append(f"Batch ratio ({batch_ratio}) is expected to exceed your VRAM capacity ({'{:.3f}'.format(vram)}GB, suggested {batch_cap} batch size cap), adjusting gradient accumulation size to: {settings['gradient_accumulation_size']}")

	iterations = calc_iterations(epochs=settings['epochs'], lines=lines, batch_size=settings['batch_size'])

	if settings['epochs'] < settings['save_rate']:
		settings['save_rate'] = settings['epochs']
		messages.append(f"Save rate is too small for the given iteration step, clamping save rate to: {settings['save_rate']}")

	if settings['epochs'] < settings['validation_rate']:
		settings['validation_rate'] = settings['epochs']
		messages.append(f"Validation rate is too small for the given iteration step, clamping validation rate to: {settings['validation_rate']}")

	if settings['resume_state'] and not os.path.exists(settings['resume_state']):
		settings['resume_state'] = None
		messages.append("Resume path specified, but does not exist. Disabling...")

	if settings['bitsandbytes']:
		messages.append("! EXPERIMENTAL ! BitsAndBytes requested.")

	if settings['half_p']:
		if settings['bitsandbytes']:
			settings['half_p'] = False
			messages.append("Half Precision requested, but BitsAndBytes is also requested. Due to redundancies, disabling half precision...")
		else:
			messages.append("! EXPERIMENTAL ! Half Precision requested.")
			if not os.path.exists(get_halfp_model_path()):
				convert_to_halfp()	

	steps = int(iterations / settings['epochs'])

	messages.append(f"For {settings['epochs']} epochs with {lines} lines in batches of {settings['batch_size']}, iterating for {iterations} steps ({steps}) steps per epoch)")

	return settings, messages

def save_training_settings( **kwargs ):
	messages = []
	settings = {}
	settings.update(kwargs)
	

	outjson = f'./training/{settings["voice"]}/train.json'
	with open(outjson, 'w', encoding="utf-8") as f:
		f.write(json.dumps(settings, indent='\t') )

	settings['dataset_path'] = f"./training/{settings['voice']}/train.txt"
	settings['validation_path'] = f"./training/{settings['voice']}/validation.txt"

	with open(settings['dataset_path'], 'r', encoding="utf-8") as f:
		lines = len(f.readlines())

	settings['iterations'] = calc_iterations(epochs=settings['epochs'], lines=lines, batch_size=settings['batch_size'])

	if not settings['source_model'] or settings['source_model'] == "auto":
		settings['source_model'] = f"./models/tortoise/autoregressive{'_half' if settings['half_p'] else ''}.pth"

	if settings['half_p']:
		if not os.path.exists(get_halfp_model_path()):
			convert_to_halfp()

	messages.append(f"For {settings['epochs']} epochs with {lines} lines, iterating for {settings['iterations']} steps")

	iterations_per_epoch = settings['iterations'] / settings['epochs']

	settings['save_rate'] = int(settings['save_rate'] * iterations_per_epoch)
	settings['validation_rate'] = int(settings['validation_rate'] * iterations_per_epoch)

	iterations_per_epoch = int(iterations_per_epoch)
	
	if settings['save_rate'] < 1:
		settings['save_rate'] = 1
	"""
	if settings['validation_rate'] < 1:
		settings['validation_rate'] = 1
	"""
	"""
	if settings['iterations'] % settings['save_rate'] != 0:
		adjustment = int(settings['iterations'] / settings['save_rate']) * settings['save_rate']
		messages.append(f"Iteration rate is not evenly divisible by save rate, adjusting: {settings['iterations']} => {adjustment}")
		settings['iterations'] = adjustment
	"""

	settings['validation_batch_size'] = int(settings['batch_size'] / settings['gradient_accumulation_size'])
	if not os.path.exists(settings['validation_path']):
		settings['validation_enabled'] = False
		messages.append("Validation not found, disabling validation...")
	elif settings['validation_batch_size'] == 0:
		settings['validation_enabled'] = False
		messages.append("Validation batch size == 0, disabling validation...")
	else:
		with open(settings['validation_path'], 'r', encoding="utf-8") as f:
			validation_lines = len(f.readlines())

		if validation_lines < settings['validation_batch_size']:
			settings['validation_batch_size'] = validation_lines
			messages.append(f"Batch size exceeds validation dataset size, clamping validation batch size to {validation_lines}")

	settings['tokenizer_json'] = args.tokenizer_json if args.tokenizer_json else get_tokenizer_jsons()[0]

	if settings['gpus'] > get_device_count():
		settings['gpus'] = get_device_count()

	# what an utter mistake this was
	settings['optimizer'] = 'adamw' # if settings['gpus'] == 1 else 'adamw_zero'

	if 'learning_rate_scheme' not in settings or settings['learning_rate_scheme'] not in LEARNING_RATE_SCHEMES:
		settings['learning_rate_scheme'] = "Multistep"

	settings['learning_rate_scheme'] = LEARNING_RATE_SCHEMES[settings['learning_rate_scheme']]

	learning_rate_schema = [f"default_lr_scheme: {settings['learning_rate_scheme']}"]
	if settings['learning_rate_scheme'] == "MultiStepLR":
		if not settings['learning_rate_schedule']:
			settings['learning_rate_schedule'] = LEARNING_RATE_SCHEDULE
		elif isinstance(settings['learning_rate_schedule'],str):
			settings['learning_rate_schedule'] = json.loads(settings['learning_rate_schedule'])

		settings['learning_rate_schedule'] = schedule_learning_rate( iterations_per_epoch, settings['learning_rate_schedule'] )

		learning_rate_schema.append(f"  gen_lr_steps: {settings['learning_rate_schedule']}")
		learning_rate_schema.append(f"  lr_gamma: 0.5")
	elif settings['learning_rate_scheme'] == "CosineAnnealingLR_Restart":
		epochs = settings['epochs']
		restarts = settings['learning_rate_restarts']
		restart_period = int(epochs / restarts)

		if 'learning_rate_warmup' not in settings:
			settings['learning_rate_warmup'] = 0
		if 'learning_rate_min' not in settings:
			settings['learning_rate_min'] = 1e-08

		if 'learning_rate_period' not in settings:
			settings['learning_rate_period'] = [ iterations_per_epoch * restart_period for x in range(epochs) ]

		settings['learning_rate_restarts'] = [ iterations_per_epoch * (x+1) * restart_period for x in range(restarts) ] # [52, 104, 156, 208]

		if 'learning_rate_restart_weights' not in settings:
			settings['learning_rate_restart_weights'] = [ ( restarts - x - 1 ) / restarts for x in range(restarts) ] # [.75, .5, .25, .125]
			settings['learning_rate_restart_weights'][-1] = settings['learning_rate_restart_weights'][-2] * 0.5

		learning_rate_schema.append(f"  T_period: {settings['learning_rate_period']}")
		learning_rate_schema.append(f"  warmup: {settings['learning_rate_warmup']}")
		learning_rate_schema.append(f"  eta_min: !!float {settings['learning_rate_min']}")
		learning_rate_schema.append(f"  restarts: {settings['learning_rate_restarts']}")
		learning_rate_schema.append(f"  restart_weights: {settings['learning_rate_restart_weights']}")
	settings['learning_rate_scheme'] = "\n".join(learning_rate_schema)

	if settings['resume_state']:
		settings['source_model'] = f"# pretrain_model_gpt: '{settings['source_model']}'"
		settings['resume_state'] = f"resume_state: '{settings['resume_state']}'"
	else:
		settings['source_model'] = f"pretrain_model_gpt: '{settings['source_model']}'"
		settings['resume_state'] = f"# resume_state: '{settings['resume_state']}'"

	def use_template(template, out):
		with open(template, 'r', encoding="utf-8") as f:
			yaml = f.read()

		# i could just load and edit the YAML directly, but this is easier, as I don't need to bother with path traversals
		for k in settings:
			if settings[k] is None:
				continue
			yaml = yaml.replace(f"${{{k}}}", str(settings[k]))

		with open(out, 'w', encoding="utf-8") as f:
			f.write(yaml)
	
	if args.tts_backend == "tortoise":
		use_template(f'./models/.template.dlas.yaml', f'./training/{settings["voice"]}/train.yaml')
	elif args.tts_backend == "vall-e":
		settings['model_name'] = "[ 'ar-quarter', 'nar-quarter' ]"
		use_template(f'./models/.template.valle.yaml', f'./training/{settings["voice"]}/config.yaml')

	messages.append(f"Saved training output")
	return settings, messages

def import_voices(files, saveAs=None, progress=None):
	global args

	if not isinstance(files, list):
		files = [files]

	for file in tqdm(files, desc="Importing voice files"):
		j, latents = read_generate_settings(file, read_latents=True)
		
		if j is not None and saveAs is None:
			saveAs = j['voice']
		if saveAs is None or saveAs == "":
			raise Exception("Specify a voice name")

		outdir = f'{get_voice_dir()}/{saveAs}/'
		os.makedirs(outdir, exist_ok=True)

		if latents:
			print(f"Importing latents to {latents}")
			with open(f'{outdir}/cond_latents.pth', 'wb') as f:
				f.write(latents)
			latents = f'{outdir}/cond_latents.pth'
			print(f"Imported latents to {latents}")
		else:
			filename = file.name
			if filename[-4:] != ".wav":
				raise Exception("Please convert to a WAV first")

			path = f"{outdir}/{os.path.basename(filename)}"
			print(f"Importing voice to {path}")

			waveform, sample_rate = torchaudio.load(filename)

			if args.voice_fixer:
				if not voicefixer:
					load_voicefixer()

				waveform, sample_rate = resample(waveform, sample_rate, 44100)
				torchaudio.save(path, waveform, sample_rate)

				print(f"Running 'voicefixer' on voice sample: {path}")
				voicefixer.restore(
					input = path,
					output = path,
					cuda=get_device_name() == "cuda" and args.voice_fixer_use_cuda,
					#mode=mode,
				)
			else:
				torchaudio.save(path, waveform, sample_rate)

			print(f"Imported voice to {path}")

def relative_paths( dirs ):
	return [ './' + os.path.relpath( d ).replace("\\", "/") for d in dirs ]

def get_voice( name, dir=get_voice_dir(), load_latents=True, extensions=["wav", "mp3", "flac"] ):
	subj = f'{dir}/{name}/'
	if not os.path.isdir(subj):
		return
	files = os.listdir(subj)
	
	if load_latents:
		extensions.append("pth")

	voice = []
	for file in files:
		ext = os.path.splitext(file)[-1][1:]
		if ext not in extensions:
			continue

		voice.append(f'{subj}/{file}') 

	return sorted( voice )

def get_voice_list(dir=get_voice_dir(), append_defaults=False, extensions=["wav", "mp3", "flac", "pth"]):
	defaults = [ "random", "microphone" ]
	os.makedirs(dir, exist_ok=True)
	#res = sorted([d for d in os.listdir(dir) if d not in defaults and os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 ])

	res = []
	for name in os.listdir(dir):
		if name in defaults:
			continue
		if not os.path.isdir(f'{dir}/{name}'):
			continue
		if len(os.listdir(os.path.join(dir, name))) == 0:
			continue
		files = get_voice( name, dir=dir, extensions=extensions )

		if len(files) > 0:
			res.append(name)
		else:
			for subdir in os.listdir(f'{dir}/{name}'):
				if not os.path.isdir(f'{dir}/{name}/{subdir}'):
					continue
				files = get_voice( f'{name}/{subdir}', dir=dir, extensions=extensions )
				if len(files) == 0:
					continue
				res.append(f'{name}/{subdir}')

	res = sorted(res)
	
	if append_defaults:
		res = res + defaults
	
	return res

def get_valle_models(dir="./training/"):
	return [ f'{dir}/{d}/config.yaml' for d in os.listdir(dir) if os.path.exists(f'{dir}/{d}/config.yaml') ]

def get_autoregressive_models(dir="./models/finetunes/", prefixed=False, auto=False):
	os.makedirs(dir, exist_ok=True)
	base = [get_model_path('autoregressive.pth')]
	halfp = get_halfp_model_path()
	if os.path.exists(halfp):
		base.append(halfp)

	additionals = sorted([f'{dir}/{d}' for d in os.listdir(dir) if d[-4:] == ".pth" ])
	found = []
	for training in os.listdir(f'./training/'):
		if not os.path.isdir(f'./training/{training}/') or not os.path.isdir(f'./training/{training}/finetune/') or not os.path.isdir(f'./training/{training}/finetune/models/'):
			continue
		models = sorted([ int(d[:-8]) for d in os.listdir(f'./training/{training}/finetune/models/') if d[-8:] == "_gpt.pth" ])
		found = found + [ f'./training/{training}/finetune/models/{d}_gpt.pth' for d in models ]

	res = base + additionals + found
	
	if prefixed:
		for i in range(len(res)):
			path = res[i]
			hash = hash_file(path)
			shorthash = hash[:8]

			res[i] = f'[{shorthash}] {path}'

	paths = relative_paths(res)
	if auto:
		paths = ["auto"] + paths 

	return paths

def get_diffusion_models(dir="./models/finetunes/", prefixed=False):
	return relative_paths([ get_model_path('diffusion_decoder.pth') ])

def get_tokenizer_jsons( dir="./models/tokenizers/" ):
	additionals = sorted([ f'{dir}/{d}' for d in os.listdir(dir) if d[-5:] == ".json" ]) if os.path.isdir(dir) else []
	return relative_paths([ "./modules/tortoise-tts/tortoise/data/tokenizer.json" ] + additionals)

def tokenize_text( text, config=None, stringed=True, skip_specials=False ):
	from tortoise.utils.tokenizer import VoiceBpeTokenizer

	if not config:
		config = args.tokenizer_json if args.tokenizer_json else get_tokenizer_jsons()[0]

	if not tts:
		tokenizer = VoiceBpeTokenizer(config)
	else:
		tokenizer = tts.tokenizer

	encoded = tokenizer.encode(text)
	decoded = tokenizer.tokenizer.decode(encoded, skip_special_tokens=skip_specials).split(" ")

	if stringed:
		return "\n".join([ str(encoded), str(decoded) ])

	return decoded

def get_dataset_list(dir="./training/"):
	return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "train.txt" in os.listdir(os.path.join(dir, d)) ])

def get_training_list(dir="./training/"):
	if args.tts_backend == "tortoise":
		return sorted([f'./training/{d}/train.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "train.yaml" in os.listdir(os.path.join(dir, d)) ])
	else:
		return sorted([f'./training/{d}/config.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "config.yaml" in os.listdir(os.path.join(dir, d)) ])

def pad(num, zeroes):
	return str(num).zfill(zeroes+1)

def curl(url):
	try:
		req = urllib.request.Request(url, headers={'User-Agent': 'Python'})
		conn = urllib.request.urlopen(req)
		data = conn.read()
		data = data.decode()
		data = json.loads(data)
		conn.close()
		return data
	except Exception as e:
		print(e)
		return None

def check_for_updates( dir = None ):
	if dir is None:
		check_for_updates("./.git/")
		check_for_updates("./.git/modules/dlas/")
		check_for_updates("./.git/modules/tortoise-tts/")
		return

	git_dir = dir
	if not os.path.isfile(f'{git_dir}/FETCH_HEAD'):
		print(f"Cannot check for updates for {dir}: not from a git repo")
		return False

	with open(f'{git_dir}/FETCH_HEAD', 'r', encoding="utf-8") as f:
		head = f.read()
	
	match = re.findall(r"^([a-f0-9]+).+?https:\/\/(.+?)\/(.+?)\/(.+?)\n", head)
	if match is None or len(match) == 0:
		print(f"Cannot check for updates for {dir}: cannot parse FETCH_HEAD")
		return False

	match = match[0]

	local = match[0]
	host = match[1]
	owner = match[2]
	repo = match[3]

	res = curl(f"https://{host}/api/v1/repos/{owner}/{repo}/branches/") #this only works for gitea instances

	if res is None or len(res) == 0:
		print(f"Cannot check for updates for {dir}: cannot fetch from remote")
		return False

	remote = res[0]["commit"]["id"]

	if remote != local:
		print(f"New version found for {dir}: {local[:8]} => {remote[:8]}")
		return True

	return False

def notify_progress(message, progress=None, verbose=True):
	if verbose:
		print(message)

	if progress is None:
		tqdm.write(message)
	else:
		progress(0, desc=message)

def get_args():
	global args
	return args

def setup_args(cli=False):
	global args

	default_arguments = {
		'share': False,
		'listen': None,
		'check-for-updates': False,
		'models-from-local-only': False,
		'low-vram': False,
		'sample-batch-size': None,
		'unsqueeze-sample-batches': False,
		'embed-output-metadata': True,
		'latents-lean-and-mean': True,
		'voice-fixer': False, # getting tired of long initialization times in a Colab for downloading a large dataset for it
		'voice-fixer-use-cuda': True,

		
		'force-cpu-for-conditioning-latents': False,
		'defer-tts-load': False,
		'device-override': None,
		'prune-nonfinal-outputs': True,
		'concurrency-count': 2,
		'autocalculate-voice-chunk-duration-size': 10,

		'output-sample-rate': 44100,
		'output-volume': 1,
		'results-folder': "./results/",
		
		'hf-token': None,
		'tts-backend': TTSES[0],
		
		'autoregressive-model': None,
		'diffusion-model': None,
		'vocoder-model': VOCODERS[-1],
		'tokenizer-json': None,

		'phonemizer-backend': 'espeak',
		
		'valle-model': None,

		'whisper-backend': 'openai/whisper',
		'whisper-model': "base",
		'whisper-batchsize': 1,

		'training-default-halfp': False,
		'training-default-bnb': True,
	}

	if os.path.isfile('./config/exec.json'):
		with open(f'./config/exec.json', 'r', encoding="utf-8") as f:
			try:
				overrides = json.load(f)
				for k in overrides:
					default_arguments[k] = overrides[k]
			except Exception as e:
				print(e)
				pass

	parser = argparse.ArgumentParser(allow_abbrev=not cli)
	parser.add_argument("--share", action='store_true', default=default_arguments['share'], help="Lets Gradio return a public URL to use anywhere")
	parser.add_argument("--listen", default=default_arguments['listen'], help="Path for Gradio to listen on")
	parser.add_argument("--check-for-updates", action='store_true', default=default_arguments['check-for-updates'], help="Checks for update on startup")
	parser.add_argument("--models-from-local-only", action='store_true', default=default_arguments['models-from-local-only'], help="Only loads models from disk, does not check for updates for models")
	parser.add_argument("--low-vram", action='store_true', default=default_arguments['low-vram'], help="Disables some optimizations that increases VRAM usage")
	parser.add_argument("--no-embed-output-metadata", action='store_false', default=not default_arguments['embed-output-metadata'], help="Disables embedding output metadata into resulting WAV files for easily fetching its settings used with the web UI (data is stored in the lyrics metadata tag)")
	parser.add_argument("--latents-lean-and-mean", action='store_true', default=default_arguments['latents-lean-and-mean'], help="Exports the bare essentials for latents.")
	parser.add_argument("--voice-fixer", action='store_true', default=default_arguments['voice-fixer'], help="Uses python module 'voicefixer' to improve audio quality, if available.")
	parser.add_argument("--voice-fixer-use-cuda", action='store_true', default=default_arguments['voice-fixer-use-cuda'], help="Hints to voicefixer to use CUDA, if available.")
	parser.add_argument("--force-cpu-for-conditioning-latents", default=default_arguments['force-cpu-for-conditioning-latents'], action='store_true', help="Forces computing conditional latents to be done on the CPU (if you constantyl OOM on low chunk counts)")
	parser.add_argument("--defer-tts-load", default=default_arguments['defer-tts-load'], action='store_true', help="Defers loading TTS model")
	parser.add_argument("--prune-nonfinal-outputs", default=default_arguments['prune-nonfinal-outputs'], action='store_true', help="Deletes non-final output files on completing a generation")
	parser.add_argument("--device-override", default=default_arguments['device-override'], help="A device string to override pass through Torch")
	parser.add_argument("--sample-batch-size", default=default_arguments['sample-batch-size'], type=int, help="Sets how many batches to use during the autoregressive samples pass")
	parser.add_argument("--unsqueeze-sample-batches", default=default_arguments['unsqueeze-sample-batches'], action='store_true', help="Unsqueezes sample batches to process one by one after sampling")
	parser.add_argument("--concurrency-count", type=int, default=default_arguments['concurrency-count'], help="How many Gradio events to process at once")
	parser.add_argument("--autocalculate-voice-chunk-duration-size", type=float, default=default_arguments['autocalculate-voice-chunk-duration-size'], help="Number of seconds to suggest voice chunk size for (for example, 100 seconds of audio at 10 seconds per chunk will suggest 10 chunks)")
	parser.add_argument("--output-sample-rate", type=int, default=default_arguments['output-sample-rate'], help="Sample rate to resample the output to (from 24KHz)")
	parser.add_argument("--output-volume", type=float, default=default_arguments['output-volume'], help="Adjusts volume of output")
	parser.add_argument("--results-folder", type=str, default=default_arguments['results-folder'], help="Sets output directory")
	
	parser.add_argument("--hf-token", type=str, default=default_arguments['hf-token'], help="HuggingFace Token")
	parser.add_argument("--tts-backend", default=default_arguments['tts-backend'], help="Specifies which TTS backend to use.")

	parser.add_argument("--autoregressive-model", default=default_arguments['autoregressive-model'], help="Specifies which autoregressive model to use for sampling.")
	parser.add_argument("--diffusion-model", default=default_arguments['diffusion-model'], help="Specifies which diffusion model to use for sampling.")
	parser.add_argument("--vocoder-model", default=default_arguments['vocoder-model'], action='store_true', help="Specifies with vocoder to use")
	parser.add_argument("--tokenizer-json", default=default_arguments['tokenizer-json'], help="Specifies which tokenizer json to use for tokenizing.")

	parser.add_argument("--phonemizer-backend", default=default_arguments['phonemizer-backend'], help="Specifies which phonemizer backend to use.")
	
	parser.add_argument("--valle-model", default=default_arguments['valle-model'], help="Specifies which VALL-E model to use for sampling.")
	
	parser.add_argument("--whisper-backend", default=default_arguments['whisper-backend'], action='store_true', help="Picks which whisper backend to use (openai/whisper, lightmare/whispercpp)")
	parser.add_argument("--whisper-model", default=default_arguments['whisper-model'], help="Specifies which whisper model to use for transcription.")
	parser.add_argument("--whisper-batchsize", type=int, default=default_arguments['whisper-batchsize'], help="Specifies batch size for WhisperX")
	
	parser.add_argument("--training-default-halfp", action='store_true', default=default_arguments['training-default-halfp'], help="Training default: halfp")
	parser.add_argument("--training-default-bnb", action='store_true', default=default_arguments['training-default-bnb'], help="Training default: bnb")
	
	parser.add_argument("--os", default="unix", help="Specifies which OS, easily")
	if cli:
		args, unknown = parser.parse_known_args()
	else:
		args = parser.parse_args()

	args.embed_output_metadata = not args.no_embed_output_metadata

	if not args.device_override:
		set_device_name(args.device_override)

	if args.sample_batch_size == 0 and get_device_batch_size() == 1:
		print("!WARNING! Automatically deduced sample batch size returned 1.")

	args.listen_host = None
	args.listen_port = None
	args.listen_path = None
	if args.listen:
		try:
			match = re.findall(r"^(?:(.+?):(\d+))?(\/.*?)?$", args.listen)[0]

			args.listen_host = match[0] if match[0] != "" else "127.0.0.1"
			args.listen_port = match[1] if match[1] != "" else None
			args.listen_path = match[2] if match[2] != "" else "/"
		except Exception as e:
			pass

	if args.listen_port is not None:
		args.listen_port = int(args.listen_port)
		if args.listen_port == 0:
			args.listen_port = None
	
	return args

def get_default_settings( hypenated=True ):
	settings = {
		'listen': None if not args.listen else args.listen,
		'share': args.share,
		'low-vram':args.low_vram,
		'check-for-updates':args.check_for_updates,
		'models-from-local-only':args.models_from_local_only,
		'force-cpu-for-conditioning-latents': args.force_cpu_for_conditioning_latents,
		'defer-tts-load': args.defer_tts_load,
		'prune-nonfinal-outputs': args.prune_nonfinal_outputs,
		'device-override': args.device_override,
		'sample-batch-size': args.sample_batch_size,
		'unsqueeze-sample-batches': args.unsqueeze_sample_batches,
		'embed-output-metadata': args.embed_output_metadata,
		'latents-lean-and-mean': args.latents_lean_and_mean,
		'voice-fixer': args.voice_fixer,
		'voice-fixer-use-cuda': args.voice_fixer_use_cuda,
		'concurrency-count': args.concurrency_count,
		'output-sample-rate': args.output_sample_rate,
		'autocalculate-voice-chunk-duration-size': args.autocalculate_voice_chunk_duration_size,
		'output-volume': args.output_volume,
		'results-folder': args.results_folder,
		
		'hf-token': args.hf_token,
		'tts-backend': args.tts_backend,

		'autoregressive-model': args.autoregressive_model,
		'diffusion-model': args.diffusion_model,
		'vocoder-model': args.vocoder_model,
		'tokenizer-json': args.tokenizer_json,

		'phonemizer-backend': args.phonemizer_backend,
		
		'valle-model': args.valle_model,

		'whisper-backend': args.whisper_backend,
		'whisper-model': args.whisper_model,
		'whisper-batchsize': args.whisper_batchsize,

		'training-default-halfp': args.training_default_halfp,
		'training-default-bnb': args.training_default_bnb,
	}

	res = {}
	for k in settings:
		res[k.replace("-", "_") if not hypenated else k] = settings[k]
	return res

def update_args( **kwargs ):
	global args

	settings = get_default_settings(hypenated=False)
	settings.update(kwargs)

	args.listen = settings['listen']
	args.share = settings['share']
	args.check_for_updates = settings['check_for_updates']
	args.models_from_local_only = settings['models_from_local_only']
	args.low_vram = settings['low_vram']
	args.force_cpu_for_conditioning_latents = settings['force_cpu_for_conditioning_latents']
	args.defer_tts_load = settings['defer_tts_load']
	args.prune_nonfinal_outputs = settings['prune_nonfinal_outputs']
	args.device_override = settings['device_override']
	args.sample_batch_size = settings['sample_batch_size']
	args.unsqueeze_sample_batches = settings['unsqueeze_sample_batches']
	args.embed_output_metadata = settings['embed_output_metadata']
	args.latents_lean_and_mean = settings['latents_lean_and_mean']
	args.voice_fixer = settings['voice_fixer']
	args.voice_fixer_use_cuda = settings['voice_fixer_use_cuda']
	args.concurrency_count = settings['concurrency_count']
	args.output_sample_rate = 44000
	args.autocalculate_voice_chunk_duration_size = settings['autocalculate_voice_chunk_duration_size']
	args.output_volume = settings['output_volume']
	args.results_folder = settings['results_folder']
	
	args.hf_token = settings['hf_token']
	args.tts_backend = settings['tts_backend']
	
	args.autoregressive_model = settings['autoregressive_model']
	args.diffusion_model = settings['diffusion_model']
	args.vocoder_model = settings['vocoder_model']
	args.tokenizer_json = settings['tokenizer_json']

	args.phonemizer_backend = settings['phonemizer_backend']
	
	args.valle_model = settings['valle_model']

	args.whisper_backend = settings['whisper_backend']
	args.whisper_model = settings['whisper_model']
	args.whisper_batchsize = settings['whisper_batchsize']

	args.training_default_halfp = settings['training_default_halfp']
	args.training_default_bnb = settings['training_default_bnb']

	save_args_settings()

def save_args_settings():
	global args
	settings = get_default_settings()

	os.makedirs('./config/', exist_ok=True)
	with open(f'./config/exec.json', 'w', encoding="utf-8") as f:
		f.write(json.dumps(settings, indent='\t') )

# super kludgy )`;
def import_generate_settings(file = None):
	if not file:
		file = "./config/generate.json"

	res = {
		'text': None,
		'delimiter': None,
		'emotion': None,
		'prompt': None,
		'voice': "random",
		'mic_audio': None,
		'voice_latents_chunks': None,
		'candidates': None,
		'seed': None,
		'num_autoregressive_samples': 16,
		'diffusion_iterations': 30,
		'temperature': 0.8,
		'diffusion_sampler': "DDIM",
		'breathing_room': 8  ,
		'cvvp_weight': 0.0,
		'top_p': 0.8,
		'diffusion_temperature': 1.0,
		'length_penalty': 1.0,
		'repetition_penalty': 2.0,
		'cond_free_k': 2.0,
		'experimentals': None,
	}

	settings, _ = read_generate_settings(file, read_latents=False)

	if settings is not None:
		res.update(settings)
	
	return res

def reset_generate_settings():
	with open(f'./config/generate.json', 'w', encoding="utf-8") as f:
		f.write(json.dumps({}, indent='\t') )
	return import_generate_settings()

def read_generate_settings(file, read_latents=True):
	j = None
	latents = None

	if isinstance(file, list) and len(file) == 1:
		file = file[0]

	try:
		if file is not None:
			if hasattr(file, 'name'):
				file = file.name

			if file[-4:] == ".wav":
					metadata = music_tag.load_file(file)
					if 'lyrics' in metadata:
						j = json.loads(str(metadata['lyrics']))
			elif file[-5:] == ".json":
				with open(file, 'r') as f:
					j = json.load(f)
	except Exception as e:
		pass

	if j is not None:
		if 'latents' in j:
			if read_latents:
				latents = base64.b64decode(j['latents'])
			del j['latents']
		

		if "time" in j:
			j["time"] = "{:.3f}".format(j["time"])



	return (
		j,
		latents,
	)

def version_check_tts( min_version ):
	global tts
	if not tts:
		raise Exception("TTS is not initialized")

	if not hasattr(tts, 'version'):
		return False

	if min_version[0] > tts.version[0]:
		return True
	if min_version[1] > tts.version[1]:
		return True
	if min_version[2] >= tts.version[2]:
		return True
	return False

def load_tts( restart=False, 
	# TorToiSe configs
	autoregressive_model=None, diffusion_model=None, vocoder_model=None, tokenizer_json=None,
	# VALL-E configs
	valle_model=None,
):
	global args
	global tts

	if restart:
		unload_tts()

	tts_loading = True
	if args.tts_backend == "tortoise":
		if autoregressive_model:
			args.autoregressive_model = autoregressive_model
		else:
			autoregressive_model = args.autoregressive_model

		if autoregressive_model == "auto":
			autoregressive_model = deduce_autoregressive_model()

		if diffusion_model:
			args.diffusion_model = diffusion_model
		else:
			diffusion_model = args.diffusion_model

		if vocoder_model:
			args.vocoder_model = vocoder_model
		else:
			vocoder_model = args.vocoder_model

		if tokenizer_json:
			args.tokenizer_json = tokenizer_json
		else:
			tokenizer_json = args.tokenizer_json

		if get_device_name() == "cpu":
			print("!!!! WARNING !!!! No GPU available in PyTorch. You may need to reinstall PyTorch.")

		print(f"Loading TorToiSe... (AR: {autoregressive_model}, diffusion: {diffusion_model}, vocoder: {vocoder_model})")
		tts = TorToise_TTS(minor_optimizations=not args.low_vram, autoregressive_model_path=autoregressive_model, diffusion_model_path=diffusion_model, vocoder_model=vocoder_model, tokenizer_json=tokenizer_json, unsqueeze_sample_batches=args.unsqueeze_sample_batches)
	elif args.tts_backend == "vall-e":
		if valle_model:
			args.valle_model = valle_model
		else:
			valle_model = args.valle_model

		print(f"Loading VALL-E... (Config: {valle_model})")
		tts = VALLE_TTS(config=args.valle_model)
	elif args.tts_backend == "bark":

		print(f"Loading Bark...")
		tts = Bark_TTS(small=args.low_vram)

	print("Loaded TTS, ready for generation.")
	tts_loading = False
	return tts

def unload_tts():
	global tts

	if tts:
		del tts
		tts = None
		print("Unloaded TTS")
	do_gc()

def reload_tts():
	unload_tts()
	load_tts()

def get_current_voice():
	global current_voice
	if current_voice:
		return current_voice

	settings, _ = read_generate_settings("./config/generate.json", read_latents=False)
	
	if settings and "voice" in settings['voice']:
		return settings["voice"]
	
	return None

def deduce_autoregressive_model(voice=None):
	if not voice:
		voice = get_current_voice()

	if voice:
		if os.path.exists(f'./models/finetunes/{voice}.pth'):
			return f'./models/finetunes/{voice}.pth'
		
		dir = f'./training/{voice}/finetune/models/'
		if os.path.isdir(dir):
			counts = sorted([ int(d[:-8]) for d in os.listdir(dir) if d[-8:] == "_gpt.pth" ])
			names = [ f'{dir}/{d}_gpt.pth' for d in counts ]
			if len(names) > 0:
				return names[-1]

	if args.autoregressive_model != "auto":
		return args.autoregressive_model

	return get_model_path('autoregressive.pth')

def update_autoregressive_model(autoregressive_model_path):
	if args.tts_backend != "tortoise":
		raise f"Unsupported backend: {args.tts_backend}"

	if autoregressive_model_path == "auto":
		autoregressive_model_path = deduce_autoregressive_model()
	else:
		match = re.findall(r'^\[[a-fA-F0-9]{8}\] (.+?)$', autoregressive_model_path)
		if match:
			autoregressive_model_path = match[0]

	if not autoregressive_model_path or not os.path.exists(autoregressive_model_path):
		print(f"Invalid model: {autoregressive_model_path}")
		return

	args.autoregressive_model = autoregressive_model_path
	save_args_settings()
	print(f'Stored autoregressive model to settings: {autoregressive_model_path}')

	global tts
	if not tts:
		if tts_loading:
			raise Exception("TTS is still initializing...")
		return
	
	if hasattr(tts, "loading") and tts.loading:
		raise Exception("TTS is still initializing...")


	if autoregressive_model_path == tts.autoregressive_model_path:
		return

	tts.load_autoregressive_model(autoregressive_model_path)

	do_gc()
	
	return autoregressive_model_path

def update_diffusion_model(diffusion_model_path):
	if args.tts_backend != "tortoise":
		raise f"Unsupported backend: {args.tts_backend}"

	match = re.findall(r'^\[[a-fA-F0-9]{8}\] (.+?)$', diffusion_model_path)
	if match:
		diffusion_model_path = match[0]

	if not diffusion_model_path or not os.path.exists(diffusion_model_path):
		print(f"Invalid model: {diffusion_model_path}")
		return

	args.diffusion_model = diffusion_model_path
	save_args_settings()
	print(f'Stored diffusion model to settings: {diffusion_model_path}')

	global tts
	if not tts:
		if tts_loading:
			raise Exception("TTS is still initializing...")
		return
	
	if hasattr(tts, "loading") and tts.loading:
		raise Exception("TTS is still initializing...")

	if diffusion_model_path == "auto":
		diffusion_model_path = deduce_diffusion_model()

	if diffusion_model_path == tts.diffusion_model_path:
		return

	tts.load_diffusion_model(diffusion_model_path)

	do_gc()
	
	return diffusion_model_path

def update_vocoder_model(vocoder_model):
	if args.tts_backend != "tortoise":
		raise f"Unsupported backend: {args.tts_backend}"

	args.vocoder_model = vocoder_model
	save_args_settings()
	print(f'Stored vocoder model to settings: {vocoder_model}')

	global tts
	if not tts:
		if tts_loading:
			raise Exception("TTS is still initializing...")
		return

	if hasattr(tts, "loading") and tts.loading:
		raise Exception("TTS is still initializing...")

	print(f"Loading model: {vocoder_model}")
	tts.load_vocoder_model(vocoder_model)
	print(f"Loaded model: {tts.vocoder_model}")

	do_gc()
	
	return vocoder_model

def update_tokenizer(tokenizer_json):
	if args.tts_backend != "tortoise":
		raise f"Unsupported backend: {args.tts_backend}"

	args.tokenizer_json = tokenizer_json
	save_args_settings()
	print(f'Stored tokenizer to settings: {tokenizer_json}')

	global tts
	if not tts:
		if tts_loading:
			raise Exception("TTS is still initializing...")
		return

	if hasattr(tts, "loading") and tts.loading:
		raise Exception("TTS is still initializing...")

	print(f"Loading tokenizer vocab: {tokenizer_json}")
	tts.load_tokenizer_json(tokenizer_json)
	print(f"Loaded tokenizer vocab: {tts.tokenizer_json}")

	do_gc()
	
	return vocoder_model

def load_voicefixer(restart=False):
	global voicefixer

	if restart:
		unload_voicefixer()

	try:
		print("Loading Voicefixer")
		from voicefixer import VoiceFixer
		voicefixer = VoiceFixer()
		print("Loaded Voicefixer")
	except Exception as e:
		print(f"Error occurred while tring to initialize voicefixer: {e}")
		if voicefixer:
			del voicefixer
		voicefixer = None

def unload_voicefixer():
	global voicefixer

	if voicefixer:
		del voicefixer
		voicefixer = None
		print("Unloaded Voicefixer")

	do_gc()

def load_whisper_model(language=None, model_name=None, progress=None):
	global whisper_model
	global whisper_align_model

	if args.whisper_backend not in WHISPER_BACKENDS:
		raise Exception(f"unavailable backend: {args.whisper_backend}")

	if not model_name:
		model_name = args.whisper_model
	else:
		args.whisper_model = model_name
		save_args_settings()

	if language and f'{model_name}.{language}' in WHISPER_SPECIALIZED_MODELS:
		model_name = f'{model_name}.{language}'
		print(f"Loading specialized model for language: {language}")

	notify_progress(f"Loading Whisper model: {model_name}", progress=progress)

	if args.whisper_backend == "openai/whisper":
		import whisper
		try:
			#is it possible for model to fit on vram but go oom later on while executing on data?
			whisper_model = whisper.load_model(model_name)
		except:
			print("Out of VRAM memory. falling back to loading Whisper on CPU.")
			whisper_model = whisper.load_model(model_name, device="cpu")
	elif args.whisper_backend == "lightmare/whispercpp":
		from whispercpp import Whisper
		if not language:
			language = 'auto'

		b_lang = language.encode('ascii')
		whisper_model = Whisper(model_name, models_dir='./models/', language=b_lang)
	elif args.whisper_backend == "m-bain/whisperx":
		import whisper, whisperx
		device = "cuda" if get_device_name() == "cuda" else "cpu"
		whisper_model = whisperx.load_model(model_name, device)
		whisper_align_model = whisperx.load_align_model(model_name="WAV2VEC2_ASR_LARGE_LV60K_960H" if language=="en" else None, language_code=language, device=device)

	print("Loaded Whisper model")

def unload_whisper():
	global whisper_model
	global whisper_align_model

	if whisper_align_model:
		del whisper_align_model
		whisper_align_model = None

	if whisper_model:
		del whisper_model
		whisper_model = None
		print("Unloaded Whisper")

	do_gc()	

# shamelessly borrowed from Voldy's Web UI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/modules/extras.py#L74
def merge_models( primary_model_name, secondary_model_name, alpha, progress=gr.Progress() ):
	key_blacklist = []

	def weighted_sum(theta0, theta1, alpha):
		return ((1 - alpha) * theta0) + (alpha * theta1)

	def read_model( filename ):
		print(f"Loading {filename}")
		return torch.load(filename)

	theta_func = weighted_sum

	theta_0 = read_model(primary_model_name)
	theta_1 = read_model(secondary_model_name)

	for key in tqdm(theta_0.keys(), desc="Merging..."):
		if key in key_blacklist:
			print("Skipping ignored key:", key)
			continue
		
		a = theta_0[key]
		b = theta_1[key]

		if a.dtype != torch.float32 and a.dtype != torch.float16:
			print("Skipping key:", key, a.dtype)
			continue

		if b.dtype != torch.float32 and b.dtype != torch.float16:
			print("Skipping key:", key, b.dtype)
			continue

		theta_0[key] = theta_func(a, b, alpha)

	del theta_1

	primary_basename = os.path.splitext(os.path.basename(primary_model_name))[0]
	secondary_basename = os.path.splitext(os.path.basename(secondary_model_name))[0]
	suffix = "{:.3f}".format(alpha)
	output_path = f'./models/finetunes/{primary_basename}_{secondary_basename}_{suffix}_merge.pth'

	torch.save(theta_0, output_path)
	message = f"Saved to {output_path}"
	print(message)
	return message