import os if 'XDG_CACHE_HOME' not in os.environ: os.environ['XDG_CACHE_HOME'] = os.path.realpath(os.path.join(os.getcwd(), './models/')) if 'TORTOISE_MODELS_DIR' not in os.environ: os.environ['TORTOISE_MODELS_DIR'] = os.path.realpath(os.path.join(os.getcwd(), './models/tortoise/')) if 'TRANSFORMERS_CACHE' not in os.environ: os.environ['TRANSFORMERS_CACHE'] = os.path.realpath(os.path.join(os.getcwd(), './models/transformers/')) import argparse import time import math import json import base64 import re import urllib.request import signal import gc import subprocess import psutil import yaml import hashlib import string from tqdm import tqdm import torch import torchaudio import music_tag import gradio as gr import gradio.utils import pandas as pd from datetime import datetime from datetime import timedelta from tortoise.api import TextToSpeech, MODELS, get_model_path, pad_or_truncate from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir, get_voices from tortoise.utils.text import split_and_recombine_text from tortoise.utils.device import get_device_name, set_device_name, get_device_count, get_device_vram, get_device_batch_size, do_gc from whisper.normalizers.english import EnglishTextNormalizer from whisper.normalizers.basic import BasicTextNormalizer from whisper.tokenizer import LANGUAGES MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth" WHISPER_MODELS = ["tiny", "base", "small", "medium", "large"] WHISPER_SPECIALIZED_MODELS = ["tiny.en", "base.en", "small.en", "medium.en"] WHISPER_BACKENDS = ["openai/whisper", "lightmare/whispercpp", "m-bain/whisperx"] VOCODERS = ['univnet', 'bigvgan_base_24khz_100band', 'bigvgan_24khz_100band'] TTSES = ['tortoise'] INFERENCING = False GENERATE_SETTINGS_ARGS = None LEARNING_RATE_SCHEMES = {"Multistep": "MultiStepLR", "Cos. Annealing": "CosineAnnealingLR_Restart"} LEARNING_RATE_SCHEDULE = [ 2, 4, 9, 18, 25, 33, 50 ] RESAMPLERS = {} MIN_TRAINING_DURATION = 0.6 MAX_TRAINING_DURATION = 11.6097505669 VALLE_ENABLED = False try: from vall_e.emb.qnt import encode as valle_quantize VALLE_ENABLED = True except Exception as e: pass if VALLE_ENABLED: TTSES.append('vall-e') args = None tts = None tts_loading = False webui = None voicefixer = None whisper_model = None whisper_vad = None whisper_diarize = None whisper_align_model = None training_state = None current_voice = None def resample( waveform, input_rate, output_rate=44100 ): # mono-ize waveform = torch.mean(waveform, dim=0, keepdim=True) if input_rate == output_rate: return waveform, output_rate key = f'{input_rate}:{output_rate}' if not key in RESAMPLERS: RESAMPLERS[key] = torchaudio.transforms.Resample( input_rate, output_rate, lowpass_filter_width=16, rolloff=0.85, resampling_method="kaiser_window", beta=8.555504641634386, ) return RESAMPLERS[key]( waveform ), output_rate def generate(**kwargs): parameters = {} parameters.update(kwargs) voice = parameters['voice'] progress = parameters['progress'] if 'progress' in parameters else None if parameters['seed'] == 0: parameters['seed'] = None usedSeed = parameters['seed'] global args global tts unload_whisper() unload_voicefixer() if not tts: # should check if it's loading or unloaded, and load it if it's unloaded if tts_loading: raise Exception("TTS is still initializing...") if progress is not None: progress(0, "Initializing TTS...") load_tts() if hasattr(tts, "loading") and tts.loading: raise Exception("TTS is still initializing...") do_gc() voice_samples = None conditioning_latents =None sample_voice = None voice_cache = {} def fetch_voice( voice ): cache_key = f'{voice}:{tts.autoregressive_model_hash[:8]}' if cache_key in voice_cache: return voice_cache[cache_key] print(f"Loading voice: {voice} with model {tts.autoregressive_model_hash[:8]}") sample_voice = None if voice == "microphone": if parameters['mic_audio'] is None: raise Exception("Please provide audio from mic when choosing `microphone` as a voice input") voice_samples, conditioning_latents = [load_audio(parameters['mic_audio'], tts.input_sample_rate)], None elif voice == "random": voice_samples, conditioning_latents = None, tts.get_random_conditioning_latents() else: if progress is not None: progress(0, desc=f"Loading voice: {voice}") voice_samples, conditioning_latents = load_voice(voice, model_hash=tts.autoregressive_model_hash) if voice_samples and len(voice_samples) > 0: if conditioning_latents is None: conditioning_latents = compute_latents(voice=voice, voice_samples=voice_samples, voice_latents_chunks=parameters['voice_latents_chunks']) sample_voice = torch.cat(voice_samples, dim=-1).squeeze().cpu() voice_samples = None voice_cache[cache_key] = (voice_samples, conditioning_latents, sample_voice) return voice_cache[cache_key] def get_settings( override=None ): settings = { 'temperature': float(parameters['temperature']), 'top_p': float(parameters['top_p']), 'diffusion_temperature': float(parameters['diffusion_temperature']), 'length_penalty': float(parameters['length_penalty']), 'repetition_penalty': float(parameters['repetition_penalty']), 'cond_free_k': float(parameters['cond_free_k']), 'num_autoregressive_samples': parameters['num_autoregressive_samples'], 'sample_batch_size': args.sample_batch_size, 'diffusion_iterations': parameters['diffusion_iterations'], 'voice_samples': None, 'conditioning_latents': None, 'use_deterministic_seed': parameters['seed'], 'return_deterministic_state': True, 'k': parameters['candidates'], 'diffusion_sampler': parameters['diffusion_sampler'], 'breathing_room': parameters['breathing_room'], 'progress': parameters['progress'], 'half_p': "Half Precision" in parameters['experimentals'], 'cond_free': "Conditioning-Free" in parameters['experimentals'], 'cvvp_amount': parameters['cvvp_weight'], 'autoregressive_model': args.autoregressive_model, 'diffusion_model': args.diffusion_model, 'tokenizer_json': args.tokenizer_json, } # could be better to just do a ternary on everything above, but i am not a professional selected_voice = voice if override is not None: if 'voice' in override: selected_voice = override['voice'] for k in override: if k not in settings: continue settings[k] = override[k] if settings['autoregressive_model'] is not None: if settings['autoregressive_model'] == "auto": settings['autoregressive_model'] = deduce_autoregressive_model(selected_voice) tts.load_autoregressive_model(settings['autoregressive_model']) if settings['diffusion_model'] is not None: if settings['diffusion_model'] == "auto": settings['diffusion_model'] = deduce_diffusion_model(selected_voice) tts.load_diffusion_model(settings['diffusion_model']) if settings['tokenizer_json'] is not None: tts.load_tokenizer_json(settings['tokenizer_json']) settings['voice_samples'], settings['conditioning_latents'], _ = fetch_voice(voice=selected_voice) # clamp it down for the insane users who want this # it would be wiser to enforce the sample size to the batch size, but this is what the user wants settings['sample_batch_size'] = args.sample_batch_size if not settings['sample_batch_size']: settings['sample_batch_size'] = tts.autoregressive_batch_size if settings['num_autoregressive_samples'] < settings['sample_batch_size']: settings['sample_batch_size'] = settings['num_autoregressive_samples'] if settings['conditioning_latents'] is not None and len(settings['conditioning_latents']) == 2 and settings['cvvp_amount'] > 0: print("Requesting weighing against CVVP weight, but voice latents are missing some extra data. Please regenerate your voice latents with 'Slimmer voice latents' unchecked.") settings['cvvp_amount'] = 0 return settings if not parameters['delimiter']: parameters['delimiter'] = "\n" elif parameters['delimiter'] == "\\n": parameters['delimiter'] = "\n" if parameters['delimiter'] and parameters['delimiter'] != "" and parameters['delimiter'] in parameters['text']: texts = parameters['text'].split(parameters['delimiter']) else: texts = split_and_recombine_text(parameters['text']) full_start_time = time.time() outdir = f"{args.results_folder}/{voice}/" os.makedirs(outdir, exist_ok=True) audio_cache = {} volume_adjust = torchaudio.transforms.Vol(gain=args.output_volume, gain_type="amplitude") if args.output_volume != 1 else None idx = 0 idx_cache = {} for i, file in enumerate(os.listdir(outdir)): filename = os.path.basename(file) extension = os.path.splitext(filename)[1] if extension != ".json" and extension != ".wav": continue match = re.findall(rf"^{voice}_(\d+)(?:.+?)?{extension}$", filename) if match and len(match) > 0: key = int(match[0]) idx_cache[key] = True if len(idx_cache) > 0: keys = sorted(list(idx_cache.keys())) idx = keys[-1] + 1 idx = pad(idx, 4) def get_name(line=0, candidate=0, combined=False): name = f"{idx}" if combined: name = f"{name}_combined" elif len(texts) > 1: name = f"{name}_{line}" if parameters['candidates'] > 1: name = f"{name}_{candidate}" return name def get_info( voice, settings = None, latents = True ): info = {} info.update(parameters) info['time'] = time.time()-full_start_time info['datetime'] = datetime.now().isoformat() info['model'] = tts.autoregressive_model_path info['model_hash'] = tts.autoregressive_model_hash info['progress'] = None del info['progress'] if info['delimiter'] == "\n": info['delimiter'] = "\\n" if settings is not None: for k in settings: if k in info: info[k] = settings[k] if 'half_p' in settings and 'cond_free' in settings: info['experimentals'] = [] if settings['half_p']: info['experimentals'].append("Half Precision") if settings['cond_free']: info['experimentals'].append("Conditioning-Free") if latents and "latents" not in info: voice = info['voice'] model_hash = settings["model_hash"][:8] if settings is not None and "model_hash" in settings else tts.autoregressive_model_hash[:8] dir = f'{get_voice_dir()}/{voice}/' latents_path = f'{dir}/cond_latents_{model_hash}.pth' if voice == "random" or voice == "microphone": if latents and settings is not None and settings['conditioning_latents']: os.makedirs(dir, exist_ok=True) torch.save(conditioning_latents, latents_path) if latents_path and os.path.exists(latents_path): try: with open(latents_path, 'rb') as f: info['latents'] = base64.b64encode(f.read()).decode("ascii") except Exception as e: pass return info INFERENCING = True for line, cut_text in enumerate(texts): if should_phonemize(): cut_text = phonemizer( cut_text ) if parameters['emotion'] == "Custom": if parameters['prompt'] and parameters['prompt'].strip() != "": cut_text = f"[{parameters['prompt']},] {cut_text}" elif parameters['emotion'] != "None" and parameters['emotion']: cut_text = f"[I am really {parameters['emotion'].lower()},] {cut_text}" progress.msg_prefix = f'[{str(line+1)}/{str(len(texts))}]' print(f"{progress.msg_prefix} Generating line: {cut_text}") start_time = time.time() # do setting editing match = re.findall(r'^(\{.+\}) (.+?)$', cut_text) override = None if match and len(match) > 0: match = match[0] try: override = json.loads(match[0]) cut_text = match[1].strip() except Exception as e: raise Exception("Prompt settings editing requested, but received invalid JSON") settings = get_settings( override=override ) gen, additionals = tts.tts(cut_text, **settings ) parameters['seed'] = additionals[0] run_time = time.time()-start_time print(f"Generating line took {run_time} seconds") if not isinstance(gen, list): gen = [gen] for j, g in enumerate(gen): audio = g.squeeze(0).cpu() name = get_name(line=line, candidate=j) settings['text'] = cut_text settings['time'] = run_time settings['datetime'] = datetime.now().isoformat(), settings['model'] = tts.autoregressive_model_path settings['model_hash'] = tts.autoregressive_model_hash audio_cache[name] = { 'audio': audio, 'settings': get_info(voice=override['voice'] if override and 'voice' in override else voice, settings=settings) } # save here in case some error happens mid-batch torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, tts.output_sample_rate) del gen do_gc() INFERENCING = False for k in audio_cache: audio = audio_cache[k]['audio'] audio, _ = resample(audio, tts.output_sample_rate, args.output_sample_rate) if volume_adjust is not None: audio = volume_adjust(audio) audio_cache[k]['audio'] = audio torchaudio.save(f'{outdir}/{voice}_{k}.wav', audio, args.output_sample_rate) output_voices = [] for candidate in range(parameters['candidates']): if len(texts) > 1: audio_clips = [] for line in range(len(texts)): name = get_name(line=line, candidate=candidate) audio = audio_cache[name]['audio'] audio_clips.append(audio) name = get_name(candidate=candidate, combined=True) audio = torch.cat(audio_clips, dim=-1) torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, args.output_sample_rate) audio = audio.squeeze(0).cpu() audio_cache[name] = { 'audio': audio, 'settings': get_info(voice=voice), 'output': True } else: name = get_name(candidate=candidate) audio_cache[name]['output'] = True if args.voice_fixer: if not voicefixer: progress(0, "Loading voicefix...") load_voicefixer() try: fixed_cache = {} for name in progress.tqdm(audio_cache, desc="Running voicefix..."): del audio_cache[name]['audio'] if 'output' not in audio_cache[name] or not audio_cache[name]['output']: continue path = f'{outdir}/{voice}_{name}.wav' fixed = f'{outdir}/{voice}_{name}_fixed.wav' voicefixer.restore( input=path, output=fixed, cuda=get_device_name() == "cuda" and args.voice_fixer_use_cuda, #mode=mode, ) fixed_cache[f'{name}_fixed'] = { 'settings': audio_cache[name]['settings'], 'output': True } audio_cache[name]['output'] = False for name in fixed_cache: audio_cache[name] = fixed_cache[name] except Exception as e: print(e) print("\nFailed to run Voicefixer") for name in audio_cache: if 'output' not in audio_cache[name] or not audio_cache[name]['output']: if args.prune_nonfinal_outputs: audio_cache[name]['pruned'] = True os.remove(f'{outdir}/{voice}_{name}.wav') continue output_voices.append(f'{outdir}/{voice}_{name}.wav') if not args.embed_output_metadata: with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f: f.write(json.dumps(audio_cache[name]['settings'], indent='\t') ) if args.embed_output_metadata: for name in progress.tqdm(audio_cache, desc="Embedding metadata..."): if 'pruned' in audio_cache[name] and audio_cache[name]['pruned']: continue metadata = music_tag.load_file(f"{outdir}/{voice}_{name}.wav") metadata['lyrics'] = json.dumps(audio_cache[name]['settings']) metadata.save() if sample_voice is not None: sample_voice = (tts.input_sample_rate, sample_voice.numpy()) info = get_info(voice=voice, latents=False) print(f"Generation took {info['time']} seconds, saved to '{output_voices[0]}'\n") info['seed'] = usedSeed if 'latents' in info: del info['latents'] os.makedirs('./config/', exist_ok=True) with open(f'./config/generate.json', 'w', encoding="utf-8") as f: f.write(json.dumps(info, indent='\t') ) stats = [ [ parameters['seed'], "{:.3f}".format(info['time']) ] ] return ( sample_voice, output_voices, stats, ) def cancel_generate(): if not INFERENCING: return import tortoise.api tortoise.api.STOP_SIGNAL = True def hash_file(path, algo="md5", buffer_size=0): hash = None if algo == "md5": hash = hashlib.md5() elif algo == "sha1": hash = hashlib.sha1() else: raise Exception(f'Unknown hash algorithm specified: {algo}') if not os.path.exists(path): raise Exception(f'Path not found: {path}') with open(path, 'rb') as f: if buffer_size > 0: while True: data = f.read(buffer_size) if not data: break hash.update(data) else: hash.update(f.read()) return "{0}".format(hash.hexdigest()) def update_baseline_for_latents_chunks( voice ): global current_voice current_voice = voice path = f'{get_voice_dir()}/{voice}/' if not os.path.isdir(path): return 1 dataset_file = f'./training/{voice}/train.txt' if os.path.exists(dataset_file): return 0 # 0 will leverage using the LJspeech dataset for computing latents files = os.listdir(path) total = 0 total_duration = 0 for file in files: if file[-4:] != ".wav": continue metadata = torchaudio.info(f'{path}/{file}') duration = metadata.num_frames / metadata.sample_rate total_duration += duration total = total + 1 # brain too fried to figure out a better way if args.autocalculate_voice_chunk_duration_size == 0: return int(total_duration / total) if total > 0 else 1 return int(total_duration / args.autocalculate_voice_chunk_duration_size) if total_duration > 0 else 1 def compute_latents(voice=None, voice_samples=None, voice_latents_chunks=0, progress=None): global tts global args unload_whisper() unload_voicefixer() if not tts: if tts_loading: raise Exception("TTS is still initializing...") load_tts() if hasattr(tts, "loading") and tts.loading: raise Exception("TTS is still initializing...") if args.autoregressive_model == "auto": tts.load_autoregressive_model(deduce_autoregressive_model(voice)) if voice: load_from_dataset = voice_latents_chunks == 0 if load_from_dataset: dataset_path = f'./training/{voice}/train.txt' if not os.path.exists(dataset_path): load_from_dataset = False else: with open(dataset_path, 'r', encoding="utf-8") as f: lines = f.readlines() print("Leveraging dataset for computing latents") voice_samples = [] max_length = 0 for line in lines: filename = f'./training/{voice}/{line.split("|")[0]}' waveform = load_audio(filename, 22050) max_length = max(max_length, waveform.shape[-1]) voice_samples.append(waveform) for i in range(len(voice_samples)): voice_samples[i] = pad_or_truncate(voice_samples[i], max_length) voice_latents_chunks = len(voice_samples) if voice_latents_chunks == 0: print("Dataset is empty!") load_from_dataset = True if not load_from_dataset: voice_samples, _ = load_voice(voice, load_latents=False) if voice_samples is None: return conditioning_latents = tts.get_conditioning_latents(voice_samples, return_mels=not args.latents_lean_and_mean, slices=voice_latents_chunks, force_cpu=args.force_cpu_for_conditioning_latents, progress=progress) if len(conditioning_latents) == 4: conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None) outfile = f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth' torch.save(conditioning_latents, outfile) print(f'Saved voice latents: {outfile}') return conditioning_latents # superfluous, but it cleans up some things class TrainingState(): def __init__(self, config_path, keep_x_past_checkpoints=0, start=True): self.killed = False self.training_dir = os.path.dirname(config_path) with open(config_path, 'r') as file: self.yaml_config = yaml.safe_load(file) self.json_config = json.load(open(f"{self.training_dir}/train.json", 'r', encoding="utf-8")) self.dataset_path = f"{self.training_dir}/train.txt" with open(self.dataset_path, 'r', encoding="utf-8") as f: self.dataset_size = len(f.readlines()) self.batch_size = self.json_config["batch_size"] self.save_rate = self.json_config["save_rate"] self.epoch = 0 self.epochs = self.json_config["epochs"] self.it = 0 self.its = calc_iterations( self.epochs, self.dataset_size, self.batch_size ) self.step = 0 self.steps = int(self.its / self.dataset_size) self.checkpoint = 0 self.checkpoints = int((self.its - self.it) / self.save_rate) self.gpus = self.json_config['gpus'] self.buffer = [] self.open_state = False self.training_started = False self.info = {} self.it_rate = "" self.it_rates = 0 self.epoch_rate = "" self.eta = "?" self.eta_hhmmss = "?" self.nan_detected = False self.last_info_check_at = 0 self.statistics = { 'loss': [], 'lr': [], } self.losses = [] self.metrics = { 'step': "", 'rate': "", 'loss': "", } self.loss_milestones = [ 1.0, 0.15, 0.05 ] if keep_x_past_checkpoints > 0: self.cleanup_old(keep=keep_x_past_checkpoints) if start: self.spawn_process(config_path=config_path, gpus=self.gpus) def spawn_process(self, config_path, gpus=1): if args.tts_backend == "vall-e": self.cmd = ['deepspeed', f'--num_gpus={gpus}', '--module', 'vall_e.train', f'yaml="{config_path}"'] else: self.cmd = ['train.bat', config_path] if os.name == "nt" else ['./train.sh', config_path] print("Spawning process: ", " ".join(self.cmd)) self.process = subprocess.Popen(self.cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True) def parse_metrics(self, data): if isinstance(data, str): if line.find('Training Metrics:') >= 0: data = json.loads(line.split("Training Metrics:")[-1]) data['mode'] = "training" elif line.find('Validation Metrics:') >= 0: data = json.loads(line.split("Validation Metrics:")[-1]) data['mode'] = "validation" else: return self.info = data if 'epoch' in self.info: self.epoch = int(self.info['epoch']) if 'it' in self.info: self.it = int(self.info['it']) if 'step' in self.info: self.step = int(self.info['step']) if 'steps' in self.info: self.steps = int(self.info['steps']) if 'iteration_rate' in self.info: it_rate = self.info['iteration_rate'] self.it_rate = f'{"{:.3f}".format(1/it_rate)}it/s' if 0 < it_rate and it_rate < 1 else f'{"{:.3f}".format(it_rate)}s/it' self.it_rates += it_rate epoch_rate = self.it_rates / self.it * self.steps if epoch_rate > 0: self.epoch_rate = f'{"{:.3f}".format(1/epoch_rate)}epoch/s' if 0 < epoch_rate and epoch_rate < 1 else f'{"{:.3f}".format(epoch_rate)}s/epoch' try: self.eta = (self.its - self.it) * (self.it_rates / self.it) eta = str(timedelta(seconds=int(self.eta))) self.eta_hhmmss = eta except Exception as e: self.eta_hhmmss = "?" pass self.metrics['step'] = [f"{self.epoch}/{self.epochs}"] if self.epochs != self.its: self.metrics['step'].append(f"{self.it}/{self.its}") if self.steps > 1: self.metrics['step'].append(f"{self.step}/{self.steps}") self.metrics['step'] = ", ".join(self.metrics['step']) epoch = self.epoch + (self.step / self.steps) for k in ['lr'] if args.tts_backend == "tortoise" else ['ar.lr', 'nar.lr', 'aar-half.lr', 'nar-half.lr', 'ar-quarter.lr', 'nar-quarter.lr']: if k not in self.info: continue self.statistics['lr'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k}) for k in ['loss_text_ce', 'loss_mel_ce'] if args.tts_backend == "tortoise" else ['ar.loss', 'nar.loss', 'aar-half.loss', 'nar-half.loss', 'ar-quarter.loss', 'nar-quarter.loss']: if k not in self.info: continue self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{"val_" if data["mode"] == "validation" else ""}{k}' }) self.losses.append( self.statistics['loss'][-1] ) return data def get_status(self): message = None self.metrics['rate'] = [] if self.epoch_rate: self.metrics['rate'].append(self.epoch_rate) if self.it_rate and self.epoch_rate[:-7] != self.it_rate[:-4]: self.metrics['rate'].append(self.it_rate) self.metrics['rate'] = ", ".join(self.metrics['rate']) eta_hhmmss = self.eta_hhmmss if self.eta_hhmmss else "?" self.metrics['loss'] = [] if 'lr' in self.info: self.metrics['loss'].append(f'LR: {"{:.3e}".format(self.info["lr"])}') if len(self.losses) > 0: self.metrics['loss'].append(f'Loss: {"{:.3f}".format(self.losses[-1]["value"])}') if False and len(self.losses) >= 2: deriv = 0 accum_length = len(self.losses)//2 # i *guess* this is fine when you think about it loss_value = self.losses[-1]["value"] for i in range(accum_length): d1_loss = self.losses[accum_length-i-1]["value"] d2_loss = self.losses[accum_length-i-2]["value"] dloss = (d2_loss - d1_loss) d1_step = self.losses[accum_length-i-1]["it"] d2_step = self.losses[accum_length-i-2]["it"] dstep = (d2_step - d1_step) if dstep == 0: continue inst_deriv = dloss / dstep deriv += inst_deriv deriv = deriv / accum_length print("Deriv: ", deriv) if deriv != 0: # dloss < 0: next_milestone = None for milestone in self.loss_milestones: if loss_value > milestone: next_milestone = milestone break print(f"Loss value: {loss_value} | Next milestone: {next_milestone} | Distance: {loss_value - next_milestone}") if next_milestone: # tfw can do simple calculus but not basic algebra in my head est_its = (next_milestone - loss_value) / deriv * 100 print(f"Estimated: {est_its}") if est_its >= 0: self.metrics['loss'].append(f'Est. milestone {next_milestone} in: {int(est_its)}its') else: est_loss = inst_deriv * (self.its - self.it) + loss_value if est_loss >= 0: self.metrics['loss'].append(f'Est. final loss: {"{:.3f}".format(est_loss)}') self.metrics['loss'] = ", ".join(self.metrics['loss']) message = f"[{self.metrics['step']}] [{self.metrics['rate']}] [ETA: {eta_hhmmss}] [{self.metrics['loss']}]" if self.nan_detected: message = f"[!NaN DETECTED! {self.nan_detected}] {message}" return message def load_statistics(self, update=False): if not os.path.isdir(self.training_dir): return if args.tts_backend == "tortoise": logs = sorted([f'{self.training_dir}/finetune/{d}' for d in os.listdir(f'{self.training_dir}/finetune/') if d[-4:] == ".log" ]) else: logs = sorted([f'{self.training_dir}/logs/{d}/log.txt' for d in os.listdir(f'{self.training_dir}/logs/') ]) if update: logs = [logs[-1]] infos = {} highest_step = self.last_info_check_at if not update: self.statistics['loss'] = [] self.statistics['lr'] = [] self.it_rates = 0 for log in logs: with open(log, 'r', encoding="utf-8") as f: lines = f.readlines() for line in lines: if line.find('Training Metrics:') >= 0: data = json.loads(line.split("Training Metrics:")[-1]) data['mode'] = "training" elif line.find('Validation Metrics:') >= 0: data = json.loads(line.split("Validation Metrics:")[-1]) data['mode'] = "validation" else: continue if args.tts_backend == "tortoise": if "it" not in data: continue it = data['it'] else: if "global_step" not in data: continue it = data['global_step'] if update and it <= self.last_info_check_at: continue self.parse_metrics(data) self.last_info_check_at = highest_step def cleanup_old(self, keep=2): if keep <= 0: return if args.tts_backend == "vall-e": return if not os.path.isdir(f'{self.training_dir}/finetune/'): return models = sorted([ int(d[:-8]) for d in os.listdir(f'{self.training_dir}/finetune/models/') if d[-8:] == "_gpt.pth" ]) states = sorted([ int(d[:-6]) for d in os.listdir(f'{self.training_dir}/finetune/training_state/') if d[-6:] == ".state" ]) remove_models = models[:-keep] remove_states = states[:-keep] for d in remove_models: path = f'{self.training_dir}/finetune/models/{d}_gpt.pth' print("Removing", path) os.remove(path) for d in remove_states: path = f'{self.training_dir}/finetune/training_state/{d}.state' print("Removing", path) os.remove(path) def parse_valle_metrics(self, data): res = {} res['mode'] = "training" res['loss'] = data['model.loss'] res['lr'] = data['model.lr'] res['it'] = data['global_step'] res['step'] = res['it'] % self.dataset_size res['steps'] = self.steps res['epoch'] = int(res['it'] / self.dataset_size) res['iteration_rate'] = data['elapsed_time'] return res def parse(self, line, verbose=False, keep_x_past_checkpoints=0, buffer_size=8, progress=None ): self.buffer.append(f'{line}') data = None percent = 0 message = None should_return = False MESSAGE_START = 'Start training from epoch' MESSAGE_FINSIHED = 'Finished training' MESSAGE_SAVING = 'Saving models and training states.' MESSAGE_METRICS_TRAINING = 'Training Metrics:' MESSAGE_METRICS_VALIDATION = 'Validation Metrics:' if line.find(MESSAGE_FINSIHED) >= 0: self.killed = True # rip out iteration info elif not self.training_started: if line.find(MESSAGE_START) >= 0: self.training_started = True # could just leverage the above variable, but this is python, and there's no point in these aggressive microoptimizations match = re.findall(r'epoch: ([\d,]+)', line) if match and len(match) > 0: self.epoch = int(match[0].replace(",", "")) match = re.findall(r'iter: ([\d,]+)', line) if match and len(match) > 0: self.it = int(match[0].replace(",", "")) self.checkpoints = int((self.its - self.it) / self.save_rate) self.load_statistics() should_return = True else: if line.find(MESSAGE_SAVING) >= 0: self.checkpoint += 1 message = f"[{self.checkpoint}/{self.checkpoints}] Saving checkpoint..." percent = self.checkpoint / self.checkpoints self.cleanup_old(keep=keep_x_past_checkpoints) elif line.find(MESSAGE_METRICS_TRAINING) >= 0: data = json.loads(line.split(MESSAGE_METRICS_TRAINING)[-1]) data['mode'] = "training" elif line.find(MESSAGE_METRICS_VALIDATION) >= 0: data = json.loads(line.split(MESSAGE_METRICS_VALIDATION)[-1]) data['mode'] = "validation" if data is not None: if ': nan' in line and not self.nan_detected: self.nan_detected = self.it self.parse_metrics( data ) message = self.get_status() if message: percent = self.it / float(self.its) # self.epoch / float(self.epochs) if progress is not None: progress(percent, message) self.buffer.append(f'[{"{:.3f}".format(percent*100)}%] {message}') should_return = True if verbose and not self.training_started: should_return = True self.buffer = self.buffer[-buffer_size:] result = None if should_return: result = "".join(self.buffer) if not self.training_started else message return ( result, percent, message, ) try: import altair as alt alt.data_transformers.enable('default', max_rows=None) except Exception as e: print(e) pass def run_training(config_path, verbose=False, keep_x_past_checkpoints=0, progress=gr.Progress(track_tqdm=True)): global training_state if training_state and training_state.process: return "Training already in progress" # ensure we have the dvae.pth get_model_path('dvae.pth') # I don't know if this is still necessary, as it was bitching at me for not doing this, despite it being in a separate process torch.multiprocessing.freeze_support() unload_tts() unload_whisper() unload_voicefixer() training_state = TrainingState(config_path=config_path, keep_x_past_checkpoints=keep_x_past_checkpoints) for line in iter(training_state.process.stdout.readline, ""): if training_state.killed: return result, percent, message = training_state.parse( line=line, verbose=verbose, keep_x_past_checkpoints=keep_x_past_checkpoints, progress=progress ) print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}") if result: yield result if progress is not None and message: progress(percent, message) if training_state: training_state.process.stdout.close() return_code = training_state.process.wait() training_state = None def update_training_dataplot(config_path=None): global training_state losses = None lrs = None if not training_state: if config_path: training_state = TrainingState(config_path=config_path, start=False) training_state.load_statistics() message = training_state.get_status() if len(training_state.statistics['loss']) > 0: losses = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['loss']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Loss Metrics", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,) if len(training_state.statistics['lr']) > 0: lrs = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['lr']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Learning Rate", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,) del training_state training_state = None else: # training_state.load_statistics() if len(training_state.statistics['loss']) > 0: losses = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['loss']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Loss Metrics", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,) if len(training_state.statistics['lr']) > 0: lrs = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['lr']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Learning Rate", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,) return (losses, lrs) def reconnect_training(verbose=False, progress=gr.Progress(track_tqdm=True)): global training_state if not training_state or not training_state.process: return "Training not in progress" for line in iter(training_state.process.stdout.readline, ""): result, percent, message = training_state.parse( line=line, verbose=verbose, progress=progress ) print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}") if result: yield result if progress is not None and message: progress(percent, message) def stop_training(): global training_state if training_state is None: return "No training in progress" print("Killing training process...") training_state.killed = True children = [] if args.tts_backend == "tortoise": # wrapped in a try/catch in case for some reason this fails outside of Linux try: children = [p.info for p in psutil.process_iter(attrs=['pid', 'name', 'cmdline']) if './src/train.py' in p.info['cmdline']] except Exception as e: pass training_state.process.stdout.close() training_state.process.terminate() training_state.process.kill() elif args.tts_backend == "vall-e": print(training_state.process.communicate(input='quit')[0]) return_code = training_state.process.wait() for p in children: os.kill( p['pid'], signal.SIGKILL ) training_state = None print("Killed training process.") return f"Training cancelled: {return_code}" def get_halfp_model_path(): autoregressive_model_path = get_model_path('autoregressive.pth') return autoregressive_model_path.replace(".pth", "_half.pth") def convert_to_halfp(): autoregressive_model_path = get_model_path('autoregressive.pth') print(f'Converting model to half precision: {autoregressive_model_path}') model = torch.load(autoregressive_model_path) for k in model: model[k] = model[k].half() outfile = get_halfp_model_path() torch.save(model, outfile) print(f'Converted model to half precision: {outfile}') # collapses short segments into the previous segment def whisper_sanitize( results ): sanitized = json.loads(json.dumps(results)) sanitized['segments'] = [] for segment in results['segments']: length = segment['end'] - segment['start'] if length >= MIN_TRAINING_DURATION or len(sanitized['segments']) == 0: sanitized['segments'].append(segment) continue last_segment = sanitized['segments'][-1] # segment already asimilitated it, somehow if last_segment['end'] >= segment['end']: continue """ # segment already asimilitated it, somehow if last_segment['text'].endswith(segment['text']): continue """ last_segment['text'] += segment['text'] last_segment['end'] = segment['end'] for i in range(len(sanitized['segments'])): sanitized['segments'][i]['id'] = i return sanitized def whisper_transcribe( file, language=None ): # shouldn't happen, but it's for safety global whisper_model global whisper_vad global whisper_diarize global whisper_align_model if not whisper_model: load_whisper_model(language=language) if args.whisper_backend == "openai/whisper": if not language: language = None return whisper_model.transcribe(file, language=language) if args.whisper_backend == "lightmare/whispercpp": res = whisper_model.transcribe(file) segments = whisper_model.extract_text_and_timestamps( res ) result = { 'text': [], 'segments': [] } for segment in segments: reparsed = { 'start': segment[0] / 100.0, 'end': segment[1] / 100.0, 'text': segment[2], 'id': len(result['segments']) } result['text'].append( segment[2] ) result['segments'].append(reparsed) result['text'] = " ".join(result['text']) return result if args.whisper_backend == "m-bain/whisperx": import whisperx from whisperx.diarize import assign_word_speakers device = "cuda" if get_device_name() == "cuda" else "cpu" if whisper_vad: if args.whisper_batchsize > 1: result = whisperx.transcribe_with_vad_parallel(whisper_model, file, whisper_vad, batch_size=args.whisper_batchsize, language=language, task="transcribe") else: result = whisperx.transcribe_with_vad(whisper_model, file, whisper_vad) else: result = whisper_model.transcribe(file) align_model, metadata = whisper_align_model result_aligned = whisperx.align(result["segments"], align_model, metadata, file, device) if whisper_diarize: diarize_segments = whisper_diarize(file) diarize_df = pd.DataFrame(diarize_segments.itertracks(yield_label=True)) diarize_df['start'] = diarize_df[0].apply(lambda x: x.start) diarize_df['end'] = diarize_df[0].apply(lambda x: x.end) # assumes each utterance is single speaker (needs fix) result_segments, word_segments = assign_word_speakers(diarize_df, result_aligned["segments"], fill_nearest=True) result_aligned["segments"] = result_segments result_aligned["word_segments"] = word_segments for i in range(len(result_aligned['segments'])): del result_aligned['segments'][i]['word-segments'] del result_aligned['segments'][i]['char-segments'] result['segments'] = result_aligned['segments'] result['text'] = [] for segment in result['segments']: segment['id'] = len(result['text']) result['text'].append(segment['text'].strip()) result['text'] = " ".join(result['text']) return result def validate_waveform( waveform, sample_rate, min_only=False ): if not torch.any(waveform < 0): return "Waveform is empty" num_channels, num_frames = waveform.shape duration = num_frames / sample_rate if duration < MIN_TRAINING_DURATION: return "Duration too short ({:.3f}s < {:.3f}s)".format(duration, MIN_TRAINING_DURATION) if not min_only: if duration > MAX_TRAINING_DURATION: return "Duration too long ({:.3f}s < {:.3f}s)".format(MAX_TRAINING_DURATION, duration) return def transcribe_dataset( voice, language=None, skip_existings=False, progress=None ): unload_tts() global whisper_model if whisper_model is None: load_whisper_model(language=language) results = {} files = sorted( get_voices(load_latents=False)[voice] ) indir = f'./training/{voice}/' infile = f'{indir}/whisper.json' os.makedirs(f'{indir}/audio/', exist_ok=True) if os.path.exists(infile): results = json.load(open(infile, 'r', encoding="utf-8")) for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress): basename = os.path.basename(file) modified = False if basename in results and skip_existings: print(f"Skipping already parsed file: {basename}") else: try: result = whisper_transcribe(file, language=language) modified = True except Exception as e: print("Failed to transcribe:", file) continue results[basename] = result try: sanitized = whisper_sanitize(results[basename]) if len(sanitized['segments']) > 0 and len(sanitized['segments']) != len(results[basename]['segments']): results[basename] = sanitized modified = True print("Segments sanizited: ", basename) except Exception as e: print("Failed to sanitize:", basename, e) pass waveform, sample_rate = torchaudio.load(file) # resample to the input rate, since it'll get resampled for training anyways # this should also "help" increase throughput a bit when filling the dataloaders waveform, sample_rate = resample(waveform, sample_rate, tts.input_sample_rate if tts is not None else 22050) if waveform.shape[0] == 2: waveform = waveform[:1] torchaudio.save(f"{indir}/audio/{basename}", waveform, sample_rate, encoding="PCM_S", bits_per_sample=16) if modified: with open(infile, 'w', encoding="utf-8") as f: f.write(json.dumps(results, indent='\t')) do_gc() return f"Processed dataset to: {indir}" def slice_waveform( waveform, sample_rate, start, end, trim ): start = int(start * sample_rate) end = int(end * sample_rate) if start < 0: start = 0 if end >= waveform.shape[-1]: end = waveform.shape[-1] - 1 sliced = waveform[:, start:end] error = validate_waveform( sliced, sample_rate, min_only=True ) if trim and not error: sliced = torchaudio.functional.vad( sliced, sample_rate ) return sliced, error def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, results=None, progress=gr.Progress() ): indir = f'./training/{voice}/' infile = f'{indir}/whisper.json' messages = [] if not os.path.exists(infile): raise Exception(f"Missing dataset: {infile}") if results is None: results = json.load(open(infile, 'r', encoding="utf-8")) files = 0 segments = 0 for filename in results: path = f'./voices/{voice}/{filename}' if not os.path.exists(path): path = f'./training/{voice}/{filename}' if not os.path.exists(path): message = f"Missing source audio: {filename}" print(message) messages.append(message) continue files += 1 result = results[filename] waveform, sample_rate = torchaudio.load(path) num_channels, num_frames = waveform.shape duration = num_frames / sample_rate for segment in result['segments']: file = filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav") sliced, error = slice_waveform( waveform, sample_rate, segment['start'] + start_offset, segment['end'] + end_offset, trim_silence ) if error: message = f"{error}, skipping... {file}" print(message) messages.append(message) continue sliced, _ = resample( sliced, sample_rate, 22050 ) if waveform.shape[0] == 2: waveform = waveform[:1] torchaudio.save(f"{indir}/audio/{file}", sliced, 22050, encoding="PCM_S", bits_per_sample=16) segments +=1 messages.append(f"Sliced segments: {files} => {segments}.") return "\n".join(messages) # takes an LJSpeech-dataset-formatted .txt file and phonemize it def phonemize_txt_file( path ): with open(path, 'r', encoding='utf-8') as f: lines = f.readlines() reparsed = [] with open(path.replace(".txt", ".phn.txt"), 'a', encoding='utf-8') as f: for line in enumerate_progress(lines, desc='Phonemizing...'): split = line.split("|") audio = split[0] text = split[2] phonemes = phonemizer( text ) reparsed.append(f'{audio}|{phonemes}') f.write(f'\n{audio}|{phonemes}') joined = "\n".join(reparsed) with open(path.replace(".txt", ".phn.txt"), 'w', encoding='utf-8') as f: f.write(joined) return joined # takes an LJSpeech-dataset-formatted .txt (and phonemized .phn.txt from the above) and creates a JSON that should slot in as whisper.json def create_dataset_json( path ): with open(path, 'r', encoding='utf-8') as f: lines = f.readlines() phonemes = None phn_path = path.replace(".txt", ".phn.txt") if os.path.exists(phn_path): with open(phn_path, 'r', encoding='utf-8') as f: phonemes = f.readlines() data = {} for line in lines: split = line.split("|") audio = split[0] text = split[1] data[audio] = { 'text': text.strip() } for line in phonemes: split = line.split("|") audio = split[0] text = split[1] data[audio]['phonemes'] = text.strip() with open(path.replace(".txt", ".json"), 'w', encoding='utf-8') as f: f.write(json.dumps(data, indent="\t")) def phonemizer( text, language="en-us" ): from phonemizer import phonemize if language == "en": language = "en-us" return phonemize( text, language=language, strip=True, preserve_punctuation=True, with_stress=True, backend=args.phonemizer_backend ) def should_phonemize(): should = args.tokenizer_json is not None and args.tokenizer_json[-8:] == "ipa.json" if should: try: from phonemizer import phonemize except Exception as e: return False return should def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, progress=gr.Progress() ): indir = f'./training/{voice}/' infile = f'{indir}/whisper.json' if not os.path.exists(infile): raise Exception(f"Missing dataset: {infile}") results = json.load(open(infile, 'r', encoding="utf-8")) errored = 0 messages = [] normalize = True phonemize = should_phonemize() lines = { 'training': [], 'validation': [] } segments = {} for filename in results: use_segment = use_segments result = results[filename] lang = result['language'] language = LANGUAGES[lang] if lang in LANGUAGES else lang normalizer = EnglishTextNormalizer() if language and language == "english" else BasicTextNormalizer() # check if unsegmented text exceeds 200 characters if not use_segment: if len(result['text']) > 200: message = f"Text length too long (200 < {len(result['text'])}), using segments: {filename}" print(message) messages.append(message) use_segment = True # check if unsegmented audio exceeds 11.6s if not use_segment: path = f'{indir}/audio/{filename}' if not os.path.exists(path): messages.append(f"Missing source audio: {filename}") errored += 1 continue metadata = torchaudio.info(path) duration = metadata.num_frames / metadata.sample_rate if duration >= MAX_TRAINING_DURATION: message = f"Audio too large, using segments: {filename}" print(message) messages.append(message) use_segment = True # implicitly segment if use_segment and not use_segments: exists = True for segment in result['segments']: duration = segment['end'] - segment['start'] if duration <= MIN_TRAINING_DURATION or MAX_TRAINING_DURATION <= duration: continue path = f'{indir}/audio/' + filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav") if os.path.exists(path): continue exists = False break if not exists: tmp = {} tmp[filename] = result print(f"Audio not segmented, segmenting: {filename}") message = slice_dataset( voice, results=tmp ) print(message) messages = messages + message.split("\n") if not use_segment: segments[filename] = { 'text': result['text'], 'lang': lang, 'language': language, 'normalizer': normalizer, 'phonemes': result['phonemes'] if 'phonemes' in result else None } else: for segment in result['segments']: duration = segment['end'] - segment['start'] if duration <= MIN_TRAINING_DURATION or MAX_TRAINING_DURATION <= duration: continue segments[filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav")] = { 'text': segment['text'], 'lang': lang, 'language': language, 'normalizer': normalizer, 'phonemes': segment['phonemes'] if 'phonemes' in segment else None } jobs = { 'quantize': [[], []], 'phonemize': [[], []], } for file in enumerate_progress(segments, desc="Parsing segments", progress=progress): result = segments[file] path = f'{indir}/audio/{file}' if not os.path.exists(path): message = f"Missing segment, skipping... {file}" print(message) messages.append(message) errored += 1 continue text = result['text'] lang = result['lang'] language = result['language'] normalizer = result['normalizer'] phonemes = result['phonemes'] if phonemize and phonemes is None: phonemes = phonemizer( text, language=lang ) normalized = normalizer(text) if normalize else text if len(text) > 200: message = f"Text length too long (200 < {len(text)}), skipping... {file}" print(message) messages.append(message) errored += 1 continue waveform, sample_rate = torchaudio.load(path) num_channels, num_frames = waveform.shape duration = num_frames / sample_rate error = validate_waveform( waveform, sample_rate ) if error: message = f"{error}, skipping... {file}" print(message) messages.append(message) errored += 1 continue culled = len(text) < text_length if not culled and audio_length > 0: culled = duration < audio_length line = f'audio/{file}|{phonemes if phonemize and phonemes else text}' lines['training' if not culled else 'validation'].append(line) if culled or args.tts_backend != "vall-e": continue os.makedirs(f'{indir}/valle/', exist_ok=True) qnt_file = f'{indir}/valle/{file.replace(".wav",".qnt.pt")}' if not os.path.exists(qnt_file): jobs['quantize'][0].append(qnt_file) jobs['quantize'][1].append((waveform, sample_rate)) """ quantized = valle_quantize( waveform, sample_rate ).cpu() torch.save(quantized, f'{indir}/valle/{file.replace(".wav",".qnt.pt")}') print("Quantized:", file) """ phn_file = f'{indir}/valle/{file.replace(".wav",".phn.txt")}' if not os.path.exists(phn_file): jobs['phonemize'][0].append(phn_file) jobs['phonemize'][1].append(normalized) """ phonemized = valle_phonemize( normalized ) open(f'{indir}/valle/{file.replace(".wav",".phn.txt")}', 'w', encoding='utf-8').write(" ".join(phonemized)) print("Phonemized:", file, normalized, text) """ for i in enumerate_progress(range(len(jobs['quantize'][0])), desc="Quantizing", progress=progress): qnt_file = jobs['quantize'][0][i] waveform, sample_rate = jobs['quantize'][1][i] quantized = valle_quantize( waveform, sample_rate ).cpu() torch.save(quantized, qnt_file) print("Quantized:", file) for i in enumerate_progress(range(len(jobs['phonemize'][0])), desc="Phonemizing", progress=progress): phn_file = jobs['phonemize'][0][i] normalized = jobs['phonemize'][1][i] phonemized = valle_phonemize( normalized ) open(phn_file, 'w', encoding='utf-8').write(" ".join(phonemized)) print("Phonemized:", file) training_joined = "\n".join(lines['training']) validation_joined = "\n".join(lines['validation']) with open(f'{indir}/train.txt', 'w', encoding="utf-8") as f: f.write(training_joined) with open(f'{indir}/validation.txt', 'w', encoding="utf-8") as f: f.write(validation_joined) messages.append(f"Prepared {len(lines['training'])} lines (validation: {len(lines['validation'])}, culled: {errored}).\n{training_joined}\n\n{validation_joined}") return "\n".join(messages) def calc_iterations( epochs, lines, batch_size ): return int(math.ceil(epochs * math.ceil(lines / batch_size))) def schedule_learning_rate( iterations, schedule=LEARNING_RATE_SCHEDULE ): return [int(iterations * d) for d in schedule] def optimize_training_settings( **kwargs ): messages = [] settings = {} settings.update(kwargs) dataset_path = f"./training/{settings['voice']}/train.txt" with open(dataset_path, 'r', encoding="utf-8") as f: lines = len(f.readlines()) if lines == 0: raise Exception("Empty dataset.") if settings['batch_size'] > lines: settings['batch_size'] = lines messages.append(f"Batch size is larger than your dataset, clamping batch size to: {settings['batch_size']}") """ if lines % settings['batch_size'] != 0: settings['batch_size'] = int(lines / settings['batch_size']) if settings['batch_size'] == 0: settings['batch_size'] = 1 messages.append(f"Batch size not neatly divisible by dataset size, adjusting batch size to: {settings['batch_size']}") """ if settings['gradient_accumulation_size'] == 0: settings['gradient_accumulation_size'] = 1 if settings['batch_size'] / settings['gradient_accumulation_size'] < 2: settings['gradient_accumulation_size'] = int(settings['batch_size'] / 2) if settings['gradient_accumulation_size'] == 0: settings['gradient_accumulation_size'] = 1 messages.append(f"Gradient accumulation size is too large for a given batch size, clamping gradient accumulation size to: {settings['gradient_accumulation_size']}") elif settings['batch_size'] % settings['gradient_accumulation_size'] != 0: settings['gradient_accumulation_size'] -= settings['batch_size'] % settings['gradient_accumulation_size'] if settings['gradient_accumulation_size'] == 0: settings['gradient_accumulation_size'] = 1 messages.append(f"Batch size is not evenly divisible by the gradient accumulation size, adjusting gradient accumulation size to: {settings['gradient_accumulation_size']}") if settings['batch_size'] % settings['gpus'] != 0: settings['batch_size'] -= settings['batch_size'] % settings['gpus'] if settings['batch_size'] == 0: settings['batch_size'] = 1 messages.append(f"Batch size not neatly divisible by GPU count, adjusting batch size to: {settings['batch_size']}") def get_device_batch_size( vram ): DEVICE_BATCH_SIZE_MAP = [ (70, 128), # based on an A100-80G, I can safely get a ratio of 4096:32 = 128 (32, 64), # based on my two 6800XTs, I can only really safely get a ratio of 128:2 = 64 (16, 8), # based on an A4000, I can do a ratio of 512:64 = 8:1 (8, 4), # interpolated (6, 2), # based on my 2060, it only really lets me have a batch ratio of 2:1 ] for k, v in DEVICE_BATCH_SIZE_MAP: if vram > (k-1): return v return 1 if settings['gpus'] > get_device_count(): settings['gpus'] = get_device_count() messages.append(f"GPU count exceeds defacto GPU count, clamping to: {settings['gpus']}") if settings['gpus'] <= 1: settings['gpus'] = 1 else: messages.append(f"! EXPERIMENTAL ! Multi-GPU training is extremely particular, expect issues.") # assuming you have equal GPUs vram = get_device_vram() * settings['gpus'] batch_ratio = int(settings['batch_size'] / settings['gradient_accumulation_size']) batch_cap = get_device_batch_size(vram) if batch_ratio > batch_cap: settings['gradient_accumulation_size'] = int(settings['batch_size'] / batch_cap) messages.append(f"Batch ratio ({batch_ratio}) is expected to exceed your VRAM capacity ({'{:.3f}'.format(vram)}GB, suggested {batch_cap} batch size cap), adjusting gradient accumulation size to: {settings['gradient_accumulation_size']}") iterations = calc_iterations(epochs=settings['epochs'], lines=lines, batch_size=settings['batch_size']) if settings['epochs'] < settings['save_rate']: settings['save_rate'] = settings['epochs'] messages.append(f"Save rate is too small for the given iteration step, clamping save rate to: {settings['save_rate']}") if settings['epochs'] < settings['validation_rate']: settings['validation_rate'] = settings['epochs'] messages.append(f"Validation rate is too small for the given iteration step, clamping validation rate to: {settings['validation_rate']}") if settings['resume_state'] and not os.path.exists(settings['resume_state']): settings['resume_state'] = None messages.append("Resume path specified, but does not exist. Disabling...") if settings['bitsandbytes']: messages.append("! EXPERIMENTAL ! BitsAndBytes requested.") if settings['half_p']: if settings['bitsandbytes']: settings['half_p'] = False messages.append("Half Precision requested, but BitsAndBytes is also requested. Due to redundancies, disabling half precision...") else: messages.append("! EXPERIMENTAL ! Half Precision requested.") if not os.path.exists(get_halfp_model_path()): convert_to_halfp() steps = int(iterations / settings['epochs']) messages.append(f"For {settings['epochs']} epochs with {lines} lines in batches of {settings['batch_size']}, iterating for {iterations} steps ({steps}) steps per epoch)") return settings, messages def save_training_settings( **kwargs ): messages = [] settings = {} settings.update(kwargs) outjson = f'./training/{settings["voice"]}/train.json' with open(outjson, 'w', encoding="utf-8") as f: f.write(json.dumps(settings, indent='\t') ) settings['dataset_path'] = f"./training/{settings['voice']}/train.txt" settings['validation_path'] = f"./training/{settings['voice']}/validation.txt" with open(settings['dataset_path'], 'r', encoding="utf-8") as f: lines = len(f.readlines()) settings['iterations'] = calc_iterations(epochs=settings['epochs'], lines=lines, batch_size=settings['batch_size']) if not settings['source_model'] or settings['source_model'] == "auto": settings['source_model'] = f"./models/tortoise/autoregressive{'_half' if settings['half_p'] else ''}.pth" if settings['half_p']: if not os.path.exists(get_halfp_model_path()): convert_to_halfp() messages.append(f"For {settings['epochs']} epochs with {lines} lines, iterating for {settings['iterations']} steps") iterations_per_epoch = settings['iterations'] / settings['epochs'] settings['save_rate'] = int(settings['save_rate'] * iterations_per_epoch) settings['validation_rate'] = int(settings['validation_rate'] * iterations_per_epoch) iterations_per_epoch = int(iterations_per_epoch) if settings['save_rate'] < 1: settings['save_rate'] = 1 """ if settings['validation_rate'] < 1: settings['validation_rate'] = 1 """ """ if settings['iterations'] % settings['save_rate'] != 0: adjustment = int(settings['iterations'] / settings['save_rate']) * settings['save_rate'] messages.append(f"Iteration rate is not evenly divisible by save rate, adjusting: {settings['iterations']} => {adjustment}") settings['iterations'] = adjustment """ settings['validation_batch_size'] = int(settings['batch_size'] / settings['gradient_accumulation_size']) if not os.path.exists(settings['validation_path']): settings['validation_enabled'] = False messages.append("Validation not found, disabling validation...") elif settings['validation_batch_size'] == 0: settings['validation_enabled'] = False messages.append("Validation batch size == 0, disabling validation...") else: with open(settings['validation_path'], 'r', encoding="utf-8") as f: validation_lines = len(f.readlines()) if validation_lines < settings['validation_batch_size']: settings['validation_batch_size'] = validation_lines messages.append(f"Batch size exceeds validation dataset size, clamping validation batch size to {validation_lines}") settings['tokenizer_json'] = args.tokenizer_json if args.tokenizer_json else get_tokenizer_jsons()[0] if settings['gpus'] > get_device_count(): settings['gpus'] = get_device_count() # what an utter mistake this was settings['optimizer'] = 'adamw' # if settings['gpus'] == 1 else 'adamw_zero' if 'learning_rate_scheme' not in settings or settings['learning_rate_scheme'] not in LEARNING_RATE_SCHEMES: settings['learning_rate_scheme'] = "Multistep" settings['learning_rate_scheme'] = LEARNING_RATE_SCHEMES[settings['learning_rate_scheme']] learning_rate_schema = [f"default_lr_scheme: {settings['learning_rate_scheme']}"] if settings['learning_rate_scheme'] == "MultiStepLR": if not settings['learning_rate_schedule']: settings['learning_rate_schedule'] = LEARNING_RATE_SCHEDULE elif isinstance(settings['learning_rate_schedule'],str): settings['learning_rate_schedule'] = json.loads(settings['learning_rate_schedule']) settings['learning_rate_schedule'] = schedule_learning_rate( iterations_per_epoch, settings['learning_rate_schedule'] ) learning_rate_schema.append(f" gen_lr_steps: {settings['learning_rate_schedule']}") learning_rate_schema.append(f" lr_gamma: 0.5") elif settings['learning_rate_scheme'] == "CosineAnnealingLR_Restart": epochs = settings['epochs'] restarts = settings['learning_rate_restarts'] restart_period = int(epochs / restarts) if 'learning_rate_warmup' not in settings: settings['learning_rate_warmup'] = 0 if 'learning_rate_min' not in settings: settings['learning_rate_min'] = 1e-08 if 'learning_rate_period' not in settings: settings['learning_rate_period'] = [ iterations_per_epoch * restart_period for x in range(epochs) ] settings['learning_rate_restarts'] = [ iterations_per_epoch * (x+1) * restart_period for x in range(restarts) ] # [52, 104, 156, 208] if 'learning_rate_restart_weights' not in settings: settings['learning_rate_restart_weights'] = [ ( restarts - x - 1 ) / restarts for x in range(restarts) ] # [.75, .5, .25, .125] settings['learning_rate_restart_weights'][-1] = settings['learning_rate_restart_weights'][-2] * 0.5 learning_rate_schema.append(f" T_period: {settings['learning_rate_period']}") learning_rate_schema.append(f" warmup: {settings['learning_rate_warmup']}") learning_rate_schema.append(f" eta_min: !!float {settings['learning_rate_min']}") learning_rate_schema.append(f" restarts: {settings['learning_rate_restarts']}") learning_rate_schema.append(f" restart_weights: {settings['learning_rate_restart_weights']}") settings['learning_rate_scheme'] = "\n".join(learning_rate_schema) if settings['resume_state']: settings['source_model'] = f"# pretrain_model_gpt: '{settings['source_model']}'" settings['resume_state'] = f"resume_state: '{settings['resume_state']}'" else: settings['source_model'] = f"pretrain_model_gpt: '{settings['source_model']}'" settings['resume_state'] = f"# resume_state: '{settings['resume_state']}'" def use_template(template, out): with open(template, 'r', encoding="utf-8") as f: yaml = f.read() # i could just load and edit the YAML directly, but this is easier, as I don't need to bother with path traversals for k in settings: if settings[k] is None: continue yaml = yaml.replace(f"${{{k}}}", str(settings[k])) with open(out, 'w', encoding="utf-8") as f: f.write(yaml) if args.tts_backend == "tortoise": use_template(f'./models/.template.dlas.yaml', f'./training/{settings["voice"]}/train.yaml') elif args.tts_backend == "vall-e": settings['model_name'] = "[ 'ar-quarter', 'nar-quarter' ]" use_template(f'./models/.template.valle.yaml', f'./training/{settings["voice"]}/config.yaml') messages.append(f"Saved training output") return settings, messages def import_voices(files, saveAs=None, progress=None): global args if not isinstance(files, list): files = [files] for file in enumerate_progress(files, desc="Importing voice files", progress=progress): j, latents = read_generate_settings(file, read_latents=True) if j is not None and saveAs is None: saveAs = j['voice'] if saveAs is None or saveAs == "": raise Exception("Specify a voice name") outdir = f'{get_voice_dir()}/{saveAs}/' os.makedirs(outdir, exist_ok=True) if latents: print(f"Importing latents to {latents}") with open(f'{outdir}/cond_latents.pth', 'wb') as f: f.write(latents) latents = f'{outdir}/cond_latents.pth' print(f"Imported latents to {latents}") else: filename = file.name if filename[-4:] != ".wav": raise Exception("Please convert to a WAV first") path = f"{outdir}/{os.path.basename(filename)}" print(f"Importing voice to {path}") waveform, sample_rate = torchaudio.load(filename) if args.voice_fixer: if not voicefixer: load_voicefixer() waveform, sample_rate = resample(waveform, sample_rate, 44100) torchaudio.save(path, waveform, sample_rate) print(f"Running 'voicefixer' on voice sample: {path}") voicefixer.restore( input = path, output = path, cuda=get_device_name() == "cuda" and args.voice_fixer_use_cuda, #mode=mode, ) else: torchaudio.save(path, waveform, sample_rate) print(f"Imported voice to {path}") def relative_paths( dirs ): return [ './' + os.path.relpath( d ).replace("\\", "/") for d in dirs ] def get_voice_list(dir=get_voice_dir(), append_defaults=False): defaults = [ "random", "microphone" ] os.makedirs(dir, exist_ok=True) res = sorted([d for d in os.listdir(dir) if d not in defaults and os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 ]) if append_defaults: res = res + defaults return res def get_autoregressive_models(dir="./models/finetunes/", prefixed=False): os.makedirs(dir, exist_ok=True) base = [get_model_path('autoregressive.pth')] halfp = get_halfp_model_path() if os.path.exists(halfp): base.append(halfp) additionals = sorted([f'{dir}/{d}' for d in os.listdir(dir) if d[-4:] == ".pth" ]) found = [] for training in os.listdir(f'./training/'): if not os.path.isdir(f'./training/{training}/') or not os.path.isdir(f'./training/{training}/finetune/') or not os.path.isdir(f'./training/{training}/finetune/models/'): continue models = sorted([ int(d[:-8]) for d in os.listdir(f'./training/{training}/finetune/models/') if d[-8:] == "_gpt.pth" ]) found = found + [ f'./training/{training}/finetune/models/{d}_gpt.pth' for d in models ] res = base + additionals + found if prefixed: for i in range(len(res)): path = res[i] hash = hash_file(path) shorthash = hash[:8] res[i] = f'[{shorthash}] {path}' return ["auto"] + relative_paths(res) def get_diffusion_models(dir="./models/finetunes/", prefixed=False): return relative_paths([ get_model_path('diffusion_decoder.pth') ]) def get_tokenizer_jsons( dir="./models/tokenizers/" ): additionals = sorted([ f'{dir}/{d}' for d in os.listdir(dir) if d[-5:] == ".json" ]) if os.path.isdir(dir) else [] return relative_paths([ "./modules/tortoise-tts/tortoise/data/tokenizer.json" ] + additionals) def tokenize_text( text, config=None, stringed=True, skip_specials=False ): from tortoise.utils.tokenizer import VoiceBpeTokenizer if not config: config = args.tokenizer_json if args.tokenizer_json else get_tokenizer_jsons()[0] if not tts: tokenizer = VoiceBpeTokenizer(config) else: tokenizer = tts.tokenizer encoded = tokenizer.encode(text) decoded = tokenizer.tokenizer.decode(encoded, skip_special_tokens=skip_specials).split(" ") if stringed: return "\n".join([ str(encoded), str(decoded) ]) return decoded def get_dataset_list(dir="./training/"): return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "train.txt" in os.listdir(os.path.join(dir, d)) ]) def get_training_list(dir="./training/"): if args.tts_backend == "tortoise": return sorted([f'./training/{d}/train.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "train.yaml" in os.listdir(os.path.join(dir, d)) ]) ars = sorted([f'./training/{d}/ar.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "ar.yaml" in os.listdir(os.path.join(dir, d)) ]) nars = sorted([f'./training/{d}/nar.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "nar.yaml" in os.listdir(os.path.join(dir, d)) ]) return ars + nars def pad(num, zeroes): return str(num).zfill(zeroes+1) def curl(url): try: req = urllib.request.Request(url, headers={'User-Agent': 'Python'}) conn = urllib.request.urlopen(req) data = conn.read() data = data.decode() data = json.loads(data) conn.close() return data except Exception as e: print(e) return None def check_for_updates( dir = None ): if dir is None: check_for_updates("./.git/") check_for_updates("./.git/modules/dlas/") check_for_updates("./.git/modules/tortoise-tts/") return git_dir = dir if not os.path.isfile(f'{git_dir}/FETCH_HEAD'): print(f"Cannot check for updates for {dir}: not from a git repo") return False with open(f'{git_dir}/FETCH_HEAD', 'r', encoding="utf-8") as f: head = f.read() match = re.findall(r"^([a-f0-9]+).+?https:\/\/(.+?)\/(.+?)\/(.+?)\n", head) if match is None or len(match) == 0: print(f"Cannot check for updates for {dir}: cannot parse FETCH_HEAD") return False match = match[0] local = match[0] host = match[1] owner = match[2] repo = match[3] res = curl(f"https://{host}/api/v1/repos/{owner}/{repo}/branches/") #this only works for gitea instances if res is None or len(res) == 0: print(f"Cannot check for updates for {dir}: cannot fetch from remote") return False remote = res[0]["commit"]["id"] if remote != local: print(f"New version found for {dir}: {local[:8]} => {remote[:8]}") return True return False def enumerate_progress(iterable, desc=None, progress=None, verbose=None): if verbose and desc is not None: print(desc) if progress is None: return tqdm(iterable, disable=False) #not verbose) return progress.tqdm(iterable, desc=f'{progress.msg_prefix} {desc}' if hasattr(progress, 'msg_prefix') else desc, track_tqdm=True) def notify_progress(message, progress=None, verbose=True): if verbose: print(message) if progress is None: return progress(0, desc=message) def get_args(): global args return args def setup_args(): global args default_arguments = { 'share': False, 'listen': None, 'check-for-updates': False, 'models-from-local-only': False, 'low-vram': False, 'sample-batch-size': None, 'unsqueeze-sample-batches': False, 'embed-output-metadata': True, 'latents-lean-and-mean': True, 'voice-fixer': False, # getting tired of long initialization times in a Colab for downloading a large dataset for it 'voice-fixer-use-cuda': True, 'force-cpu-for-conditioning-latents': False, 'defer-tts-load': False, 'device-override': None, 'prune-nonfinal-outputs': True, 'concurrency-count': 2, 'autocalculate-voice-chunk-duration-size': 10, 'output-sample-rate': 44100, 'output-volume': 1, 'results-folder': "./results/", 'hf-token': None, 'tts-backend': TTSES[0], 'autoregressive-model': None, 'diffusion-model': None, 'vocoder-model': VOCODERS[-1], 'tokenizer-json': None, 'phonemizer-backend': 'espeak', 'whisper-backend': 'openai/whisper', 'whisper-model': "base", 'whisper-batchsize': 1, 'training-default-halfp': False, 'training-default-bnb': True, } if os.path.isfile('./config/exec.json'): with open(f'./config/exec.json', 'r', encoding="utf-8") as f: try: overrides = json.load(f) for k in overrides: default_arguments[k] = overrides[k] except Exception as e: print(e) pass parser = argparse.ArgumentParser() parser.add_argument("--share", action='store_true', default=default_arguments['share'], help="Lets Gradio return a public URL to use anywhere") parser.add_argument("--listen", default=default_arguments['listen'], help="Path for Gradio to listen on") parser.add_argument("--check-for-updates", action='store_true', default=default_arguments['check-for-updates'], help="Checks for update on startup") parser.add_argument("--models-from-local-only", action='store_true', default=default_arguments['models-from-local-only'], help="Only loads models from disk, does not check for updates for models") parser.add_argument("--low-vram", action='store_true', default=default_arguments['low-vram'], help="Disables some optimizations that increases VRAM usage") parser.add_argument("--no-embed-output-metadata", action='store_false', default=not default_arguments['embed-output-metadata'], help="Disables embedding output metadata into resulting WAV files for easily fetching its settings used with the web UI (data is stored in the lyrics metadata tag)") parser.add_argument("--latents-lean-and-mean", action='store_true', default=default_arguments['latents-lean-and-mean'], help="Exports the bare essentials for latents.") parser.add_argument("--voice-fixer", action='store_true', default=default_arguments['voice-fixer'], help="Uses python module 'voicefixer' to improve audio quality, if available.") parser.add_argument("--voice-fixer-use-cuda", action='store_true', default=default_arguments['voice-fixer-use-cuda'], help="Hints to voicefixer to use CUDA, if available.") parser.add_argument("--force-cpu-for-conditioning-latents", default=default_arguments['force-cpu-for-conditioning-latents'], action='store_true', help="Forces computing conditional latents to be done on the CPU (if you constantyl OOM on low chunk counts)") parser.add_argument("--defer-tts-load", default=default_arguments['defer-tts-load'], action='store_true', help="Defers loading TTS model") parser.add_argument("--prune-nonfinal-outputs", default=default_arguments['prune-nonfinal-outputs'], action='store_true', help="Deletes non-final output files on completing a generation") parser.add_argument("--device-override", default=default_arguments['device-override'], help="A device string to override pass through Torch") parser.add_argument("--sample-batch-size", default=default_arguments['sample-batch-size'], type=int, help="Sets how many batches to use during the autoregressive samples pass") parser.add_argument("--unsqueeze-sample-batches", default=default_arguments['unsqueeze-sample-batches'], action='store_true', help="Unsqueezes sample batches to process one by one after sampling") parser.add_argument("--concurrency-count", type=int, default=default_arguments['concurrency-count'], help="How many Gradio events to process at once") parser.add_argument("--autocalculate-voice-chunk-duration-size", type=float, default=default_arguments['autocalculate-voice-chunk-duration-size'], help="Number of seconds to suggest voice chunk size for (for example, 100 seconds of audio at 10 seconds per chunk will suggest 10 chunks)") parser.add_argument("--output-sample-rate", type=int, default=default_arguments['output-sample-rate'], help="Sample rate to resample the output to (from 24KHz)") parser.add_argument("--output-volume", type=float, default=default_arguments['output-volume'], help="Adjusts volume of output") parser.add_argument("--results-folder", type=str, default=default_arguments['results-folder'], help="Sets output directory") parser.add_argument("--hf-token", type=str, default=default_arguments['hf-token'], help="HuggingFace Token") parser.add_argument("--tts-backend", default=default_arguments['tts-backend'], help="Specifies which TTS backend to use.") parser.add_argument("--autoregressive-model", default=default_arguments['autoregressive-model'], help="Specifies which autoregressive model to use for sampling.") parser.add_argument("--diffusion-model", default=default_arguments['diffusion-model'], help="Specifies which diffusion model to use for sampling.") parser.add_argument("--vocoder-model", default=default_arguments['vocoder-model'], action='store_true', help="Specifies with vocoder to use") parser.add_argument("--tokenizer-json", default=default_arguments['tokenizer-json'], help="Specifies which tokenizer json to use for tokenizing.") parser.add_argument("--phonemizer-backend", default=default_arguments['phonemizer-backend'], help="Specifies which phonemizer backend to use.") parser.add_argument("--whisper-backend", default=default_arguments['whisper-backend'], action='store_true', help="Picks which whisper backend to use (openai/whisper, lightmare/whispercpp)") parser.add_argument("--whisper-model", default=default_arguments['whisper-model'], help="Specifies which whisper model to use for transcription.") parser.add_argument("--whisper-batchsize", type=int, default=default_arguments['whisper-batchsize'], help="Specifies batch size for WhisperX") parser.add_argument("--training-default-halfp", action='store_true', default=default_arguments['training-default-halfp'], help="Training default: halfp") parser.add_argument("--training-default-bnb", action='store_true', default=default_arguments['training-default-bnb'], help="Training default: bnb") parser.add_argument("--os", default="unix", help="Specifies which OS, easily") args = parser.parse_args() args.embed_output_metadata = not args.no_embed_output_metadata if not args.device_override: set_device_name(args.device_override) if args.sample_batch_size == 0 and get_device_batch_size() == 1: print("!WARNING! Automatically deduced sample batch size returned 1.") args.listen_host = None args.listen_port = None args.listen_path = None if args.listen: try: match = re.findall(r"^(?:(.+?):(\d+))?(\/.*?)?$", args.listen)[0] args.listen_host = match[0] if match[0] != "" else "127.0.0.1" args.listen_port = match[1] if match[1] != "" else None args.listen_path = match[2] if match[2] != "" else "/" except Exception as e: pass if args.listen_port is not None: args.listen_port = int(args.listen_port) if args.listen_port == 0: args.listen_port = None return args def get_default_settings( hypenated=True ): settings = { 'listen': None if not args.listen else args.listen, 'share': args.share, 'low-vram':args.low_vram, 'check-for-updates':args.check_for_updates, 'models-from-local-only':args.models_from_local_only, 'force-cpu-for-conditioning-latents': args.force_cpu_for_conditioning_latents, 'defer-tts-load': args.defer_tts_load, 'prune-nonfinal-outputs': args.prune_nonfinal_outputs, 'device-override': args.device_override, 'sample-batch-size': args.sample_batch_size, 'unsqueeze-sample-batches': args.unsqueeze_sample_batches, 'embed-output-metadata': args.embed_output_metadata, 'latents-lean-and-mean': args.latents_lean_and_mean, 'voice-fixer': args.voice_fixer, 'voice-fixer-use-cuda': args.voice_fixer_use_cuda, 'concurrency-count': args.concurrency_count, 'output-sample-rate': args.output_sample_rate, 'autocalculate-voice-chunk-duration-size': args.autocalculate_voice_chunk_duration_size, 'output-volume': args.output_volume, 'results-folder': args.results_folder, 'hf-token': args.hf_token, 'tts-backend': args.tts_backend, 'autoregressive-model': args.autoregressive_model, 'diffusion-model': args.diffusion_model, 'vocoder-model': args.vocoder_model, 'tokenizer-json': args.tokenizer_json, 'phonemizer-backend': args.phonemizer_backend, 'whisper-backend': args.whisper_backend, 'whisper-model': args.whisper_model, 'whisper-batchsize': args.whisper_batchsize, 'training-default-halfp': args.training_default_halfp, 'training-default-bnb': args.training_default_bnb, } res = {} for k in settings: res[k.replace("-", "_") if not hypenated else k] = settings[k] return res def update_args( **kwargs ): global args settings = get_default_settings(hypenated=False) settings.update(kwargs) args.listen = settings['listen'] args.share = settings['share'] args.check_for_updates = settings['check_for_updates'] args.models_from_local_only = settings['models_from_local_only'] args.low_vram = settings['low_vram'] args.force_cpu_for_conditioning_latents = settings['force_cpu_for_conditioning_latents'] args.defer_tts_load = settings['defer_tts_load'] args.prune_nonfinal_outputs = settings['prune_nonfinal_outputs'] args.device_override = settings['device_override'] args.sample_batch_size = settings['sample_batch_size'] args.unsqueeze_sample_batches = settings['unsqueeze_sample_batches'] args.embed_output_metadata = settings['embed_output_metadata'] args.latents_lean_and_mean = settings['latents_lean_and_mean'] args.voice_fixer = settings['voice_fixer'] args.voice_fixer_use_cuda = settings['voice_fixer_use_cuda'] args.concurrency_count = settings['concurrency_count'] args.output_sample_rate = 44000 args.autocalculate_voice_chunk_duration_size = settings['autocalculate_voice_chunk_duration_size'] args.output_volume = settings['output_volume'] args.results_folder = settings['results_folder'] args.hf_token = settings['hf_token'] args.tts_backend = settings['tts_backend'] args.autoregressive_model = settings['autoregressive_model'] args.diffusion_model = settings['diffusion_model'] args.vocoder_model = settings['vocoder_model'] args.tokenizer_json = settings['tokenizer_json'] args.phonemizer_backend = settings['phonemizer_backend'] args.whisper_backend = settings['whisper_backend'] args.whisper_model = settings['whisper_model'] args.whisper_batchsize = settings['whisper_batchsize'] args.training_default_halfp = settings['training_default_halfp'] args.training_default_bnb = settings['training_default_bnb'] save_args_settings() def save_args_settings(): global args settings = get_default_settings() os.makedirs('./config/', exist_ok=True) with open(f'./config/exec.json', 'w', encoding="utf-8") as f: f.write(json.dumps(settings, indent='\t') ) # super kludgy )`; def import_generate_settings(file = None): if not file: file = "./config/generate.json" res = { 'text': None, 'delimiter': None, 'emotion': None, 'prompt': None, 'voice': "random", 'mic_audio': None, 'voice_latents_chunks': None, 'candidates': None, 'seed': None, 'num_autoregressive_samples': 16, 'diffusion_iterations': 30, 'temperature': 0.8, 'diffusion_sampler': "DDIM", 'breathing_room': 8 , 'cvvp_weight': 0.0, 'top_p': 0.8, 'diffusion_temperature': 1.0, 'length_penalty': 1.0, 'repetition_penalty': 2.0, 'cond_free_k': 2.0, 'experimentals': None, } settings, _ = read_generate_settings(file, read_latents=False) if settings is not None: res.update(settings) return res def reset_generate_settings(): with open(f'./config/generate.json', 'w', encoding="utf-8") as f: f.write(json.dumps({}, indent='\t') ) return import_generate_settings() def read_generate_settings(file, read_latents=True): j = None latents = None if isinstance(file, list) and len(file) == 1: file = file[0] try: if file is not None: if hasattr(file, 'name'): file = file.name if file[-4:] == ".wav": metadata = music_tag.load_file(file) if 'lyrics' in metadata: j = json.loads(str(metadata['lyrics'])) elif file[-5:] == ".json": with open(file, 'r') as f: j = json.load(f) except Exception as e: pass if j is not None: if 'latents' in j: if read_latents: latents = base64.b64decode(j['latents']) del j['latents'] if "time" in j: j["time"] = "{:.3f}".format(j["time"]) return ( j, latents, ) def version_check_tts( min_version ): global tts if not tts: raise Exception("TTS is not initialized") if not hasattr(tts, 'version'): return False if min_version[0] > tts.version[0]: return True if min_version[1] > tts.version[1]: return True if min_version[2] >= tts.version[2]: return True return False def load_tts( restart=False, autoregressive_model=None, diffusion_model=None, vocoder_model=None, tokenizer_json=None ): global args global tts if restart: unload_tts() if autoregressive_model: args.autoregressive_model = autoregressive_model else: autoregressive_model = args.autoregressive_model if autoregressive_model == "auto": autoregressive_model = deduce_autoregressive_model() if diffusion_model: args.diffusion_model = diffusion_model else: diffusion_model = args.diffusion_model if vocoder_model: args.vocoder_model = vocoder_model else: vocoder_model = args.vocoder_model if tokenizer_json: args.tokenizer_json = tokenizer_json else: tokenizer_json = args.tokenizer_json if get_device_name() == "cpu": print("!!!! WARNING !!!! No GPU available in PyTorch. You may need to reinstall PyTorch.") tts_loading = True print(f"Loading TorToiSe... (AR: {autoregressive_model}, vocoder: {vocoder_model})") tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=autoregressive_model, diffusion_model_path=diffusion_model, vocoder_model=vocoder_model, tokenizer_json=tokenizer_json, unsqueeze_sample_batches=args.unsqueeze_sample_batches) tts_loading = False get_model_path('dvae.pth') print("Loaded TorToiSe, ready for generation.") return tts setup_tortoise = load_tts def unload_tts(): global tts if tts: del tts tts = None print("Unloaded TTS") do_gc() def reload_tts(): unload_tts() load_tts() def get_current_voice(): global current_voice if current_voice: return current_voice settings, _ = read_generate_settings("./config/generate.json", read_latents=False) if settings and "voice" in settings['voice']: return settings["voice"] return None def deduce_autoregressive_model(voice=None): if not voice: voice = get_current_voice() if voice: if os.path.exists(f'./models/finetunes/{voice}.pth'): return f'./models/finetunes/{voice}.pth' dir = f'./training/{voice}/finetune/models/' if os.path.isdir(dir): counts = sorted([ int(d[:-8]) for d in os.listdir(dir) if d[-8:] == "_gpt.pth" ]) names = [ f'{dir}/{d}_gpt.pth' for d in counts ] if len(names) > 0: return names[-1] if args.autoregressive_model != "auto": return args.autoregressive_model return get_model_path('autoregressive.pth') def update_autoregressive_model(autoregressive_model_path): match = re.findall(r'^\[[a-fA-F0-9]{8}\] (.+?)$', autoregressive_model_path) if match: autoregressive_model_path = match[0] if not autoregressive_model_path or not os.path.exists(autoregressive_model_path): print(f"Invalid model: {autoregressive_model_path}") return args.autoregressive_model = autoregressive_model_path save_args_settings() print(f'Stored autoregressive model to settings: {autoregressive_model_path}') global tts if not tts: if tts_loading: raise Exception("TTS is still initializing...") return if hasattr(tts, "loading") and tts.loading: raise Exception("TTS is still initializing...") if autoregressive_model_path == "auto": autoregressive_model_path = deduce_autoregressive_model() if autoregressive_model_path == tts.autoregressive_model_path: return tts.load_autoregressive_model(autoregressive_model_path) do_gc() return autoregressive_model_path def update_diffusion_model(diffusion_model_path): match = re.findall(r'^\[[a-fA-F0-9]{8}\] (.+?)$', diffusion_model_path) if match: diffusion_model_path = match[0] if not diffusion_model_path or not os.path.exists(diffusion_model_path): print(f"Invalid model: {diffusion_model_path}") return args.diffusion_model = diffusion_model_path save_args_settings() print(f'Stored diffusion model to settings: {diffusion_model_path}') global tts if not tts: if tts_loading: raise Exception("TTS is still initializing...") return if hasattr(tts, "loading") and tts.loading: raise Exception("TTS is still initializing...") if diffusion_model_path == "auto": diffusion_model_path = deduce_diffusion_model() if diffusion_model_path == tts.diffusion_model_path: return tts.load_diffusion_model(diffusion_model_path) do_gc() return diffusion_model_path def update_vocoder_model(vocoder_model): args.vocoder_model = vocoder_model save_args_settings() print(f'Stored vocoder model to settings: {vocoder_model}') global tts if not tts: if tts_loading: raise Exception("TTS is still initializing...") return if hasattr(tts, "loading") and tts.loading: raise Exception("TTS is still initializing...") print(f"Loading model: {vocoder_model}") tts.load_vocoder_model(vocoder_model) print(f"Loaded model: {tts.vocoder_model}") do_gc() return vocoder_model def update_tokenizer(tokenizer_json): args.tokenizer_json = tokenizer_json save_args_settings() print(f'Stored tokenizer to settings: {tokenizer_json}') global tts if not tts: if tts_loading: raise Exception("TTS is still initializing...") return if hasattr(tts, "loading") and tts.loading: raise Exception("TTS is still initializing...") print(f"Loading tokenizer vocab: {tokenizer_json}") tts.load_tokenizer_json(tokenizer_json) print(f"Loaded tokenizer vocab: {tts.tokenizer_json}") do_gc() return vocoder_model def load_voicefixer(restart=False): global voicefixer if restart: unload_voicefixer() try: print("Loading Voicefixer") from voicefixer import VoiceFixer voicefixer = VoiceFixer() print("Loaded Voicefixer") except Exception as e: print(f"Error occurred while tring to initialize voicefixer: {e}") if voicefixer: del voicefixer voicefixer = None def unload_voicefixer(): global voicefixer if voicefixer: del voicefixer voicefixer = None print("Unloaded Voicefixer") do_gc() def load_whisper_model(language=None, model_name=None, progress=None): global whisper_model global whisper_vad global whisper_diarize global whisper_align_model if args.whisper_backend not in WHISPER_BACKENDS: raise Exception(f"unavailable backend: {args.whisper_backend}") if not model_name: model_name = args.whisper_model else: args.whisper_model = model_name save_args_settings() if language and f'{model_name}.{language}' in WHISPER_SPECIALIZED_MODELS: model_name = f'{model_name}.{language}' print(f"Loading specialized model for language: {language}") notify_progress(f"Loading Whisper model: {model_name}", progress) if args.whisper_backend == "openai/whisper": import whisper try: #is it possible for model to fit on vram but go oom later on while executing on data? whisper_model = whisper.load_model(model_name) except: print("Out of VRAM memory. falling back to loading Whisper on CPU.") whisper_model = whisper.load_model(model_name, device="cpu") elif args.whisper_backend == "lightmare/whispercpp": from whispercpp import Whisper if not language: language = 'auto' b_lang = language.encode('ascii') whisper_model = Whisper(model_name, models_dir='./models/', language=b_lang) elif args.whisper_backend == "m-bain/whisperx": import whisperx device = "cuda" if get_device_name() == "cuda" else "cpu" whisper_model = whisperx.load_model(model_name, device) if not args.hf_token: print("No huggingface token used, needs to be saved in environment variable, otherwise will throw error loading VAD model.") try: from pyannote.audio import Inference, Pipeline whisper_vad = Inference( "pyannote/segmentation", pre_aggregation_hook=lambda segmentation: segmentation, use_auth_token=args.hf_token, device=torch.device(device), ) whisper_diarize = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1",use_auth_token=args.hf_token) except Exception as e: pass whisper_align_model = whisperx.load_align_model(model_name="WAV2VEC2_ASR_LARGE_LV60K_960H" if language=="en" else None, language_code=language, device=device) print("Loaded Whisper model") def unload_whisper(): global whisper_model global whisper_vad global whisper_diarize global whisper_align_model if whisper_vad: del whisper_vad whisper_vad = None if whisper_diarize: del whisper_diarize whisper_diarize = None if whisper_align_model: del whisper_align_model whisper_align_model = None if whisper_model: del whisper_model whisper_model = None print("Unloaded Whisper") do_gc()