From b6440091fbfbf44deffafa82ec284dad274a8e9e Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 26 Apr 2023 04:48:09 +0000 Subject: [PATCH] Very, very, VERY, barebones integration with Bark (documentation soon) --- src/utils.py | 456 ++++++++++++++++++++++++++++++++++++++++++++++----- src/webui.py | 54 +++--- 2 files changed, 450 insertions(+), 60 deletions(-) diff --git a/src/utils.py b/src/utils.py index 9b07d93..3cf50fe 100755 --- a/src/utils.py +++ b/src/utils.py @@ -31,6 +31,7 @@ import music_tag import gradio as gr import gradio.utils import pandas as pd +import numpy as np from glob import glob from datetime import datetime @@ -65,6 +66,7 @@ MIN_TRAINING_DURATION = 0.6 MAX_TRAINING_DURATION = 11.6097505669 VALLE_ENABLED = False +BARK_ENABLED = False try: from vall_e.emb.qnt import encode as valle_quantize @@ -76,11 +78,98 @@ try: VALLE_ENABLED = True except Exception as e: + if False: # args.tts_backend == "vall-e": + raise e pass if VALLE_ENABLED: TTSES.append('vall-e') +try: + from bark.generation import SAMPLE_RATE as BARK_SAMPLE_RATE, ALLOWED_PROMPTS, preload_models, codec_decode, generate_coarse, generate_fine, generate_text_semantic, load_codec_model + from bark.api import generate_audio as bark_generate_audio + from encodec.utils import convert_audio + + from scipy.io.wavfile import write as write_wav + + BARK_ENABLED = True +except Exception as e: + if False: # args.tts_backend == "bark": + raise e + pass + +if BARK_ENABLED: + TTSES.append('bark') + class Bark_TTS(): + def __init__(self, small=False): + self.input_sample_rate = BARK_SAMPLE_RATE + self.output_sample_rate = args.output_sample_rate + + preload_models( + text_use_gpu=True, + coarse_use_gpu=True, + fine_use_gpu=True, + codec_use_gpu=True, + + text_use_small=small, + coarse_use_small=small, + fine_use_small=small, + + force_reload=False + ) + + def create_voice( self, voice, device='cuda' ): + transcription_json = f'./training/{voice}/whisper.json' + if not os.path.exists(transcription_json): + raise f"Transcription for voice not found: {voice}" + + transcriptions = json.load(open(transcription_json, 'r', encoding="utf-8")) + candidates = [] + for file in transcriptions: + result = transcriptions[file] + for segment in result['segments']: + entry = ( + file.replace(".wav", f"_{pad(segment['id'], 4)}.wav"), + segment['end'] - segment['start'], + segment['text'] + ) + candidates.append(entry) + + candidates.sort(key=lambda x: x[1]) + candidate = random.choice(candidates) + audio_filepath = f'./training/{voice}/audio/{candidate[0]}' + text = candidate[-1] + + print("Using as reference:", audio_filepath, text) + + # Load and pre-process the audio waveform + model = load_codec_model(use_gpu=True) + wav, sr = torchaudio.load(audio_filepath) + wav = convert_audio(wav, sr, model.sample_rate, model.channels) + wav = wav.unsqueeze(0).to(device) + + # Extract discrete codes from EnCodec + with torch.no_grad(): + encoded_frames = model.encode(wav) + codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze().cpu().numpy() # [n_q, T] + + # get seconds of audio + seconds = wav.shape[-1] / model.sample_rate + # generate semantic tokens + semantic_tokens = generate_text_semantic(text, max_gen_duration_s=seconds, top_k=50, top_p=.95, temp=0.7) + + output_path = './modules/bark/bark/assets/prompts/' + voice.replace("/", "_") + '.npz' + np.savez(output_path, fine_prompt=codes, coarse_prompt=codes[:2, :], semantic_prompt=semantic_tokens) + + def inference( self, text, voice, text_temp=0.7, waveform_temp=0.7 ): + if not os.path.exists('./modules/bark/bark/assets/prompts/' + voice + '.npz'): + self.create_voice( voice ) + voice = voice.replace("/", "_") + if voice not in ALLOWED_PROMPTS: + ALLOWED_PROMPTS.add( voice ) + + return (bark_generate_audio(text, history_prompt=voice, text_temp=text_temp, waveform_temp=waveform_temp), BARK_SAMPLE_RATE) + args = None tts = None tts_loading = False @@ -96,6 +185,9 @@ training_state = None current_voice = None +def cleanup_voice_name( name ): + return name.split("/")[-1] + def resample( waveform, input_rate, output_rate=44100 ): # mono-ize waveform = torch.mean(waveform, dim=0, keepdim=True) @@ -121,6 +213,291 @@ def generate(**kwargs): return generate_tortoise(**kwargs) if args.tts_backend == "vall-e": return generate_valle(**kwargs) + if args.tts_backend == "bark": + return generate_bark(**kwargs) + +def generate_bark(**kwargs): + parameters = {} + parameters.update(kwargs) + + voice = parameters['voice'] + progress = parameters['progress'] if 'progress' in parameters else None + if parameters['seed'] == 0: + parameters['seed'] = None + + usedSeed = parameters['seed'] + + global args + global tts + + unload_whisper() + unload_voicefixer() + + if not tts: + # should check if it's loading or unloaded, and load it if it's unloaded + if tts_loading: + raise Exception("TTS is still initializing...") + if progress is not None: + progress(0, "Initializing TTS...") + load_tts() + if hasattr(tts, "loading") and tts.loading: + raise Exception("TTS is still initializing...") + + do_gc() + + voice_samples = None + conditioning_latents = None + sample_voice = None + + voice_cache = {} + + def get_settings( override=None ): + settings = { + 'voice': parameters['voice'], + 'text_temp': float(parameters['temperature']), + 'waveform_temp': float(parameters['temperature']), + } + + # could be better to just do a ternary on everything above, but i am not a professional + selected_voice = voice + if override is not None: + if 'voice' in override: + selected_voice = override['voice'] + + for k in override: + if k not in settings: + continue + settings[k] = override[k] + + return settings + + if not parameters['delimiter']: + parameters['delimiter'] = "\n" + elif parameters['delimiter'] == "\\n": + parameters['delimiter'] = "\n" + + if parameters['delimiter'] and parameters['delimiter'] != "" and parameters['delimiter'] in parameters['text']: + texts = parameters['text'].split(parameters['delimiter']) + else: + texts = split_and_recombine_text(parameters['text']) + + full_start_time = time.time() + + outdir = f"{args.results_folder}/{voice}/" + os.makedirs(outdir, exist_ok=True) + + audio_cache = {} + + volume_adjust = torchaudio.transforms.Vol(gain=args.output_volume, gain_type="amplitude") if args.output_volume != 1 else None + + idx = 0 + idx_cache = {} + for i, file in enumerate(os.listdir(outdir)): + filename = os.path.basename(file) + extension = os.path.splitext(filename)[1] + if extension != ".json" and extension != ".wav": + continue + match = re.findall(rf"^{cleanup_voice_name(voice)}_(\d+)(?:.+?)?{extension}$", filename) + if match and len(match) > 0: + key = int(match[0]) + idx_cache[key] = True + + if len(idx_cache) > 0: + keys = sorted(list(idx_cache.keys())) + idx = keys[-1] + 1 + + idx = pad(idx, 4) + + def get_name(line=0, candidate=0, combined=False): + name = f"{idx}" + if combined: + name = f"{name}_combined" + elif len(texts) > 1: + name = f"{name}_{line}" + if parameters['candidates'] > 1: + name = f"{name}_{candidate}" + return name + + def get_info( voice, settings = None, latents = True ): + info = {} + info.update(parameters) + + info['time'] = time.time()-full_start_time + info['datetime'] = datetime.now().isoformat() + + info['progress'] = None + del info['progress'] + + if info['delimiter'] == "\n": + info['delimiter'] = "\\n" + + if settings is not None: + for k in settings: + if k in info: + info[k] = settings[k] + return info + + INFERENCING = True + for line, cut_text in enumerate(texts): + progress.msg_prefix = f'[{str(line+1)}/{str(len(texts))}]' + print(f"{progress.msg_prefix} Generating line: {cut_text}") + start_time = time.time() + + # do setting editing + match = re.findall(r'^(\{.+\}) (.+?)$', cut_text) + override = None + if match and len(match) > 0: + match = match[0] + try: + override = json.loads(match[0]) + cut_text = match[1].strip() + except Exception as e: + raise Exception("Prompt settings editing requested, but received invalid JSON") + + settings = get_settings( override=override ) + + gen = tts.inference(cut_text, **settings ) + + run_time = time.time()-start_time + print(f"Generating line took {run_time} seconds") + + if not isinstance(gen, list): + gen = [gen] + + for j, g in enumerate(gen): + wav, sr = g + name = get_name(line=line, candidate=j) + + settings['text'] = cut_text + settings['time'] = run_time + settings['datetime'] = datetime.now().isoformat() + + # save here in case some error happens mid-batch + #torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu(), sr) + write_wav(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', sr, wav) + wav, sr = torchaudio.load(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav') + + audio_cache[name] = { + 'audio': wav, + 'settings': get_info(voice=override['voice'] if override and 'voice' in override else voice, settings=settings) + } + + del gen + do_gc() + INFERENCING = False + + for k in audio_cache: + audio = audio_cache[k]['audio'] + + audio, _ = resample(audio, tts.output_sample_rate, args.output_sample_rate) + if volume_adjust is not None: + audio = volume_adjust(audio) + + audio_cache[k]['audio'] = audio + torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{k}.wav', audio, args.output_sample_rate) + + output_voices = [] + for candidate in range(parameters['candidates']): + if len(texts) > 1: + audio_clips = [] + for line in range(len(texts)): + name = get_name(line=line, candidate=candidate) + audio = audio_cache[name]['audio'] + audio_clips.append(audio) + + name = get_name(candidate=candidate, combined=True) + audio = torch.cat(audio_clips, dim=-1) + torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', audio, args.output_sample_rate) + + audio = audio.squeeze(0).cpu() + audio_cache[name] = { + 'audio': audio, + 'settings': get_info(voice=voice), + 'output': True + } + else: + name = get_name(candidate=candidate) + audio_cache[name]['output'] = True + + + if args.voice_fixer: + if not voicefixer: + progress(0, "Loading voicefix...") + load_voicefixer() + + try: + fixed_cache = {} + for name in progress.tqdm(audio_cache, desc="Running voicefix..."): + del audio_cache[name]['audio'] + if 'output' not in audio_cache[name] or not audio_cache[name]['output']: + continue + + path = f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav' + fixed = f'{outdir}/{cleanup_voice_name(voice)}_{name}_fixed.wav' + voicefixer.restore( + input=path, + output=fixed, + cuda=get_device_name() == "cuda" and args.voice_fixer_use_cuda, + #mode=mode, + ) + + fixed_cache[f'{name}_fixed'] = { + 'settings': audio_cache[name]['settings'], + 'output': True + } + audio_cache[name]['output'] = False + + for name in fixed_cache: + audio_cache[name] = fixed_cache[name] + except Exception as e: + print(e) + print("\nFailed to run Voicefixer") + + for name in audio_cache: + if 'output' not in audio_cache[name] or not audio_cache[name]['output']: + if args.prune_nonfinal_outputs: + audio_cache[name]['pruned'] = True + os.remove(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav') + continue + + output_voices.append(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav') + + if not args.embed_output_metadata: + with open(f'{outdir}/{cleanup_voice_name(voice)}_{name}.json', 'w', encoding="utf-8") as f: + f.write(json.dumps(audio_cache[name]['settings'], indent='\t') ) + + if args.embed_output_metadata: + for name in progress.tqdm(audio_cache, desc="Embedding metadata..."): + if 'pruned' in audio_cache[name] and audio_cache[name]['pruned']: + continue + + metadata = music_tag.load_file(f"{outdir}/{cleanup_voice_name(voice)}_{name}.wav") + metadata['lyrics'] = json.dumps(audio_cache[name]['settings']) + metadata.save() + + if sample_voice is not None: + sample_voice = (tts.input_sample_rate, sample_voice.numpy()) + + info = get_info(voice=voice, latents=False) + print(f"Generation took {info['time']} seconds, saved to '{output_voices[0]}'\n") + + info['seed'] = usedSeed + if 'latents' in info: + del info['latents'] + + os.makedirs('./config/', exist_ok=True) + with open(f'./config/generate.json', 'w', encoding="utf-8") as f: + f.write(json.dumps(info, indent='\t') ) + + stats = [ + [ parameters['seed'], "{:.3f}".format(info['time']) ] + ] + + return ( + sample_voice, + output_voices, + stats, + ) def generate_valle(**kwargs): parameters = {} @@ -289,9 +666,9 @@ def generate_valle(**kwargs): settings['datetime'] = datetime.now().isoformat() # save here in case some error happens mid-batch - #torchaudio.save(f'{outdir}/{voice}_{name}.wav', wav.cpu(), sr) - soundfile.write(f'{outdir}/{voice}_{name}.wav', wav.cpu()[0,0], sr) - wav, sr = torchaudio.load(f'{outdir}/{voice}_{name}.wav') + #torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu(), sr) + soundfile.write(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu()[0,0], sr) + wav, sr = torchaudio.load(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav') audio_cache[name] = { 'audio': wav, @@ -310,7 +687,7 @@ def generate_valle(**kwargs): audio = volume_adjust(audio) audio_cache[k]['audio'] = audio - torchaudio.save(f'{outdir}/{voice}_{k}.wav', audio, args.output_sample_rate) + torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{k}.wav', audio, args.output_sample_rate) output_voices = [] for candidate in range(parameters['candidates']): @@ -323,7 +700,7 @@ def generate_valle(**kwargs): name = get_name(candidate=candidate, combined=True) audio = torch.cat(audio_clips, dim=-1) - torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, args.output_sample_rate) + torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', audio, args.output_sample_rate) audio = audio.squeeze(0).cpu() audio_cache[name] = { @@ -348,8 +725,8 @@ def generate_valle(**kwargs): if 'output' not in audio_cache[name] or not audio_cache[name]['output']: continue - path = f'{outdir}/{voice}_{name}.wav' - fixed = f'{outdir}/{voice}_{name}_fixed.wav' + path = f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav' + fixed = f'{outdir}/{cleanup_voice_name(voice)}_{name}_fixed.wav' voicefixer.restore( input=path, output=fixed, @@ -373,13 +750,13 @@ def generate_valle(**kwargs): if 'output' not in audio_cache[name] or not audio_cache[name]['output']: if args.prune_nonfinal_outputs: audio_cache[name]['pruned'] = True - os.remove(f'{outdir}/{voice}_{name}.wav') + os.remove(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav') continue - output_voices.append(f'{outdir}/{voice}_{name}.wav') + output_voices.append(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav') if not args.embed_output_metadata: - with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f: + with open(f'{outdir}/{cleanup_voice_name(voice)}_{name}.json', 'w', encoding="utf-8") as f: f.write(json.dumps(audio_cache[name]['settings'], indent='\t') ) if args.embed_output_metadata: @@ -387,7 +764,7 @@ def generate_valle(**kwargs): if 'pruned' in audio_cache[name] and audio_cache[name]['pruned']: continue - metadata = music_tag.load_file(f"{outdir}/{voice}_{name}.wav") + metadata = music_tag.load_file(f"{outdir}/{cleanup_voice_name(voice)}_{name}.wav") metadata['lyrics'] = json.dumps(audio_cache[name]['settings']) metadata.save() @@ -415,8 +792,6 @@ def generate_valle(**kwargs): stats, ) - - def generate_tortoise(**kwargs): parameters = {} parameters.update(kwargs) @@ -698,7 +1073,7 @@ def generate_tortoise(**kwargs): 'settings': get_info(voice=override['voice'] if override and 'voice' in override else voice, settings=settings) } # save here in case some error happens mid-batch - torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, tts.output_sample_rate) + torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', audio, tts.output_sample_rate) del gen do_gc() @@ -712,7 +1087,7 @@ def generate_tortoise(**kwargs): audio = volume_adjust(audio) audio_cache[k]['audio'] = audio - torchaudio.save(f'{outdir}/{voice}_{k}.wav', audio, args.output_sample_rate) + torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{k}.wav', audio, args.output_sample_rate) output_voices = [] for candidate in range(parameters['candidates']): @@ -725,7 +1100,7 @@ def generate_tortoise(**kwargs): name = get_name(candidate=candidate, combined=True) audio = torch.cat(audio_clips, dim=-1) - torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, args.output_sample_rate) + torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', audio, args.output_sample_rate) audio = audio.squeeze(0).cpu() audio_cache[name] = { @@ -750,8 +1125,8 @@ def generate_tortoise(**kwargs): if 'output' not in audio_cache[name] or not audio_cache[name]['output']: continue - path = f'{outdir}/{voice}_{name}.wav' - fixed = f'{outdir}/{voice}_{name}_fixed.wav' + path = f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav' + fixed = f'{outdir}/{cleanup_voice_name(voice)}_{name}_fixed.wav' voicefixer.restore( input=path, output=fixed, @@ -775,13 +1150,13 @@ def generate_tortoise(**kwargs): if 'output' not in audio_cache[name] or not audio_cache[name]['output']: if args.prune_nonfinal_outputs: audio_cache[name]['pruned'] = True - os.remove(f'{outdir}/{voice}_{name}.wav') + os.remove(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav') continue - output_voices.append(f'{outdir}/{voice}_{name}.wav') + output_voices.append(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav') if not args.embed_output_metadata: - with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f: + with open(f'{outdir}/{cleanup_voice_name(voice)}_{name}.json', 'w', encoding="utf-8") as f: f.write(json.dumps(audio_cache[name]['settings'], indent='\t') ) if args.embed_output_metadata: @@ -789,7 +1164,7 @@ def generate_tortoise(**kwargs): if 'pruned' in audio_cache[name] and audio_cache[name]['pruned']: continue - metadata = music_tag.load_file(f"{outdir}/{voice}_{name}.wav") + metadata = music_tag.load_file(f"{outdir}/{cleanup_voice_name(voice)}_{name}.wav") metadata['lyrics'] = json.dumps(audio_cache[name]['settings']) metadata.save() @@ -1096,9 +1471,9 @@ class TrainingState(): 'ar-half.loss', 'nar-half.loss', 'ar-half+nar-half.loss', 'ar-quarter.loss', 'nar-quarter.loss', 'ar-quarter+nar-quarter.loss', - # 'ar.loss.nll', 'nar.loss.nll', - # 'ar-half.loss.nll', 'nar-half.loss.nll', - # 'ar-quarter.loss.nll', 'nar-quarter.loss.nll', + 'ar.loss.nll', 'nar.loss.nll', + 'ar-half.loss.nll', 'nar-half.loss.nll', + 'ar-quarter.loss.nll', 'nar-quarter.loss.nll', ] keys['accuracies'] = [ @@ -1464,14 +1839,14 @@ def run_training(config_path, verbose=False, keep_x_past_checkpoints=0, progress return_code = training_state.process.wait() training_state = None -def update_training_dataplot(x_lim=None, y_lim=None, config_path=None): +def update_training_dataplot(x_min=None, x_max=None, y_min=None, y_max=None, config_path=None): global training_state losses = None lrs = None grad_norms = None - x_lim = [ 0, x_lim ] - y_lim = [ 0, y_lim ] + x_lim = [ x_min, x_max ] + y_lim = [ y_min, y_max ] if not training_state: if config_path: @@ -1490,23 +1865,23 @@ def update_training_dataplot(x_lim=None, y_lim=None, config_path=None): losses = gr.LinePlot.update( value = pd.DataFrame(training_state.statistics['loss']), x_lim=x_lim, y_lim=y_lim, - x="epoch", y="value", + x="it", y="value", # x="epoch", title="Loss Metrics", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350 ) if len(training_state.statistics['lr']) > 0: lrs = gr.LinePlot.update( value = pd.DataFrame(training_state.statistics['lr']), - x_lim=x_lim, y_lim=y_lim, - x="epoch", y="value", + x_lim=x_lim, + x="it", y="value", # x="epoch", title="Learning Rate", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350 ) if len(training_state.statistics['grad_norm']) > 0: grad_norms = gr.LinePlot.update( value = pd.DataFrame(training_state.statistics['grad_norm']), - x_lim=x_lim, y_lim=y_lim, - x="epoch", y="value", + x_lim=x_lim, + x="it", y="value", # x="epoch", title="Gradient Normals", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350 ) @@ -1649,13 +2024,13 @@ def whisper_transcribe( file, language=None ): device = "cuda" if get_device_name() == "cuda" else "cpu" if whisper_vad: # omits a considerable amount of the end - """ if args.whisper_batchsize > 1: result = whisperx.transcribe_with_vad_parallel(whisper_model, file, whisper_vad, batch_size=args.whisper_batchsize, language=language, task="transcribe") else: result = whisperx.transcribe_with_vad(whisper_model, file, whisper_vad) """ result = whisperx.transcribe_with_vad(whisper_model, file, whisper_vad) + """ else: result = whisper_model.transcribe(file) @@ -1717,7 +2092,7 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non os.makedirs(f'{indir}/audio/', exist_ok=True) TARGET_SAMPLE_RATE = 22050 - if args.tts_backend == "vall-e": + if args.tts_backend != "tortoise": TARGET_SAMPLE_RATE = 24000 if tts: TARGET_SAMPLE_RATE = tts.input_sample_rate @@ -1735,7 +2110,7 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non try: result = whisper_transcribe(file, language=language) except Exception as e: - print("Failed to transcribe:", file) + print("Failed to transcribe:", file, e) continue results[basename] = result @@ -1802,7 +2177,7 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, resul results = json.load(open(infile, 'r', encoding="utf-8")) TARGET_SAMPLE_RATE = 22050 - if args.tts_backend == "vall-e": + if args.tts_backend != "tortoise": TARGET_SAMPLE_RATE = 24000 if tts: TARGET_SAMPLE_RATE = tts.input_sample_rate @@ -1934,8 +2309,7 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p lines = { 'training': [], 'validation': [] } segments = {} - # I'm not sure how the VALL-E implementation decides what's validation and what's not - if args.tts_backend == "vall-e": + if args.tts_backend != "tortoise": text_length = 0 audio_length = 0 @@ -3008,6 +3382,10 @@ def load_tts( restart=False, print(f"Loading VALL-E... (Config: {valle_model})") tts = VALLE_TTS(config=args.valle_model) + elif args.tts_backend == "bark": + + print(f"Loading Bark...") + tts = Bark_TTS(small=args.low_vram) print("Loaded TTS, ready for generation.") tts_loading = False diff --git a/src/webui.py b/src/webui.py index 66540c1..6bc4bbd 100755 --- a/src/webui.py +++ b/src/webui.py @@ -167,6 +167,10 @@ def reset_generate_settings_proxy(): return tuple(res) def compute_latents_proxy(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)): + if args.tts_backend == "bark": + global tts + tts.create_voice( voice ) + return voice compute_latents( voice=voice, voice_latents_chunks=voice_latents_chunks, progress=progress ) return voice @@ -222,13 +226,13 @@ def prepare_all_datasets( language, validation_text_length, validation_audio_len print("Processing:", voice) message = transcribe_dataset( voice=voice, language=language, skip_existings=skip_existings, progress=progress ) messages.append(message) - """ if slice_audio: for voice in voices: print("Processing:", voice) message = slice_dataset( voice, trim_silence=trim_silence, start_offset=slice_start_offset, end_offset=slice_end_offset, results=None, progress=progress ) messages.append(message) + """ for voice in voices: print("Processing:", voice) @@ -400,12 +404,13 @@ def setup_gradio(): outputs=GENERATE_SETTINGS["mic_audio"], ) with gr.Column(): + preset = None GENERATE_SETTINGS["candidates"] = gr.Slider(value=1, minimum=1, maximum=6, step=1, label="Candidates", visible=args.tts_backend=="tortoise") - GENERATE_SETTINGS["seed"] = gr.Number(value=0, precision=0, label="Seed") + GENERATE_SETTINGS["seed"] = gr.Number(value=0, precision=0, label="Seed", visible=args.tts_backend!="tortoise") - preset = gr.Radio( ["Ultra Fast", "Fast", "Standard", "High Quality"], label="Preset", type="value", value="Ultra Fast" ) + preset = gr.Radio( ["Ultra Fast", "Fast", "Standard", "High Quality"], label="Preset", type="value", value="Ultra Fast", visible=args.tts_backend=="tortoise" ) - GENERATE_SETTINGS["num_autoregressive_samples"] = gr.Slider(value=16, minimum=2, maximum=512, step=1, label="Samples") + GENERATE_SETTINGS["num_autoregressive_samples"] = gr.Slider(value=16, minimum=2, maximum=512, step=1, label="Samples", visible=args.tts_backend!="bark") GENERATE_SETTINGS["diffusion_iterations"] = gr.Slider(value=30, minimum=0, maximum=512, step=1, label="Iterations", visible=args.tts_backend=="tortoise") GENERATE_SETTINGS["temperature"] = gr.Slider(value=0.2, minimum=0, maximum=1, step=0.1, label="Temperature") @@ -490,7 +495,7 @@ def setup_gradio(): merger_button = gr.Button(value="Run Merger") with gr.Column(): merger_output = gr.TextArea(label="Console Output", max_lines=8) - with gr.Tab("Training"): + with gr.Tab("Training", visible=args.tts_backend != "bark"): with gr.Tab("Prepare Dataset"): with gr.Row(): with gr.Column(): @@ -586,8 +591,10 @@ def setup_gradio(): keep_x_past_checkpoints = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1) with gr.Row(): - training_graph_x_lim = gr.Number(label="X Limit", precision=0, value=0) - training_graph_y_lim = gr.Number(label="Y Limit", precision=0, value=0) + training_graph_x_min = gr.Number(label="X Min", precision=0, value=0) + training_graph_x_max = gr.Number(label="X Max", precision=0, value=0) + training_graph_y_min = gr.Number(label="Y Min", precision=0, value=0) + training_graph_y_max = gr.Number(label="Y Max", precision=0, value=0) with gr.Row(): start_training_button = gr.Button(value="Train") @@ -597,7 +604,7 @@ def setup_gradio(): with gr.Column(): training_loss_graph = gr.LinePlot(label="Training Metrics", - x="epoch", + x="it", # x="epoch", y="value", title="Loss Metrics", color="type", @@ -606,7 +613,7 @@ def setup_gradio(): height=350, ) training_lr_graph = gr.LinePlot(label="Training Metrics", - x="epoch", + x="it", # x="epoch", y="value", title="Learning Rate", color="type", @@ -615,7 +622,7 @@ def setup_gradio(): height=350, ) training_grad_norm_graph = gr.LinePlot(label="Training Metrics", - x="epoch", + x="it", # x="epoch", y="value", title="Gradient Normals", color="type", @@ -765,13 +772,14 @@ def setup_gradio(): inputs=show_experimental_settings, outputs=experimental_column ) - preset.change(fn=update_presets, - inputs=preset, - outputs=[ - GENERATE_SETTINGS['num_autoregressive_samples'], - GENERATE_SETTINGS['diffusion_iterations'], - ], - ) + if preset: + preset.change(fn=update_presets, + inputs=preset, + outputs=[ + GENERATE_SETTINGS['num_autoregressive_samples'], + GENERATE_SETTINGS['diffusion_iterations'], + ], + ) recompute_voice_latents.click(compute_latents_proxy, inputs=[ @@ -860,8 +868,10 @@ def setup_gradio(): training_output.change( fn=update_training_dataplot, inputs=[ - training_graph_x_lim, - training_graph_y_lim, + training_graph_x_min, + training_graph_x_max, + training_graph_y_min, + training_graph_y_max, ], outputs=[ training_loss_graph, @@ -874,8 +884,10 @@ def setup_gradio(): view_losses.click( fn=update_training_dataplot, inputs=[ - training_graph_x_lim, - training_graph_y_lim, + training_graph_x_min, + training_graph_x_max, + training_graph_y_min, + training_graph_y_max, training_configs, ], outputs=[