From b6440091fbfbf44deffafa82ec284dad274a8e9e Mon Sep 17 00:00:00 2001
From: mrq <mrq@ecker.tech>
Date: Wed, 26 Apr 2023 04:48:09 +0000
Subject: [PATCH] Very, very, VERY, barebones integration with Bark
 (documentation soon)

---
 src/utils.py | 456 ++++++++++++++++++++++++++++++++++++++++++++++-----
 src/webui.py |  54 +++---
 2 files changed, 450 insertions(+), 60 deletions(-)

diff --git a/src/utils.py b/src/utils.py
index 9b07d93..3cf50fe 100755
--- a/src/utils.py
+++ b/src/utils.py
@@ -31,6 +31,7 @@ import music_tag
 import gradio as gr
 import gradio.utils
 import pandas as pd
+import numpy as np
 
 from glob import glob
 from datetime import datetime
@@ -65,6 +66,7 @@ MIN_TRAINING_DURATION = 0.6
 MAX_TRAINING_DURATION = 11.6097505669
 
 VALLE_ENABLED = False
+BARK_ENABLED = False
 
 try:
 	from vall_e.emb.qnt import encode as valle_quantize
@@ -76,11 +78,98 @@ try:
 
 	VALLE_ENABLED = True
 except Exception as e:
+	if False: # args.tts_backend == "vall-e":
+		raise e
 	pass
 
 if VALLE_ENABLED:
 	TTSES.append('vall-e')
 
+try:
+	from bark.generation import SAMPLE_RATE as BARK_SAMPLE_RATE, ALLOWED_PROMPTS, preload_models, codec_decode, generate_coarse, generate_fine, generate_text_semantic, load_codec_model
+	from bark.api import generate_audio as bark_generate_audio
+	from encodec.utils import convert_audio
+
+	from scipy.io.wavfile import write as write_wav
+
+	BARK_ENABLED = True
+except Exception as e:
+	if False: # args.tts_backend == "bark":
+		raise e
+	pass
+
+if BARK_ENABLED:
+	TTSES.append('bark')
+	class Bark_TTS():
+		def __init__(self, small=False):
+			self.input_sample_rate = BARK_SAMPLE_RATE
+			self.output_sample_rate = args.output_sample_rate
+
+			preload_models(
+			    text_use_gpu=True,
+			    coarse_use_gpu=True,
+			    fine_use_gpu=True,
+			    codec_use_gpu=True,
+
+			    text_use_small=small,
+			    coarse_use_small=small,
+			    fine_use_small=small,
+			    
+			    force_reload=False
+			)
+
+		def create_voice( self, voice, device='cuda' ):
+			transcription_json = f'./training/{voice}/whisper.json'
+			if not os.path.exists(transcription_json):
+				raise f"Transcription for voice not found: {voice}"
+			
+			transcriptions = json.load(open(transcription_json, 'r', encoding="utf-8"))
+			candidates = []
+			for file in transcriptions:
+				result = transcriptions[file]
+				for segment in result['segments']:
+					entry = (
+						file.replace(".wav", f"_{pad(segment['id'], 4)}.wav"),
+						segment['end'] - segment['start'],
+						segment['text']
+					)
+					candidates.append(entry)
+
+			candidates.sort(key=lambda x: x[1])
+			candidate = random.choice(candidates)
+			audio_filepath = f'./training/{voice}/audio/{candidate[0]}'
+			text = candidate[-1]
+
+			print("Using as reference:", audio_filepath, text)
+
+			# Load and pre-process the audio waveform
+			model = load_codec_model(use_gpu=True)
+			wav, sr = torchaudio.load(audio_filepath)
+			wav = convert_audio(wav, sr, model.sample_rate, model.channels)
+			wav = wav.unsqueeze(0).to(device)
+
+			# Extract discrete codes from EnCodec
+			with torch.no_grad():
+			    encoded_frames = model.encode(wav)
+			codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze().cpu().numpy()  # [n_q, T]
+
+			# get seconds of audio
+			seconds = wav.shape[-1] / model.sample_rate
+			# generate semantic tokens
+			semantic_tokens = generate_text_semantic(text, max_gen_duration_s=seconds, top_k=50, top_p=.95, temp=0.7)
+
+			output_path = './modules/bark/bark/assets/prompts/' + voice.replace("/", "_") + '.npz'
+			np.savez(output_path, fine_prompt=codes, coarse_prompt=codes[:2, :], semantic_prompt=semantic_tokens)
+
+		def inference( self, text, voice, text_temp=0.7, waveform_temp=0.7 ):
+			if not os.path.exists('./modules/bark/bark/assets/prompts/' + voice + '.npz'):
+				self.create_voice( voice )
+			voice = voice.replace("/", "_")
+			if voice not in ALLOWED_PROMPTS:
+				ALLOWED_PROMPTS.add( voice )
+
+			return (bark_generate_audio(text, history_prompt=voice, text_temp=text_temp, waveform_temp=waveform_temp), BARK_SAMPLE_RATE)
+
 args = None
 tts = None
 tts_loading = False
@@ -96,6 +185,9 @@ training_state = None
 
 current_voice = None
 
+def cleanup_voice_name( name ):
+	return name.split("/")[-1]
+
 def resample( waveform, input_rate, output_rate=44100 ):
 	# mono-ize
 	waveform = torch.mean(waveform, dim=0, keepdim=True)
@@ -121,6 +213,291 @@ def generate(**kwargs):
 		return generate_tortoise(**kwargs)
 	if args.tts_backend == "vall-e":
 		return generate_valle(**kwargs)
+	if args.tts_backend == "bark":
+		return generate_bark(**kwargs)
+
+def generate_bark(**kwargs):
+	parameters = {}
+	parameters.update(kwargs)
+
+	voice = parameters['voice']
+	progress = parameters['progress'] if 'progress' in parameters else None
+	if parameters['seed'] == 0:
+		parameters['seed'] = None
+
+	usedSeed = parameters['seed']
+
+	global args
+	global tts
+
+	unload_whisper()
+	unload_voicefixer()
+
+	if not tts:
+		# should check if it's loading or unloaded, and load it if it's unloaded
+		if tts_loading:
+			raise Exception("TTS is still initializing...")
+		if progress is not None:
+			progress(0, "Initializing TTS...")
+		load_tts()
+	if hasattr(tts, "loading") and tts.loading:
+		raise Exception("TTS is still initializing...")
+
+	do_gc()
+
+	voice_samples = None
+	conditioning_latents = None
+	sample_voice = None
+
+	voice_cache = {}
+
+	def get_settings( override=None ):
+		settings = {
+			'voice': parameters['voice'],
+			'text_temp': float(parameters['temperature']),
+			'waveform_temp': float(parameters['temperature']),
+		}
+
+		# could be better to just do a ternary on everything above, but i am not a professional
+		selected_voice = voice
+		if override is not None:
+			if 'voice' in override:
+				selected_voice = override['voice']
+
+			for k in override:
+				if k not in settings:
+					continue
+				settings[k] = override[k]
+
+		return settings
+
+	if not parameters['delimiter']:
+		parameters['delimiter'] = "\n"
+	elif parameters['delimiter'] == "\\n":
+		parameters['delimiter'] = "\n"
+
+	if parameters['delimiter'] and parameters['delimiter'] != "" and parameters['delimiter'] in parameters['text']:
+		texts = parameters['text'].split(parameters['delimiter'])
+	else:
+		texts = split_and_recombine_text(parameters['text'])
+ 
+	full_start_time = time.time()
+ 
+	outdir = f"{args.results_folder}/{voice}/"
+	os.makedirs(outdir, exist_ok=True)
+
+	audio_cache = {}
+
+	volume_adjust = torchaudio.transforms.Vol(gain=args.output_volume, gain_type="amplitude") if args.output_volume != 1 else None
+
+	idx = 0
+	idx_cache = {}
+	for i, file in enumerate(os.listdir(outdir)):
+		filename = os.path.basename(file)
+		extension = os.path.splitext(filename)[1]
+		if extension != ".json" and extension != ".wav":
+			continue
+		match = re.findall(rf"^{cleanup_voice_name(voice)}_(\d+)(?:.+?)?{extension}$", filename)
+		if match and len(match) > 0:
+			key = int(match[0])
+			idx_cache[key] = True
+
+	if len(idx_cache) > 0:
+		keys = sorted(list(idx_cache.keys()))
+		idx = keys[-1] + 1
+
+	idx = pad(idx, 4)
+
+	def get_name(line=0, candidate=0, combined=False):
+		name = f"{idx}"
+		if combined:
+			name = f"{name}_combined"
+		elif len(texts) > 1:
+			name = f"{name}_{line}"
+		if parameters['candidates'] > 1:
+			name = f"{name}_{candidate}"
+		return name
+
+	def get_info( voice, settings = None, latents = True ):
+		info = {}
+		info.update(parameters)
+
+		info['time'] = time.time()-full_start_time
+		info['datetime'] = datetime.now().isoformat()
+
+		info['progress'] = None
+		del info['progress']
+
+		if info['delimiter'] == "\n":
+			info['delimiter'] = "\\n"
+
+		if settings is not None:
+			for k in settings:
+				if k in info:
+					info[k] = settings[k]
+		return info
+
+	INFERENCING = True
+	for line, cut_text in enumerate(texts):	
+		progress.msg_prefix = f'[{str(line+1)}/{str(len(texts))}]'
+		print(f"{progress.msg_prefix} Generating line: {cut_text}")
+		start_time = time.time()
+
+		# do setting editing
+		match = re.findall(r'^(\{.+\}) (.+?)$', cut_text) 
+		override = None
+		if match and len(match) > 0:
+			match = match[0]
+			try:
+				override = json.loads(match[0])
+				cut_text = match[1].strip()
+			except Exception as e:
+				raise Exception("Prompt settings editing requested, but received invalid JSON")
+
+		settings = get_settings( override=override )
+
+		gen = tts.inference(cut_text, **settings )
+
+		run_time = time.time()-start_time
+		print(f"Generating line took {run_time} seconds")
+
+		if not isinstance(gen, list):
+			gen = [gen]
+
+		for j, g in enumerate(gen):
+			wav, sr = g
+			name = get_name(line=line, candidate=j)
+
+			settings['text'] = cut_text
+			settings['time'] = run_time
+			settings['datetime'] = datetime.now().isoformat()
+
+			# save here in case some error happens mid-batch
+			#torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu(), sr)
+			write_wav(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', sr, wav)
+			wav, sr = torchaudio.load(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
+
+			audio_cache[name] = {
+				'audio': wav,
+				'settings': get_info(voice=override['voice'] if override and 'voice' in override else voice, settings=settings)
+			}
+
+	del gen
+	do_gc()
+	INFERENCING = False
+
+	for k in audio_cache:
+		audio = audio_cache[k]['audio']
+
+		audio, _ = resample(audio, tts.output_sample_rate, args.output_sample_rate)
+		if volume_adjust is not None:
+			audio = volume_adjust(audio)
+
+		audio_cache[k]['audio'] = audio
+		torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{k}.wav', audio, args.output_sample_rate)
+
+	output_voices = []
+	for candidate in range(parameters['candidates']):
+		if len(texts) > 1:
+			audio_clips = []
+			for line in range(len(texts)):
+				name = get_name(line=line, candidate=candidate)
+				audio = audio_cache[name]['audio']
+				audio_clips.append(audio)
+			
+			name = get_name(candidate=candidate, combined=True)
+			audio = torch.cat(audio_clips, dim=-1)
+			torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', audio, args.output_sample_rate)
+
+			audio = audio.squeeze(0).cpu()
+			audio_cache[name] = {
+				'audio': audio,
+				'settings': get_info(voice=voice),
+				'output': True
+			}
+		else:
+			name = get_name(candidate=candidate)
+			audio_cache[name]['output'] = True
+
+
+	if args.voice_fixer:
+		if not voicefixer:
+			progress(0, "Loading voicefix...")
+			load_voicefixer()
+
+		try:
+			fixed_cache = {}
+			for name in progress.tqdm(audio_cache, desc="Running voicefix..."):
+				del audio_cache[name]['audio']
+				if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
+					continue
+
+				path = f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav'
+				fixed = f'{outdir}/{cleanup_voice_name(voice)}_{name}_fixed.wav'
+				voicefixer.restore(
+					input=path,
+					output=fixed,
+					cuda=get_device_name() == "cuda" and args.voice_fixer_use_cuda,
+					#mode=mode,
+				)
+				
+				fixed_cache[f'{name}_fixed'] = {
+					'settings': audio_cache[name]['settings'],
+					'output': True
+				}
+				audio_cache[name]['output'] = False
+			
+			for name in fixed_cache:
+				audio_cache[name] = fixed_cache[name]
+		except Exception as e:
+			print(e)
+			print("\nFailed to run Voicefixer")
+
+	for name in audio_cache:
+		if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
+			if args.prune_nonfinal_outputs:
+				audio_cache[name]['pruned'] = True
+				os.remove(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
+			continue
+
+		output_voices.append(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
+
+		if not args.embed_output_metadata:
+			with open(f'{outdir}/{cleanup_voice_name(voice)}_{name}.json', 'w', encoding="utf-8") as f:
+				f.write(json.dumps(audio_cache[name]['settings'], indent='\t') )
+
+	if args.embed_output_metadata:
+		for name in progress.tqdm(audio_cache, desc="Embedding metadata..."):
+			if 'pruned' in audio_cache[name] and audio_cache[name]['pruned']:
+				continue
+
+			metadata = music_tag.load_file(f"{outdir}/{cleanup_voice_name(voice)}_{name}.wav")
+			metadata['lyrics'] = json.dumps(audio_cache[name]['settings'])
+			metadata.save()
+ 
+	if sample_voice is not None:
+		sample_voice = (tts.input_sample_rate, sample_voice.numpy())
+
+	info = get_info(voice=voice, latents=False)
+	print(f"Generation took {info['time']} seconds, saved to '{output_voices[0]}'\n")
+
+	info['seed'] = usedSeed
+	if 'latents' in info:
+		del info['latents']
+
+	os.makedirs('./config/', exist_ok=True)
+	with open(f'./config/generate.json', 'w', encoding="utf-8") as f:
+		f.write(json.dumps(info, indent='\t') )
+
+	stats = [
+		[ parameters['seed'], "{:.3f}".format(info['time']) ]
+	]
+
+	return (
+		sample_voice,
+		output_voices,
+		stats,
+	)
 
 def generate_valle(**kwargs):
 	parameters = {}
@@ -289,9 +666,9 @@ def generate_valle(**kwargs):
 			settings['datetime'] = datetime.now().isoformat()
 
 			# save here in case some error happens mid-batch
-			#torchaudio.save(f'{outdir}/{voice}_{name}.wav', wav.cpu(), sr)
-			soundfile.write(f'{outdir}/{voice}_{name}.wav', wav.cpu()[0,0], sr)
-			wav, sr = torchaudio.load(f'{outdir}/{voice}_{name}.wav')
+			#torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu(), sr)
+			soundfile.write(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu()[0,0], sr)
+			wav, sr = torchaudio.load(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
 
 			audio_cache[name] = {
 				'audio': wav,
@@ -310,7 +687,7 @@ def generate_valle(**kwargs):
 			audio = volume_adjust(audio)
 
 		audio_cache[k]['audio'] = audio
-		torchaudio.save(f'{outdir}/{voice}_{k}.wav', audio, args.output_sample_rate)
+		torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{k}.wav', audio, args.output_sample_rate)
 
 	output_voices = []
 	for candidate in range(parameters['candidates']):
@@ -323,7 +700,7 @@ def generate_valle(**kwargs):
 			
 			name = get_name(candidate=candidate, combined=True)
 			audio = torch.cat(audio_clips, dim=-1)
-			torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, args.output_sample_rate)
+			torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', audio, args.output_sample_rate)
 
 			audio = audio.squeeze(0).cpu()
 			audio_cache[name] = {
@@ -348,8 +725,8 @@ def generate_valle(**kwargs):
 				if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
 					continue
 
-				path = f'{outdir}/{voice}_{name}.wav'
-				fixed = f'{outdir}/{voice}_{name}_fixed.wav'
+				path = f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav'
+				fixed = f'{outdir}/{cleanup_voice_name(voice)}_{name}_fixed.wav'
 				voicefixer.restore(
 					input=path,
 					output=fixed,
@@ -373,13 +750,13 @@ def generate_valle(**kwargs):
 		if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
 			if args.prune_nonfinal_outputs:
 				audio_cache[name]['pruned'] = True
-				os.remove(f'{outdir}/{voice}_{name}.wav')
+				os.remove(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
 			continue
 
-		output_voices.append(f'{outdir}/{voice}_{name}.wav')
+		output_voices.append(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
 
 		if not args.embed_output_metadata:
-			with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f:
+			with open(f'{outdir}/{cleanup_voice_name(voice)}_{name}.json', 'w', encoding="utf-8") as f:
 				f.write(json.dumps(audio_cache[name]['settings'], indent='\t') )
 
 	if args.embed_output_metadata:
@@ -387,7 +764,7 @@ def generate_valle(**kwargs):
 			if 'pruned' in audio_cache[name] and audio_cache[name]['pruned']:
 				continue
 
-			metadata = music_tag.load_file(f"{outdir}/{voice}_{name}.wav")
+			metadata = music_tag.load_file(f"{outdir}/{cleanup_voice_name(voice)}_{name}.wav")
 			metadata['lyrics'] = json.dumps(audio_cache[name]['settings'])
 			metadata.save()
  
@@ -415,8 +792,6 @@ def generate_valle(**kwargs):
 		stats,
 	)
 
-
-
 def generate_tortoise(**kwargs):
 	parameters = {}
 	parameters.update(kwargs)
@@ -698,7 +1073,7 @@ def generate_tortoise(**kwargs):
 				'settings': get_info(voice=override['voice'] if override and 'voice' in override else voice, settings=settings)
 			}
 			# save here in case some error happens mid-batch
-			torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, tts.output_sample_rate)
+			torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', audio, tts.output_sample_rate)
 
 	del gen
 	do_gc()
@@ -712,7 +1087,7 @@ def generate_tortoise(**kwargs):
 			audio = volume_adjust(audio)
 
 		audio_cache[k]['audio'] = audio
-		torchaudio.save(f'{outdir}/{voice}_{k}.wav', audio, args.output_sample_rate)
+		torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{k}.wav', audio, args.output_sample_rate)
 
 	output_voices = []
 	for candidate in range(parameters['candidates']):
@@ -725,7 +1100,7 @@ def generate_tortoise(**kwargs):
 			
 			name = get_name(candidate=candidate, combined=True)
 			audio = torch.cat(audio_clips, dim=-1)
-			torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, args.output_sample_rate)
+			torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', audio, args.output_sample_rate)
 
 			audio = audio.squeeze(0).cpu()
 			audio_cache[name] = {
@@ -750,8 +1125,8 @@ def generate_tortoise(**kwargs):
 				if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
 					continue
 
-				path = f'{outdir}/{voice}_{name}.wav'
-				fixed = f'{outdir}/{voice}_{name}_fixed.wav'
+				path = f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav'
+				fixed = f'{outdir}/{cleanup_voice_name(voice)}_{name}_fixed.wav'
 				voicefixer.restore(
 					input=path,
 					output=fixed,
@@ -775,13 +1150,13 @@ def generate_tortoise(**kwargs):
 		if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
 			if args.prune_nonfinal_outputs:
 				audio_cache[name]['pruned'] = True
-				os.remove(f'{outdir}/{voice}_{name}.wav')
+				os.remove(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
 			continue
 
-		output_voices.append(f'{outdir}/{voice}_{name}.wav')
+		output_voices.append(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
 
 		if not args.embed_output_metadata:
-			with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f:
+			with open(f'{outdir}/{cleanup_voice_name(voice)}_{name}.json', 'w', encoding="utf-8") as f:
 				f.write(json.dumps(audio_cache[name]['settings'], indent='\t') )
 
 	if args.embed_output_metadata:
@@ -789,7 +1164,7 @@ def generate_tortoise(**kwargs):
 			if 'pruned' in audio_cache[name] and audio_cache[name]['pruned']:
 				continue
 
-			metadata = music_tag.load_file(f"{outdir}/{voice}_{name}.wav")
+			metadata = music_tag.load_file(f"{outdir}/{cleanup_voice_name(voice)}_{name}.wav")
 			metadata['lyrics'] = json.dumps(audio_cache[name]['settings'])
 			metadata.save()
  
@@ -1096,9 +1471,9 @@ class TrainingState():
 					'ar-half.loss', 'nar-half.loss', 'ar-half+nar-half.loss',
 					'ar-quarter.loss', 'nar-quarter.loss', 'ar-quarter+nar-quarter.loss',
 
-				#	'ar.loss.nll', 'nar.loss.nll',
-				#	'ar-half.loss.nll', 'nar-half.loss.nll',
-				#	'ar-quarter.loss.nll', 'nar-quarter.loss.nll',
+					'ar.loss.nll', 'nar.loss.nll',
+					'ar-half.loss.nll', 'nar-half.loss.nll',
+					'ar-quarter.loss.nll', 'nar-quarter.loss.nll',
 				]
 
 				keys['accuracies'] = [
@@ -1464,14 +1839,14 @@ def run_training(config_path, verbose=False, keep_x_past_checkpoints=0, progress
 		return_code = training_state.process.wait()
 		training_state = None
 
-def update_training_dataplot(x_lim=None, y_lim=None, config_path=None):
+def update_training_dataplot(x_min=None, x_max=None, y_min=None, y_max=None, config_path=None):
 	global training_state
 	losses = None
 	lrs = None
 	grad_norms = None
 
-	x_lim = [ 0, x_lim ]
-	y_lim = [ 0, y_lim ]
+	x_lim = [ x_min, x_max ]
+	y_lim = [ y_min, y_max ]
 
 	if not training_state:
 		if config_path:
@@ -1490,23 +1865,23 @@ def update_training_dataplot(x_lim=None, y_lim=None, config_path=None):
 			losses = gr.LinePlot.update(
 				value = pd.DataFrame(training_state.statistics['loss']),
 				x_lim=x_lim, y_lim=y_lim,
-				x="epoch", y="value",
+				x="it", y="value", # x="epoch",
 				title="Loss Metrics", color="type", tooltip=['epoch', 'it', 'value', 'type'],
 				width=500, height=350
 			)
 		if len(training_state.statistics['lr']) > 0:
 			lrs = gr.LinePlot.update(
 				value = pd.DataFrame(training_state.statistics['lr']),
-				x_lim=x_lim, y_lim=y_lim,
-				x="epoch", y="value",
+				x_lim=x_lim,
+				x="it", y="value", # x="epoch",
 				title="Learning Rate", color="type", tooltip=['epoch', 'it', 'value', 'type'],
 				width=500, height=350
 			)
 		if len(training_state.statistics['grad_norm']) > 0:
 			grad_norms = gr.LinePlot.update(
 				value = pd.DataFrame(training_state.statistics['grad_norm']),
-				x_lim=x_lim, y_lim=y_lim,
-				x="epoch", y="value",
+				x_lim=x_lim,
+				x="it", y="value", # x="epoch",
 				title="Gradient Normals", color="type", tooltip=['epoch', 'it', 'value', 'type'],
 				width=500, height=350
 			)
@@ -1649,13 +2024,13 @@ def whisper_transcribe( file, language=None ):
 		device = "cuda" if get_device_name() == "cuda" else "cpu"
 		if whisper_vad:
 			# omits a considerable amount of the end
-			"""
 			if args.whisper_batchsize > 1:
 				result = whisperx.transcribe_with_vad_parallel(whisper_model, file, whisper_vad, batch_size=args.whisper_batchsize, language=language, task="transcribe")
 			else:
 				result = whisperx.transcribe_with_vad(whisper_model, file, whisper_vad)
 			"""
 			result = whisperx.transcribe_with_vad(whisper_model, file, whisper_vad)
+			"""
 		else:
 			result = whisper_model.transcribe(file)
 			
@@ -1717,7 +2092,7 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non
 	os.makedirs(f'{indir}/audio/', exist_ok=True)
 	
 	TARGET_SAMPLE_RATE = 22050
-	if args.tts_backend == "vall-e":
+	if args.tts_backend != "tortoise":
 		TARGET_SAMPLE_RATE = 24000
 	if tts:
 		TARGET_SAMPLE_RATE = tts.input_sample_rate
@@ -1735,7 +2110,7 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non
 		try:
 			result = whisper_transcribe(file, language=language)
 		except Exception as e:
-			print("Failed to transcribe:", file)
+			print("Failed to transcribe:", file, e)
 			continue
 
 		results[basename] = result
@@ -1802,7 +2177,7 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, resul
 		results = json.load(open(infile, 'r', encoding="utf-8"))
 
 	TARGET_SAMPLE_RATE = 22050
-	if args.tts_backend == "vall-e":
+	if args.tts_backend != "tortoise":
 		TARGET_SAMPLE_RATE = 24000
 	if tts:
 		TARGET_SAMPLE_RATE = tts.input_sample_rate
@@ -1934,8 +2309,7 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
 	lines = { 'training': [], 'validation': [] }
 	segments = {}
 
-	# I'm not sure how the VALL-E implementation decides what's validation and what's not
-	if args.tts_backend == "vall-e":
+	if args.tts_backend != "tortoise":
 		text_length = 0
 		audio_length = 0
 
@@ -3008,6 +3382,10 @@ def load_tts( restart=False,
 
 		print(f"Loading VALL-E... (Config: {valle_model})")
 		tts = VALLE_TTS(config=args.valle_model)
+	elif args.tts_backend == "bark":
+
+		print(f"Loading Bark...")
+		tts = Bark_TTS(small=args.low_vram)
 
 	print("Loaded TTS, ready for generation.")
 	tts_loading = False
diff --git a/src/webui.py b/src/webui.py
index 66540c1..6bc4bbd 100755
--- a/src/webui.py
+++ b/src/webui.py
@@ -167,6 +167,10 @@ def reset_generate_settings_proxy():
 	return tuple(res)
 
 def compute_latents_proxy(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
+	if args.tts_backend == "bark":
+		global tts
+		tts.create_voice( voice )
+		return voice
 	compute_latents( voice=voice, voice_latents_chunks=voice_latents_chunks, progress=progress )
 	return voice
 
@@ -222,13 +226,13 @@ def prepare_all_datasets( language, validation_text_length, validation_audio_len
 		print("Processing:", voice)
 		message = transcribe_dataset( voice=voice, language=language, skip_existings=skip_existings, progress=progress )
 		messages.append(message)
-	"""
 
 	if slice_audio:
 		for voice in voices:
 			print("Processing:", voice)
 			message = slice_dataset( voice, trim_silence=trim_silence, start_offset=slice_start_offset, end_offset=slice_end_offset, results=None, progress=progress )
 			messages.append(message)
+	"""
 
 	for voice in voices:
 		print("Processing:", voice)
@@ -400,12 +404,13 @@ def setup_gradio():
 						outputs=GENERATE_SETTINGS["mic_audio"],
 					)
 				with gr.Column():
+					preset = None						
 					GENERATE_SETTINGS["candidates"] = gr.Slider(value=1, minimum=1, maximum=6, step=1, label="Candidates", visible=args.tts_backend=="tortoise")
-					GENERATE_SETTINGS["seed"] = gr.Number(value=0, precision=0, label="Seed")
+					GENERATE_SETTINGS["seed"] = gr.Number(value=0, precision=0, label="Seed", visible=args.tts_backend!="tortoise")
 
-					preset = gr.Radio( ["Ultra Fast", "Fast", "Standard", "High Quality"], label="Preset", type="value", value="Ultra Fast" )
+					preset = gr.Radio( ["Ultra Fast", "Fast", "Standard", "High Quality"], label="Preset", type="value", value="Ultra Fast", visible=args.tts_backend=="tortoise" )
 
-					GENERATE_SETTINGS["num_autoregressive_samples"] = gr.Slider(value=16, minimum=2, maximum=512, step=1, label="Samples")
+					GENERATE_SETTINGS["num_autoregressive_samples"] = gr.Slider(value=16, minimum=2, maximum=512, step=1, label="Samples", visible=args.tts_backend!="bark")
 					GENERATE_SETTINGS["diffusion_iterations"] = gr.Slider(value=30, minimum=0, maximum=512, step=1, label="Iterations", visible=args.tts_backend=="tortoise")
 
 					GENERATE_SETTINGS["temperature"] = gr.Slider(value=0.2, minimum=0, maximum=1, step=0.1, label="Temperature")
@@ -490,7 +495,7 @@ def setup_gradio():
 						merger_button = gr.Button(value="Run Merger")
 				with gr.Column():
 					merger_output = gr.TextArea(label="Console Output", max_lines=8)
-		with gr.Tab("Training"):
+		with gr.Tab("Training", visible=args.tts_backend != "bark"):
 			with gr.Tab("Prepare Dataset"):
 				with gr.Row():
 					with gr.Column():
@@ -586,8 +591,10 @@ def setup_gradio():
 						keep_x_past_checkpoints = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
 						
 						with gr.Row():
-							training_graph_x_lim = gr.Number(label="X Limit", precision=0, value=0)
-							training_graph_y_lim = gr.Number(label="Y Limit", precision=0, value=0)
+							training_graph_x_min = gr.Number(label="X Min", precision=0, value=0)
+							training_graph_x_max = gr.Number(label="X Max", precision=0, value=0)
+							training_graph_y_min = gr.Number(label="Y Min", precision=0, value=0)
+							training_graph_y_max = gr.Number(label="Y Max", precision=0, value=0)
 
 						with gr.Row():
 							start_training_button = gr.Button(value="Train")
@@ -597,7 +604,7 @@ def setup_gradio():
 						
 					with gr.Column():
 						training_loss_graph = gr.LinePlot(label="Training Metrics",
-							x="epoch",
+							x="it", # x="epoch",
 							y="value",
 							title="Loss Metrics",
 							color="type",
@@ -606,7 +613,7 @@ def setup_gradio():
 							height=350,
 						)
 						training_lr_graph = gr.LinePlot(label="Training Metrics",
-							x="epoch",
+							x="it", # x="epoch",
 							y="value",
 							title="Learning Rate",
 							color="type",
@@ -615,7 +622,7 @@ def setup_gradio():
 							height=350,
 						)
 						training_grad_norm_graph = gr.LinePlot(label="Training Metrics",
-							x="epoch",
+							x="it", # x="epoch",
 							y="value",
 							title="Gradient Normals",
 							color="type",
@@ -765,13 +772,14 @@ def setup_gradio():
 			inputs=show_experimental_settings,
 			outputs=experimental_column
 		)
-		preset.change(fn=update_presets,
-			inputs=preset,
-			outputs=[
-				GENERATE_SETTINGS['num_autoregressive_samples'],
-				GENERATE_SETTINGS['diffusion_iterations'],
-			],
-		)
+		if preset:
+			preset.change(fn=update_presets,
+				inputs=preset,
+				outputs=[
+					GENERATE_SETTINGS['num_autoregressive_samples'],
+					GENERATE_SETTINGS['diffusion_iterations'],
+				],
+			)
 
 		recompute_voice_latents.click(compute_latents_proxy,
 			inputs=[
@@ -860,8 +868,10 @@ def setup_gradio():
 		training_output.change(
 			fn=update_training_dataplot,
 			inputs=[
-				training_graph_x_lim,
-				training_graph_y_lim,
+				training_graph_x_min,
+				training_graph_x_max,
+				training_graph_y_min,
+				training_graph_y_max,
 			],
 			outputs=[
 				training_loss_graph,
@@ -874,8 +884,10 @@ def setup_gradio():
 		view_losses.click(
 			fn=update_training_dataplot,
 			inputs=[
-				training_graph_x_lim,
-				training_graph_y_lim,
+				training_graph_x_min,
+				training_graph_x_max,
+				training_graph_y_min,
+				training_graph_y_max,
 				training_configs,
 			],
 			outputs=[