1
0
Fork 0
ai-voice-cloning/src/utils.py

3983 lines
123 KiB
Python

import os
if 'XDG_CACHE_HOME' not in os.environ:
os.environ['XDG_CACHE_HOME'] = os.path.realpath(os.path.join(os.getcwd(), './models/'))
if 'TORTOISE_MODELS_DIR' not in os.environ:
os.environ['TORTOISE_MODELS_DIR'] = os.path.realpath(os.path.join(os.getcwd(), './models/tortoise/'))
if 'TRANSFORMERS_CACHE' not in os.environ:
os.environ['TRANSFORMERS_CACHE'] = os.path.realpath(os.path.join(os.getcwd(), './models/transformers/'))
import argparse
import time
import math
import json
import base64
import re
import urllib.request
import signal
import gc
import subprocess
import psutil
import yaml
import hashlib
import string
import random
from tqdm import tqdm
import torch
import torchaudio
import music_tag
import gradio as gr
import gradio.utils
import pandas as pd
import numpy as np
from glob import glob
from datetime import datetime
from datetime import timedelta
from tortoise.api import TextToSpeech as TorToise_TTS, MODELS, get_model_path, pad_or_truncate
from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir, get_voices
from tortoise.utils.text import split_and_recombine_text
from tortoise.utils.device import get_device_name, set_device_name, get_device_count, get_device_vram, get_device_batch_size, do_gc
MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth"
WHISPER_MODELS = ["tiny", "base", "small", "medium", "large", "large-v1", "large-v2"]
WHISPER_SPECIALIZED_MODELS = ["tiny.en", "base.en", "small.en", "medium.en"]
WHISPER_BACKENDS = ["openai/whisper", "lightmare/whispercpp", "m-bain/whisperx"]
VOCODERS = ['univnet', 'bigvgan_base_24khz_100band', 'bigvgan_24khz_100band']
TTSES = ['tortoise']
INFERENCING = False
GENERATE_SETTINGS_ARGS = None
LEARNING_RATE_SCHEMES = {"Multistep": "MultiStepLR", "Cos. Annealing": "CosineAnnealingLR_Restart"}
LEARNING_RATE_SCHEDULE = [ 2, 4, 9, 18, 25, 33, 50 ]
RESAMPLERS = {}
MIN_TRAINING_DURATION = 0.6
MAX_TRAINING_DURATION = 11.6097505669
MAX_TRAINING_CHAR_LENGTH = 200
VALLE_ENABLED = False
BARK_ENABLED = False
VERBOSE_DEBUG = True
KKS = None
PYKAKASI_ENABLED = False
import traceback
try:
import pykakasi
KKS = pykakasi.kakasi()
PYKAKASI_ENABLED = True
except Exception as e:
#if VERBOSE_DEBUG:
# print(traceback.format_exc())
pass
try:
from whisper.normalizers.english import EnglishTextNormalizer
from whisper.normalizers.basic import BasicTextNormalizer
from whisper.tokenizer import LANGUAGES
print("Whisper detected")
except Exception as e:
if VERBOSE_DEBUG:
print(traceback.format_exc())
pass
try:
from vall_e.emb.qnt import encode as valle_quantize
from vall_e.emb.g2p import encode as valle_phonemize
from vall_e.inference import TTS as VALLE_TTS
import soundfile
print("VALL-E detected")
VALLE_ENABLED = True
except Exception as e:
if VERBOSE_DEBUG:
print(traceback.format_exc())
pass
if VALLE_ENABLED:
TTSES.append('vall-e')
# torchaudio.set_audio_backend('soundfile')
try:
import bark
from bark import text_to_semantic
from bark.generation import SAMPLE_RATE as BARK_SAMPLE_RATE, ALLOWED_PROMPTS, preload_models, codec_decode, generate_coarse, generate_fine, generate_text_semantic, load_codec_model
from bark.api import generate_audio as bark_generate_audio
from encodec.utils import convert_audio
from scipy.io.wavfile import write as write_wav
print("Bark detected")
BARK_ENABLED = True
except Exception as e:
if VERBOSE_DEBUG:
print(traceback.format_exc())
pass
if BARK_ENABLED:
TTSES.append('bark')
def semantic_to_audio_tokens(
semantic_tokens,
history_prompt = None,
temp = 0.7,
silent = False,
output_full = False,
):
coarse_tokens = generate_coarse(
semantic_tokens, history_prompt=history_prompt, temp=temp, silent=silent, use_kv_caching=True
)
fine_tokens = generate_fine(coarse_tokens, history_prompt=history_prompt, temp=0.5)
if output_full:
full_generation = {
"semantic_prompt": semantic_tokens,
"coarse_prompt": coarse_tokens,
"fine_prompt": fine_tokens,
}
return full_generation
return fine_tokens
class Bark_TTS():
def __init__(self, small=False):
self.input_sample_rate = BARK_SAMPLE_RATE
self.output_sample_rate = BARK_SAMPLE_RATE # args.output_sample_rate
preload_models(
text_use_gpu=True,
coarse_use_gpu=True,
fine_use_gpu=True,
codec_use_gpu=True,
text_use_small=small,
coarse_use_small=small,
fine_use_small=small,
force_reload=False
)
self.device = get_device_name()
try:
from vocos import Vocos
self.vocos_enabled = True
print("Vocos detected")
except Exception as e:
if VERBOSE_DEBUG:
print(traceback.format_exc())
self.vocos_enabled = False
try:
from hubert.hubert_manager import HuBERTManager
hubert_manager = HuBERTManager()
hubert_manager.make_sure_hubert_installed()
hubert_manager.make_sure_tokenizer_installed()
self.hubert_enabled = True
print("HuBERT detected")
except Exception as e:
if VERBOSE_DEBUG:
print(traceback.format_exc())
self.hubert_enabled = False
if self.vocos_enabled:
self.vocos = Vocos.from_pretrained("charactr/vocos-encodec-24khz").to(self.device)
def create_voice( self, voice ):
transcription_json = f'./training/{voice}/whisper.json'
if not os.path.exists(transcription_json):
raise f"Transcription for voice not found: {voice}"
transcriptions = json.load(open(transcription_json, 'r', encoding="utf-8"))
candidates = []
for file in transcriptions:
result = transcriptions[file]
added = 0
for segment in result['segments']:
path = file.replace(".wav", f"_{pad(segment['id'], 4)}.wav")
# check if the slice actually exists
if not os.path.exists(f'./training/{voice}/audio/{path}'):
continue
entry = (
path,
segment['end'] - segment['start'],
segment['text']
)
candidates.append(entry)
added = added + 1
# if nothing got added (assuming because nothign was sliced), use the master file
if added == 0: # added < len(result['segments']):
start = 0
end = 0
for segment in result['segments']:
start = max( start, segment['start'] )
end = max( end, segment['end'] )
entry = (
file,
end - start,
result['text']
)
candidates.append(entry)
candidates.sort(key=lambda x: x[1])
candidate = random.choice(candidates)
audio_filepath = f'./training/{voice}/audio/{candidate[0]}'
text = candidate[-1]
print("Using as reference:", audio_filepath, text)
# Load and pre-process the audio waveform
model = load_codec_model(use_gpu=True)
wav, sr = torchaudio.load(audio_filepath)
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
# generate semantic tokens
if self.hubert_enabled:
from hubert.pre_kmeans_hubert import CustomHubert
from hubert.customtokenizer import CustomTokenizer
wav = wav.to(self.device)
# Extract discrete codes from EnCodec
with torch.no_grad():
encoded_frames = model.encode(wav.unsqueeze(0))
codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T]
# get seconds of audio
seconds = wav.shape[-1] / model.sample_rate
# Load the HuBERT model
hubert_model = CustomHubert(checkpoint_path='./data/models/hubert/hubert.pt').to(self.device)
# Load the CustomTokenizer model
tokenizer = CustomTokenizer.load_from_checkpoint('./data/models/hubert/tokenizer.pth').to(self.device)
semantic_vectors = hubert_model.forward(wav, input_sample_hz=model.sample_rate)
semantic_tokens = tokenizer.get_token(semantic_vectors)
# move codes to cpu
codes = codes.cpu().numpy()
# move semantic tokens to cpu
semantic_tokens = semantic_tokens.cpu().numpy()
else:
wav = wav.unsqueeze(0).to(self.device)
# Extract discrete codes from EnCodec
with torch.no_grad():
encoded_frames = model.encode(wav)
codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze().cpu().numpy() # [n_q, T]
# get seconds of audio
seconds = wav.shape[-1] / model.sample_rate
# generate semantic tokens
semantic_tokens = generate_text_semantic(text, max_gen_duration_s=seconds, top_k=50, top_p=.95, temp=0.7)
# print(bark.__file__)
bark_location = os.path.dirname(os.path.relpath(bark.__file__)) # './modules/bark/bark/'
output_path = f'./{bark_location}/assets/prompts/' + voice.replace("/", "_") + '.npz'
np.savez(output_path, fine_prompt=codes, coarse_prompt=codes[:2, :], semantic_prompt=semantic_tokens)
def inference( self, text, voice, text_temp=0.7, waveform_temp=0.7 ):
if voice == "random":
voice = None
else:
if not os.path.exists('./modules/bark/bark/assets/prompts/' + voice + '.npz'):
self.create_voice( voice )
voice = voice.replace("/", "_")
if voice not in ALLOWED_PROMPTS:
ALLOWED_PROMPTS.add( voice )
semantic_tokens = text_to_semantic(text, history_prompt=voice, temp=text_temp, silent=False)
audio_tokens = semantic_to_audio_tokens( semantic_tokens, history_prompt=voice, temp=waveform_temp, silent=False, output_full=False )
if self.vocos_enabled:
audio_tokens_torch = torch.from_numpy(audio_tokens).to(self.device)
features = self.vocos.codes_to_features(audio_tokens_torch)
wav = self.vocos.decode(features, bandwidth_id=torch.tensor([2], device=self.device))
else:
wav = codec_decode( audio_tokens )
return ( wav, BARK_SAMPLE_RATE )
# return (bark_generate_audio(text, history_prompt=voice, text_temp=text_temp, waveform_temp=waveform_temp), BARK_SAMPLE_RATE)
args = None
tts = None
tts_loading = False
webui = None
voicefixer = None
whisper_model = None
whisper_align_model = None
training_state = None
current_voice = None
def cleanup_voice_name( name ):
return name.split("/")[-1]
def resample( waveform, input_rate, output_rate=44100 ):
# mono-ize
waveform = torch.mean(waveform, dim=0, keepdim=True)
if input_rate == output_rate:
return waveform, output_rate
key = f'{input_rate}:{output_rate}'
if not key in RESAMPLERS:
RESAMPLERS[key] = torchaudio.transforms.Resample(
input_rate,
output_rate,
lowpass_filter_width=16,
rolloff=0.85,
resampling_method="kaiser_window",
beta=8.555504641634386,
)
return RESAMPLERS[key]( waveform ), output_rate
def generate(**kwargs):
if args.tts_backend == "tortoise":
return generate_tortoise(**kwargs)
if args.tts_backend == "vall-e":
return generate_valle(**kwargs)
if args.tts_backend == "bark":
return generate_bark(**kwargs)
def generate_bark(**kwargs):
parameters = {}
parameters.update(kwargs)
voice = parameters['voice']
progress = parameters['progress'] if 'progress' in parameters else None
if parameters['seed'] == 0:
parameters['seed'] = None
usedSeed = parameters['seed']
global args
global tts
unload_whisper()
unload_voicefixer()
if not tts:
# should check if it's loading or unloaded, and load it if it's unloaded
if tts_loading:
raise Exception("TTS is still initializing...")
if progress is not None:
notify_progress("Initializing TTS...", progress=progress)
load_tts()
if hasattr(tts, "loading") and tts.loading:
raise Exception("TTS is still initializing...")
do_gc()
voice_samples = None
conditioning_latents = None
sample_voice = None
voice_cache = {}
def get_settings( override=None ):
settings = {
'voice': parameters['voice'],
'text_temp': float(parameters['temperature']),
'waveform_temp': float(parameters['temperature']),
}
# could be better to just do a ternary on everything above, but i am not a professional
selected_voice = voice
if override is not None:
if 'voice' in override:
selected_voice = override['voice']
for k in override:
if k not in settings:
continue
settings[k] = override[k]
return settings
if not parameters['delimiter']:
parameters['delimiter'] = "\n"
elif parameters['delimiter'] == "\\n":
parameters['delimiter'] = "\n"
if parameters['delimiter'] and parameters['delimiter'] != "" and parameters['delimiter'] in parameters['text']:
texts = parameters['text'].split(parameters['delimiter'])
else:
texts = split_and_recombine_text(parameters['text'])
full_start_time = time.time()
outdir = f"{args.results_folder}/{voice}/"
os.makedirs(outdir, exist_ok=True)
audio_cache = {}
volume_adjust = torchaudio.transforms.Vol(gain=args.output_volume, gain_type="amplitude") if args.output_volume != 1 else None
idx = 0
idx_cache = {}
for i, file in enumerate(os.listdir(outdir)):
filename = os.path.basename(file)
extension = os.path.splitext(filename)[-1][1:]
if extension != "json" and extension != "wav":
continue
match = re.findall(rf"^{cleanup_voice_name(voice)}_(\d+)(?:.+?)?{extension}$", filename)
if match and len(match) > 0:
key = int(match[0])
idx_cache[key] = True
if len(idx_cache) > 0:
keys = sorted(list(idx_cache.keys()))
idx = keys[-1] + 1
idx = pad(idx, 4)
def get_name(line=0, candidate=0, combined=False):
name = f"{idx}"
if combined:
name = f"{name}_combined"
elif len(texts) > 1:
name = f"{name}_{line}"
if parameters['candidates'] > 1:
name = f"{name}_{candidate}"
return name
def get_info( voice, settings = None, latents = True ):
info = {}
info.update(parameters)
info['time'] = time.time()-full_start_time
info['datetime'] = datetime.now().isoformat()
info['progress'] = None
del info['progress']
if info['delimiter'] == "\n":
info['delimiter'] = "\\n"
if settings is not None:
for k in settings:
if k in info:
info[k] = settings[k]
return info
INFERENCING = True
for line, cut_text in enumerate(texts):
tqdm_prefix = f'[{str(line+1)}/{str(len(texts))}]'
print(f"{tqdm_prefix} Generating line: {cut_text}")
start_time = time.time()
# do setting editing
match = re.findall(r'^(\{.+\}) (.+?)$', cut_text)
override = None
if match and len(match) > 0:
match = match[0]
try:
override = json.loads(match[0])
cut_text = match[1].strip()
except Exception as e:
raise Exception("Prompt settings editing requested, but received invalid JSON")
settings = get_settings( override=override )
gen = tts.inference(cut_text, **settings )
run_time = time.time()-start_time
print(f"Generating line took {run_time} seconds")
if not isinstance(gen, list):
gen = [gen]
for j, g in enumerate(gen):
wav, sr = g
name = get_name(line=line, candidate=j)
settings['text'] = cut_text
settings['time'] = run_time
settings['datetime'] = datetime.now().isoformat()
# save here in case some error happens mid-batch
if tts.vocos_enabled:
torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu(), sr)
else:
write_wav(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', sr, wav)
wav, sr = torchaudio.load(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
audio_cache[name] = {
'audio': wav,
'settings': get_info(voice=override['voice'] if override and 'voice' in override else voice, settings=settings)
}
del gen
do_gc()
INFERENCING = False
for k in audio_cache:
audio = audio_cache[k]['audio']
audio, _ = resample(audio, tts.output_sample_rate, args.output_sample_rate)
if volume_adjust is not None:
audio = volume_adjust(audio)
audio_cache[k]['audio'] = audio
torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{k}.wav', audio, args.output_sample_rate)
output_voices = []
for candidate in range(parameters['candidates']):
if len(texts) > 1:
audio_clips = []
for line in range(len(texts)):
name = get_name(line=line, candidate=candidate)
audio = audio_cache[name]['audio']
audio_clips.append(audio)
name = get_name(candidate=candidate, combined=True)
audio = torch.cat(audio_clips, dim=-1)
torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', audio, args.output_sample_rate)
audio = audio.squeeze(0).cpu()
audio_cache[name] = {
'audio': audio,
'settings': get_info(voice=voice),
'output': True
}
else:
try:
name = get_name(candidate=candidate)
audio_cache[name]['output'] = True
except Exception as e:
for name in audio_cache:
audio_cache[name]['output'] = True
if args.voice_fixer:
if not voicefixer:
notify_progress("Loading voicefix...", progress=progress)
load_voicefixer()
try:
fixed_cache = {}
for name in tqdm(audio_cache, desc="Running voicefix..."):
del audio_cache[name]['audio']
if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
continue
path = f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav'
fixed = f'{outdir}/{cleanup_voice_name(voice)}_{name}_fixed.wav'
voicefixer.restore(
input=path,
output=fixed,
cuda=get_device_name() == "cuda" and args.voice_fixer_use_cuda,
#mode=mode,
)
fixed_cache[f'{name}_fixed'] = {
'settings': audio_cache[name]['settings'],
'output': True
}
audio_cache[name]['output'] = False
for name in fixed_cache:
audio_cache[name] = fixed_cache[name]
except Exception as e:
print(e)
print("\nFailed to run Voicefixer")
for name in audio_cache:
if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
if args.prune_nonfinal_outputs:
audio_cache[name]['pruned'] = True
os.remove(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
continue
output_voices.append(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
if not args.embed_output_metadata:
with open(f'{outdir}/{cleanup_voice_name(voice)}_{name}.json', 'w', encoding="utf-8") as f:
f.write(json.dumps(audio_cache[name]['settings'], indent='\t') )
if args.embed_output_metadata:
for name in tqdm(audio_cache, desc="Embedding metadata..."):
if 'pruned' in audio_cache[name] and audio_cache[name]['pruned']:
continue
metadata = music_tag.load_file(f"{outdir}/{cleanup_voice_name(voice)}_{name}.wav")
metadata['lyrics'] = json.dumps(audio_cache[name]['settings'])
metadata.save()
if sample_voice is not None:
sample_voice = (tts.input_sample_rate, sample_voice.numpy())
info = get_info(voice=voice, latents=False)
print(f"Generation took {info['time']} seconds, saved to '{output_voices[0]}'\n")
info['seed'] = usedSeed
if 'latents' in info:
del info['latents']
os.makedirs('./config/', exist_ok=True)
with open(f'./config/generate.json', 'w', encoding="utf-8") as f:
f.write(json.dumps(info, indent='\t') )
stats = [
[ parameters['seed'], "{:.3f}".format(info['time']) ]
]
return (
sample_voice,
output_voices,
stats,
)
def generate_valle(**kwargs):
parameters = {}
parameters.update(kwargs)
voice = parameters['voice']
progress = parameters['progress'] if 'progress' in parameters else None
if parameters['seed'] == 0:
parameters['seed'] = None
usedSeed = parameters['seed']
global args
global tts
unload_whisper()
unload_voicefixer()
if not tts:
# should check if it's loading or unloaded, and load it if it's unloaded
if tts_loading:
raise Exception("TTS is still initializing...")
if progress is not None:
notify_progress("Initializing TTS...", progress=progress)
load_tts()
if hasattr(tts, "loading") and tts.loading:
raise Exception("TTS is still initializing...")
do_gc()
voice_samples = None
conditioning_latents = None
sample_voice = None
voice_cache = {}
def fetch_voice( voice ):
if voice in voice_cache:
return voice_cache[voice]
"""
voice_dir = f'./training/{voice}/audio/'
if not os.path.isdir(voice_dir) or len(os.listdir(voice_dir)) == 0:
voice_dir = f'./voices/{voice}/'
files = [ f'{voice_dir}/{d}' for d in os.listdir(voice_dir) if d[-4:] == ".wav" ]
"""
if os.path.isdir(f'./training/{voice}/audio/'):
files = get_voice(name="audio", dir=f"./training/{voice}/", load_latents=False)
else:
files = get_voice(name=voice, load_latents=False)
# return files
voice_cache[voice] = random.sample(files, k=min(3, len(files)))
return voice_cache[voice]
def get_settings( override=None ):
settings = {
'ar_temp': float(parameters['temperature']),
'nar_temp': float(parameters['temperature']),
'max_ar_steps': parameters['num_autoregressive_samples'],
}
# could be better to just do a ternary on everything above, but i am not a professional
selected_voice = voice
if override is not None:
if 'voice' in override:
selected_voice = override['voice']
for k in override:
if k not in settings:
continue
settings[k] = override[k]
settings['references'] = fetch_voice(voice=selected_voice) # [ fetch_voice(voice=selected_voice) for _ in range(3) ]
return settings
if not parameters['delimiter']:
parameters['delimiter'] = "\n"
elif parameters['delimiter'] == "\\n":
parameters['delimiter'] = "\n"
if parameters['delimiter'] and parameters['delimiter'] != "" and parameters['delimiter'] in parameters['text']:
texts = parameters['text'].split(parameters['delimiter'])
else:
texts = split_and_recombine_text(parameters['text'])
full_start_time = time.time()
outdir = f"{args.results_folder}/{voice}/"
os.makedirs(outdir, exist_ok=True)
audio_cache = {}
volume_adjust = torchaudio.transforms.Vol(gain=args.output_volume, gain_type="amplitude") if args.output_volume != 1 else None
idx = 0
idx_cache = {}
for i, file in enumerate(os.listdir(outdir)):
filename = os.path.basename(file)
extension = os.path.splitext(filename)[-1][1:]
if extension != "json" and extension != "wav":
continue
match = re.findall(rf"^{voice}_(\d+)(?:.+?)?{extension}$", filename)
if match and len(match) > 0:
key = int(match[0])
idx_cache[key] = True
if len(idx_cache) > 0:
keys = sorted(list(idx_cache.keys()))
idx = keys[-1] + 1
idx = pad(idx, 4)
def get_name(line=0, candidate=0, combined=False):
name = f"{idx}"
if combined:
name = f"{name}_combined"
elif len(texts) > 1:
name = f"{name}_{line}"
if parameters['candidates'] > 1:
name = f"{name}_{candidate}"
return name
def get_info( voice, settings = None, latents = True ):
info = {}
info.update(parameters)
info['time'] = time.time()-full_start_time
info['datetime'] = datetime.now().isoformat()
info['progress'] = None
del info['progress']
if info['delimiter'] == "\n":
info['delimiter'] = "\\n"
if settings is not None:
for k in settings:
if k in info:
info[k] = settings[k]
return info
INFERENCING = True
for line, cut_text in enumerate(texts):
tqdm_prefix = f'[{str(line+1)}/{str(len(texts))}]'
print(f"{tqdm_prefix} Generating line: {cut_text}")
start_time = time.time()
# do setting editing
match = re.findall(r'^(\{.+\}) (.+?)$', cut_text)
override = None
if match and len(match) > 0:
match = match[0]
try:
override = json.loads(match[0])
cut_text = match[1].strip()
except Exception as e:
raise Exception("Prompt settings editing requested, but received invalid JSON")
name = get_name(line=line, candidate=0)
settings = get_settings( override=override )
references = settings['references']
settings.pop("references")
settings['out_path'] = f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav'
gen = tts.inference(cut_text, references, **settings )
run_time = time.time()-start_time
print(f"Generating line took {run_time} seconds")
if not isinstance(gen, list):
gen = [gen]
for j, g in enumerate(gen):
wav, sr = g
name = get_name(line=line, candidate=j)
settings['text'] = cut_text
settings['time'] = run_time
settings['datetime'] = datetime.now().isoformat()
# save here in case some error happens mid-batch
#torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu(), sr)
#soundfile.write(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu()[0,0], sr)
wav, sr = torchaudio.load(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
audio_cache[name] = {
'audio': wav,
'settings': get_info(voice=override['voice'] if override and 'voice' in override else voice, settings=settings)
}
del gen
do_gc()
INFERENCING = False
for k in audio_cache:
audio = audio_cache[k]['audio']
audio, _ = resample(audio, tts.output_sample_rate, args.output_sample_rate)
if volume_adjust is not None:
audio = volume_adjust(audio)
audio_cache[k]['audio'] = audio
torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{k}.wav', audio, args.output_sample_rate)
output_voices = []
for candidate in range(parameters['candidates']):
if len(texts) > 1:
audio_clips = []
for line in range(len(texts)):
name = get_name(line=line, candidate=candidate)
audio = audio_cache[name]['audio']
audio_clips.append(audio)
name = get_name(candidate=candidate, combined=True)
audio = torch.cat(audio_clips, dim=-1)
torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', audio, args.output_sample_rate)
audio = audio.squeeze(0).cpu()
audio_cache[name] = {
'audio': audio,
'settings': get_info(voice=voice),
'output': True
}
else:
name = get_name(candidate=candidate)
audio_cache[name]['output'] = True
if args.voice_fixer:
if not voicefixer:
notify_progress("Loading voicefix...", progress=progress)
load_voicefixer()
try:
fixed_cache = {}
for name in tqdm(audio_cache, desc="Running voicefix..."):
del audio_cache[name]['audio']
if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
continue
path = f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav'
fixed = f'{outdir}/{cleanup_voice_name(voice)}_{name}_fixed.wav'
voicefixer.restore(
input=path,
output=fixed,
cuda=get_device_name() == "cuda" and args.voice_fixer_use_cuda,
#mode=mode,
)
fixed_cache[f'{name}_fixed'] = {
'settings': audio_cache[name]['settings'],
'output': True
}
audio_cache[name]['output'] = False
for name in fixed_cache:
audio_cache[name] = fixed_cache[name]
except Exception as e:
print(e)
print("\nFailed to run Voicefixer")
for name in audio_cache:
if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
if args.prune_nonfinal_outputs:
audio_cache[name]['pruned'] = True
os.remove(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
continue
output_voices.append(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
if not args.embed_output_metadata:
with open(f'{outdir}/{cleanup_voice_name(voice)}_{name}.json', 'w', encoding="utf-8") as f:
f.write(json.dumps(audio_cache[name]['settings'], indent='\t') )
if args.embed_output_metadata:
for name in tqdm(audio_cache, desc="Embedding metadata..."):
if 'pruned' in audio_cache[name] and audio_cache[name]['pruned']:
continue
metadata = music_tag.load_file(f"{outdir}/{cleanup_voice_name(voice)}_{name}.wav")
metadata['lyrics'] = json.dumps(audio_cache[name]['settings'])
metadata.save()
if sample_voice is not None:
sample_voice = (tts.input_sample_rate, sample_voice.numpy())
info = get_info(voice=voice, latents=False)
print(f"Generation took {info['time']} seconds, saved to '{output_voices[0]}'\n")
info['seed'] = usedSeed
if 'latents' in info:
del info['latents']
os.makedirs('./config/', exist_ok=True)
with open(f'./config/generate.json', 'w', encoding="utf-8") as f:
f.write(json.dumps(info, indent='\t') )
stats = [
[ parameters['seed'], "{:.3f}".format(info['time']) ]
]
return (
sample_voice,
output_voices,
stats,
)
def generate_tortoise(**kwargs):
parameters = {}
parameters.update(kwargs)
voice = parameters['voice']
progress = parameters['progress'] if 'progress' in parameters else None
if parameters['seed'] == 0:
parameters['seed'] = None
usedSeed = parameters['seed']
global args
global tts
unload_whisper()
unload_voicefixer()
if not tts:
# should check if it's loading or unloaded, and load it if it's unloaded
if tts_loading:
raise Exception("TTS is still initializing...")
load_tts()
if hasattr(tts, "loading") and tts.loading:
raise Exception("TTS is still initializing...")
do_gc()
voice_samples = None
conditioning_latents = None
sample_voice = None
voice_cache = {}
def fetch_voice( voice ):
cache_key = f'{voice}:{tts.autoregressive_model_hash[:8]}'
if cache_key in voice_cache:
return voice_cache[cache_key]
print(f"Loading voice: {voice} with model {tts.autoregressive_model_hash[:8]}")
sample_voice = None
if voice == "microphone":
if parameters['mic_audio'] is None:
raise Exception("Please provide audio from mic when choosing `microphone` as a voice input")
voice_samples, conditioning_latents = [load_audio(parameters['mic_audio'], tts.input_sample_rate)], None
elif voice == "random":
voice_samples, conditioning_latents = None, tts.get_random_conditioning_latents()
else:
if progress is not None:
notify_progress(f"Loading voice: {voice}", progress=progress)
voice_samples, conditioning_latents = load_voice(voice, model_hash=tts.autoregressive_model_hash)
if voice_samples and len(voice_samples) > 0:
if conditioning_latents is None:
conditioning_latents = compute_latents(voice=voice, voice_samples=voice_samples, voice_latents_chunks=parameters['voice_latents_chunks'])
sample_voice = torch.cat(voice_samples, dim=-1).squeeze().cpu()
voice_samples = None
voice_cache[cache_key] = (voice_samples, conditioning_latents, sample_voice)
return voice_cache[cache_key]
def get_settings( override=None ):
settings = {
'temperature': float(parameters['temperature']),
'top_p': float(parameters['top_p']),
'diffusion_temperature': float(parameters['diffusion_temperature']),
'length_penalty': float(parameters['length_penalty']),
'repetition_penalty': float(parameters['repetition_penalty']),
'cond_free_k': float(parameters['cond_free_k']),
'num_autoregressive_samples': parameters['num_autoregressive_samples'],
'sample_batch_size': args.sample_batch_size,
'diffusion_iterations': parameters['diffusion_iterations'],
'voice_samples': None,
'conditioning_latents': None,
'use_deterministic_seed': parameters['seed'],
'return_deterministic_state': True,
'k': parameters['candidates'],
'diffusion_sampler': parameters['diffusion_sampler'],
'breathing_room': parameters['breathing_room'],
'half_p': "Half Precision" in parameters['experimentals'],
'cond_free': "Conditioning-Free" in parameters['experimentals'],
'cvvp_amount': parameters['cvvp_weight'],
'autoregressive_model': args.autoregressive_model,
'diffusion_model': args.diffusion_model,
'tokenizer_json': args.tokenizer_json,
}
# could be better to just do a ternary on everything above, but i am not a professional
selected_voice = voice
if override is not None:
if 'voice' in override:
selected_voice = override['voice']
for k in override:
if k not in settings:
continue
settings[k] = override[k]
if settings['autoregressive_model'] is not None:
if settings['autoregressive_model'] == "auto":
settings['autoregressive_model'] = deduce_autoregressive_model(selected_voice)
tts.load_autoregressive_model(settings['autoregressive_model'])
if settings['diffusion_model'] is not None:
if settings['diffusion_model'] == "auto":
settings['diffusion_model'] = deduce_diffusion_model(selected_voice)
tts.load_diffusion_model(settings['diffusion_model'])
if settings['tokenizer_json'] is not None:
tts.load_tokenizer_json(settings['tokenizer_json'])
settings['voice_samples'], settings['conditioning_latents'], _ = fetch_voice(voice=selected_voice)
# clamp it down for the insane users who want this
# it would be wiser to enforce the sample size to the batch size, but this is what the user wants
settings['sample_batch_size'] = args.sample_batch_size
if not settings['sample_batch_size']:
settings['sample_batch_size'] = tts.autoregressive_batch_size
if settings['num_autoregressive_samples'] < settings['sample_batch_size']:
settings['sample_batch_size'] = settings['num_autoregressive_samples']
if settings['conditioning_latents'] is not None and len(settings['conditioning_latents']) == 2 and settings['cvvp_amount'] > 0:
print("Requesting weighing against CVVP weight, but voice latents are missing some extra data. Please regenerate your voice latents with 'Slimmer voice latents' unchecked.")
settings['cvvp_amount'] = 0
return settings
if not parameters['delimiter']:
parameters['delimiter'] = "\n"
elif parameters['delimiter'] == "\\n":
parameters['delimiter'] = "\n"
if parameters['delimiter'] and parameters['delimiter'] != "" and parameters['delimiter'] in parameters['text']:
texts = parameters['text'].split(parameters['delimiter'])
else:
texts = split_and_recombine_text(parameters['text'])
full_start_time = time.time()
outdir = f"{args.results_folder}/{voice}/"
os.makedirs(outdir, exist_ok=True)
audio_cache = {}
volume_adjust = torchaudio.transforms.Vol(gain=args.output_volume, gain_type="amplitude") if args.output_volume != 1 else None
idx = 0
idx_cache = {}
for i, file in enumerate(os.listdir(outdir)):
filename = os.path.basename(file)
extension = os.path.splitext(filename)[-1][1:]
if extension != "json" and extension != "wav":
continue
match = re.findall(rf"^{voice}_(\d+)(?:.+?)?{extension}$", filename)
if match and len(match) > 0:
key = int(match[0])
idx_cache[key] = True
if len(idx_cache) > 0:
keys = sorted(list(idx_cache.keys()))
idx = keys[-1] + 1
idx = pad(idx, 4)
def get_name(line=0, candidate=0, combined=False):
name = f"{idx}"
if combined:
name = f"{name}_combined"
elif len(texts) > 1:
name = f"{name}_{line}"
if parameters['candidates'] > 1:
name = f"{name}_{candidate}"
return name
def get_info( voice, settings = None, latents = True ):
info = {}
info.update(parameters)
info['time'] = time.time()-full_start_time
info['datetime'] = datetime.now().isoformat()
info['model'] = tts.autoregressive_model_path
info['model_hash'] = tts.autoregressive_model_hash
info['progress'] = None
del info['progress']
if info['delimiter'] == "\n":
info['delimiter'] = "\\n"
if settings is not None:
for k in settings:
if k in info:
info[k] = settings[k]
if 'half_p' in settings and 'cond_free' in settings:
info['experimentals'] = []
if settings['half_p']:
info['experimentals'].append("Half Precision")
if settings['cond_free']:
info['experimentals'].append("Conditioning-Free")
if latents and "latents" not in info:
voice = info['voice']
model_hash = settings["model_hash"][:8] if settings is not None and "model_hash" in settings else tts.autoregressive_model_hash[:8]
dir = f'{get_voice_dir()}/{voice}/'
latents_path = f'{dir}/cond_latents_{model_hash}.pth'
if voice == "random" or voice == "microphone":
if latents and settings is not None and settings['conditioning_latents']:
os.makedirs(dir, exist_ok=True)
torch.save(conditioning_latents, latents_path)
if latents_path and os.path.exists(latents_path):
try:
with open(latents_path, 'rb') as f:
info['latents'] = base64.b64encode(f.read()).decode("ascii")
except Exception as e:
pass
return info
INFERENCING = True
for line, cut_text in enumerate(texts):
if should_phonemize():
cut_text = phonemizer( cut_text )
if parameters['emotion'] == "Custom":
if parameters['prompt'] and parameters['prompt'].strip() != "":
cut_text = f"[{parameters['prompt']},] {cut_text}"
elif parameters['emotion'] != "None" and parameters['emotion']:
cut_text = f"[I am really {parameters['emotion'].lower()},] {cut_text}"
tqdm_prefix = f'[{str(line+1)}/{str(len(texts))}]'
print(f"{tqdm_prefix} Generating line: {cut_text}")
start_time = time.time()
# do setting editing
match = re.findall(r'^(\{.+\}) (.+?)$', cut_text)
override = None
if match and len(match) > 0:
match = match[0]
try:
override = json.loads(match[0])
cut_text = match[1].strip()
except Exception as e:
raise Exception("Prompt settings editing requested, but received invalid JSON")
settings = get_settings( override=override )
gen, additionals = tts.tts(cut_text, **settings )
parameters['seed'] = additionals[0]
run_time = time.time()-start_time
print(f"Generating line took {run_time} seconds")
if not isinstance(gen, list):
gen = [gen]
for j, g in enumerate(gen):
audio = g.squeeze(0).cpu()
name = get_name(line=line, candidate=j)
settings['text'] = cut_text
settings['time'] = run_time
settings['datetime'] = datetime.now().isoformat()
if args.tts_backend == "tortoise":
settings['model'] = tts.autoregressive_model_path
settings['model_hash'] = tts.autoregressive_model_hash
audio_cache[name] = {
'audio': audio,
'settings': get_info(voice=override['voice'] if override and 'voice' in override else voice, settings=settings)
}
# save here in case some error happens mid-batch
torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', audio, tts.output_sample_rate)
del gen
do_gc()
INFERENCING = False
for k in audio_cache:
audio = audio_cache[k]['audio']
audio, _ = resample(audio, tts.output_sample_rate, args.output_sample_rate)
if volume_adjust is not None:
audio = volume_adjust(audio)
audio_cache[k]['audio'] = audio
torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{k}.wav', audio, args.output_sample_rate)
output_voices = []
for candidate in range(parameters['candidates']):
if len(texts) > 1:
audio_clips = []
for line in range(len(texts)):
name = get_name(line=line, candidate=candidate)
audio = audio_cache[name]['audio']
audio_clips.append(audio)
name = get_name(candidate=candidate, combined=True)
audio = torch.cat(audio_clips, dim=-1)
torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', audio, args.output_sample_rate)
audio = audio.squeeze(0).cpu()
audio_cache[name] = {
'audio': audio,
'settings': get_info(voice=voice),
'output': True
}
else:
name = get_name(candidate=candidate)
audio_cache[name]['output'] = True
if args.voice_fixer:
if not voicefixer:
notify_progress("Loading voicefix...", progress=progress)
load_voicefixer()
try:
fixed_cache = {}
for name in tqdm(audio_cache, desc="Running voicefix..."):
del audio_cache[name]['audio']
if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
continue
path = f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav'
fixed = f'{outdir}/{cleanup_voice_name(voice)}_{name}_fixed.wav'
voicefixer.restore(
input=path,
output=fixed,
cuda=get_device_name() == "cuda" and args.voice_fixer_use_cuda,
#mode=mode,
)
fixed_cache[f'{name}_fixed'] = {
'settings': audio_cache[name]['settings'],
'output': True
}
audio_cache[name]['output'] = False
for name in fixed_cache:
audio_cache[name] = fixed_cache[name]
except Exception as e:
print(e)
print("\nFailed to run Voicefixer")
for name in audio_cache:
if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
if args.prune_nonfinal_outputs:
audio_cache[name]['pruned'] = True
os.remove(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
continue
output_voices.append(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
if not args.embed_output_metadata:
with open(f'{outdir}/{cleanup_voice_name(voice)}_{name}.json', 'w', encoding="utf-8") as f:
f.write(json.dumps(audio_cache[name]['settings'], indent='\t') )
if args.embed_output_metadata:
for name in tqdm(audio_cache, desc="Embedding metadata..."):
if 'pruned' in audio_cache[name] and audio_cache[name]['pruned']:
continue
metadata = music_tag.load_file(f"{outdir}/{cleanup_voice_name(voice)}_{name}.wav")
metadata['lyrics'] = json.dumps(audio_cache[name]['settings'])
metadata.save()
if sample_voice is not None:
sample_voice = (tts.input_sample_rate, sample_voice.numpy())
info = get_info(voice=voice, latents=False)
print(f"Generation took {info['time']} seconds, saved to '{output_voices[0]}'\n")
info['seed'] = usedSeed
if 'latents' in info:
del info['latents']
os.makedirs('./config/', exist_ok=True)
with open(f'./config/generate.json', 'w', encoding="utf-8") as f:
f.write(json.dumps(info, indent='\t') )
stats = [
[ parameters['seed'], "{:.3f}".format(info['time']) ]
]
return (
sample_voice,
output_voices,
stats,
)
def cancel_generate():
if not INFERENCING:
return
import tortoise.api
tortoise.api.STOP_SIGNAL = True
def hash_file(path, algo="md5", buffer_size=0):
hash = None
if algo == "md5":
hash = hashlib.md5()
elif algo == "sha1":
hash = hashlib.sha1()
else:
raise Exception(f'Unknown hash algorithm specified: {algo}')
if not os.path.exists(path):
raise Exception(f'Path not found: {path}')
with open(path, 'rb') as f:
if buffer_size > 0:
while True:
data = f.read(buffer_size)
if not data:
break
hash.update(data)
else:
hash.update(f.read())
return "{0}".format(hash.hexdigest())
def update_baseline_for_latents_chunks( voice ):
global current_voice
current_voice = voice
path = f'{get_voice_dir()}/{voice}/'
if not os.path.isdir(path):
return 1
dataset_file = f'./training/{voice}/train.txt'
if os.path.exists(dataset_file):
return 0 # 0 will leverage using the LJspeech dataset for computing latents
files = os.listdir(path)
total = 0
total_duration = 0
for file in files:
if file[-4:] != ".wav":
continue
metadata = torchaudio.info(f'{path}/{file}')
duration = metadata.num_frames / metadata.sample_rate
total_duration += duration
total = total + 1
# brain too fried to figure out a better way
if args.autocalculate_voice_chunk_duration_size == 0:
return int(total_duration / total) if total > 0 else 1
return int(total_duration / args.autocalculate_voice_chunk_duration_size) if total_duration > 0 else 1
def compute_latents(voice=None, voice_samples=None, voice_latents_chunks=0, original_ar=False, original_diffusion=False):
global tts
global args
unload_whisper()
unload_voicefixer()
if not tts:
if tts_loading:
raise Exception("TTS is still initializing...")
load_tts()
if hasattr(tts, "loading") and tts.loading:
raise Exception("TTS is still initializing...")
if args.tts_backend == "bark":
tts.create_voice( voice )
return
if args.autoregressive_model == "auto":
tts.load_autoregressive_model(deduce_autoregressive_model(voice))
if voice:
load_from_dataset = voice_latents_chunks == 0
if load_from_dataset:
dataset_path = f'./training/{voice}/train.txt'
if not os.path.exists(dataset_path):
load_from_dataset = False
else:
with open(dataset_path, 'r', encoding="utf-8") as f:
lines = f.readlines()
print("Leveraging dataset for computing latents")
voice_samples = []
max_length = 0
for line in lines:
filename = f'./training/{voice}/{line.split("|")[0]}'
waveform = load_audio(filename, 22050)
max_length = max(max_length, waveform.shape[-1])
voice_samples.append(waveform)
for i in range(len(voice_samples)):
voice_samples[i] = pad_or_truncate(voice_samples[i], max_length)
voice_latents_chunks = len(voice_samples)
if voice_latents_chunks == 0:
print("Dataset is empty!")
load_from_dataset = True
if not load_from_dataset:
voice_samples, _ = load_voice(voice, load_latents=False)
if voice_samples is None:
return
conditioning_latents = tts.get_conditioning_latents(voice_samples, return_mels=not args.latents_lean_and_mean, slices=voice_latents_chunks, force_cpu=args.force_cpu_for_conditioning_latents, original_ar=original_ar, original_diffusion=original_diffusion)
if len(conditioning_latents) == 4:
conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None)
outfile = f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth'
torch.save(conditioning_latents, outfile)
print(f'Saved voice latents: {outfile}')
return conditioning_latents
# superfluous, but it cleans up some things
class TrainingState():
def __init__(self, config_path, keep_x_past_checkpoints=0, start=True):
self.killed = False
self.training_dir = os.path.dirname(config_path)
with open(config_path, 'r') as file:
self.yaml_config = yaml.safe_load(file)
self.json_config = json.load(open(f"{self.training_dir}/train.json", 'r', encoding="utf-8"))
self.dataset_path = f"{self.training_dir}/train.txt"
with open(self.dataset_path, 'r', encoding="utf-8") as f:
self.dataset_size = len(f.readlines())
self.batch_size = self.json_config["batch_size"]
self.save_rate = self.json_config["save_rate"]
self.epoch = 0
self.epochs = self.json_config["epochs"]
self.it = 0
self.its = calc_iterations( self.epochs, self.dataset_size, self.batch_size )
self.step = 0
self.steps = int(self.its / self.dataset_size)
self.checkpoint = 0
self.checkpoints = int((self.its - self.it) / self.save_rate)
self.gpus = self.json_config['gpus']
self.buffer = []
self.open_state = False
self.training_started = False
self.info = {}
self.it_rate = ""
self.it_rates = 0
self.epoch_rate = ""
self.eta = "?"
self.eta_hhmmss = "?"
self.nan_detected = False
self.last_info_check_at = 0
self.statistics = {
'loss': [],
'lr': [],
'grad_norm': [],
}
self.losses = []
self.metrics = {
'step': "",
'rate': "",
'loss': "",
}
self.loss_milestones = [ 1.0, 0.15, 0.05 ]
if args.tts_backend=="vall-e":
self.valle_last_it = 0
self.valle_steps = 0
if keep_x_past_checkpoints > 0:
self.cleanup_old(keep=keep_x_past_checkpoints)
if start:
self.spawn_process(config_path=config_path, gpus=self.gpus)
def spawn_process(self, config_path, gpus=1):
if args.tts_backend == "vall-e":
self.cmd = ['deepspeed', f'--num_gpus={gpus}', '--module', 'vall_e.train', f'yaml="{config_path}"']
else:
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['./train.sh', config_path]
print("Spawning process: ", " ".join(self.cmd))
self.process = subprocess.Popen(self.cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
def parse_metrics(self, data):
if isinstance(data, str):
if line.find('Training Metrics:') >= 0:
data = json.loads(line.split("Training Metrics:")[-1])
data['mode'] = "training"
elif line.find('Validation Metrics:') >= 0:
data = json.loads(line.split("Validation Metrics:")[-1])
data['mode'] = "validation"
else:
return
self.info = data
if 'epoch' in self.info:
self.epoch = int(self.info['epoch'])
if 'it' in self.info:
self.it = int(self.info['it'])
if 'step' in self.info:
self.step = int(self.info['step'])
if 'steps' in self.info:
self.steps = int(self.info['steps'])
if 'elapsed_time' in self.info:
self.info['iteration_rate'] = self.info['elapsed_time']
del self.info['elapsed_time']
if 'iteration_rate' in self.info:
it_rate = self.info['iteration_rate']
self.it_rate = f'{"{:.3f}".format(1/it_rate)}it/s' if 0 < it_rate and it_rate < 1 else f'{"{:.3f}".format(it_rate)}s/it'
self.it_rates += it_rate
if self.it_rates > 0 and self.it * self.steps > 0:
epoch_rate = self.it_rates / self.it * self.steps
self.epoch_rate = f'{"{:.3f}".format(1/epoch_rate)}epoch/s' if 0 < epoch_rate and epoch_rate < 1 else f'{"{:.3f}".format(epoch_rate)}s/epoch'
try:
self.eta = (self.its - self.it) * (self.it_rates / self.it)
eta = str(timedelta(seconds=int(self.eta)))
self.eta_hhmmss = eta
except Exception as e:
self.eta_hhmmss = "?"
pass
self.metrics['step'] = [f"{self.epoch}/{self.epochs}"]
if self.epochs != self.its:
self.metrics['step'].append(f"{self.it}/{self.its}")
if self.steps > 1:
self.metrics['step'].append(f"{self.step}/{self.steps}")
self.metrics['step'] = ", ".join(self.metrics['step'])
if args.tts_backend == "tortoise":
epoch = self.epoch + (self.step / self.steps)
else:
epoch = self.info['epoch'] if 'epoch' in self.info else self.it
if self.it > 0:
# probably can double for-loop but whatever
keys = {
'lrs': ['lr'],
'losses': ['loss_text_ce', 'loss_mel_ce'],
'accuracies': [],
'precisions': [],
'grad_norms': [],
}
if args.tts_backend == "vall-e":
keys['lrs'] = [
'ar.lr', 'nar.lr',
]
keys['losses'] = [
# 'ar.loss', 'nar.loss', 'ar+nar.loss',
'ar.loss.nll', 'nar.loss.nll',
]
keys['accuracies'] = [
'ar.loss.acc', 'nar.loss.acc',
'ar.stats.acc', 'nar.loss.acc',
]
keys['precisions'] = [ 'ar.loss.precision', 'nar.loss.precision', ]
keys['grad_norms'] = ['ar.grad_norm', 'nar.grad_norm']
for k in keys['lrs']:
if k not in self.info:
continue
self.statistics['lr'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k})
for k in keys['accuracies']:
if k not in self.info:
continue
self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k})
for k in keys['precisions']:
if k not in self.info:
continue
self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k})
for k in keys['losses']:
if k not in self.info:
continue
prefix = ""
if "mode" in self.info and self.info["mode"] == "validation":
prefix = f'{self.info["name"] if "name" in self.info else "val"}_'
self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{prefix}{k}' })
self.losses.append( self.statistics['loss'][-1] )
for k in keys['grad_norms']:
if k not in self.info:
continue
self.statistics['grad_norm'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k})
return data
def get_status(self):
message = None
self.metrics['rate'] = []
if self.epoch_rate:
self.metrics['rate'].append(self.epoch_rate)
if self.it_rate and self.epoch_rate[:-7] != self.it_rate[:-4]:
self.metrics['rate'].append(self.it_rate)
self.metrics['rate'] = ", ".join(self.metrics['rate'])
eta_hhmmss = self.eta_hhmmss if self.eta_hhmmss else "?"
self.metrics['loss'] = []
if 'lr' in self.info:
self.metrics['loss'].append(f'LR: {"{:.3e}".format(self.info["lr"])}')
if len(self.losses) > 0:
self.metrics['loss'].append(f'Loss: {"{:.3f}".format(self.losses[-1]["value"])}')
if False and len(self.losses) >= 2:
deriv = 0
accum_length = len(self.losses)//2 # i *guess* this is fine when you think about it
loss_value = self.losses[-1]["value"]
for i in range(accum_length):
d1_loss = self.losses[accum_length-i-1]["value"]
d2_loss = self.losses[accum_length-i-2]["value"]
dloss = (d2_loss - d1_loss)
d1_step = self.losses[accum_length-i-1]["it"]
d2_step = self.losses[accum_length-i-2]["it"]
dstep = (d2_step - d1_step)
if dstep == 0:
continue
inst_deriv = dloss / dstep
deriv += inst_deriv
deriv = deriv / accum_length
print("Deriv: ", deriv)
if deriv != 0: # dloss < 0:
next_milestone = None
for milestone in self.loss_milestones:
if loss_value > milestone:
next_milestone = milestone
break
print(f"Loss value: {loss_value} | Next milestone: {next_milestone} | Distance: {loss_value - next_milestone}")
if next_milestone:
# tfw can do simple calculus but not basic algebra in my head
est_its = (next_milestone - loss_value) / deriv * 100
print(f"Estimated: {est_its}")
if est_its >= 0:
self.metrics['loss'].append(f'Est. milestone {next_milestone} in: {int(est_its)}its')
else:
est_loss = inst_deriv * (self.its - self.it) + loss_value
if est_loss >= 0:
self.metrics['loss'].append(f'Est. final loss: {"{:.3f}".format(est_loss)}')
self.metrics['loss'] = ", ".join(self.metrics['loss'])
message = f"[{self.metrics['step']}] [{self.metrics['rate']}] [ETA: {eta_hhmmss}] [{self.metrics['loss']}]"
if self.nan_detected:
message = f"[!NaN DETECTED! {self.nan_detected}] {message}"
return message
def load_statistics(self, update=False):
if not os.path.isdir(self.training_dir):
return
if args.tts_backend == "tortoise":
logs = sorted([f'{self.training_dir}/finetune/{d}' for d in os.listdir(f'{self.training_dir}/finetune/') if d[-4:] == ".log" ])
else:
log_dir = "logs"
logs = sorted([f'{self.training_dir}/{log_dir}/{d}/log.txt' for d in os.listdir(f'{self.training_dir}/{log_dir}/') ])
if update:
logs = [logs[-1]]
infos = {}
highest_step = self.last_info_check_at
if not update:
self.statistics['loss'] = []
self.statistics['lr'] = []
self.statistics['grad_norm'] = []
self.it_rates = 0
unq = {}
averager = None
prev_state = 0
for log in logs:
with open(log, 'r', encoding="utf-8") as f:
lines = f.readlines()
for line in lines:
line = line.strip()
if not line:
continue
if line[-1] == ".":
line = line[:-1]
if line.find('Training Metrics:') >= 0:
split = line.split("Training Metrics:")[-1]
data = json.loads(split)
name = "train"
mode = "training"
prev_state = 0
elif line.find('Validation Metrics:') >= 0:
data = json.loads(line.split("Validation Metrics:")[-1])
if "it" not in data:
data['it'] = it
if "epoch" not in data:
data['epoch'] = epoch
# name = data['name'] if 'name' in data else "val"
mode = "validation"
if prev_state == 0:
name = "subtrain"
else:
name = "val"
prev_state += 1
else:
continue
if "it" not in data:
continue
it = data['it']
epoch = data['epoch']
if args.tts_backend == "vall-e":
if not averager or averager['key'] != f'{it}_{name}' or averager['mode'] != mode:
averager = {
'key': f'{it}_{name}',
'name': name,
'mode': mode,
"metrics": {}
}
for k in data:
if data[k] is None:
continue
averager['metrics'][k] = [ data[k] ]
else:
for k in data:
if data[k] is None:
continue
if k not in averager['metrics']:
averager['metrics'][k] = [ data[k] ]
else:
averager['metrics'][k].append( data[k] )
unq[f'{it}_{mode}_{name}'] = averager
else:
unq[f'{it}_{mode}_{name}'] = data
if update and it <= self.last_info_check_at:
continue
blacklist = [ "batch", "eval" ]
for it in unq:
if args.tts_backend == "vall-e":
stats = unq[it]
data = {k: sum(v) / len(v) for k, v in stats['metrics'].items() if k not in blacklist }
#data = {k: min(v) for k, v in stats['metrics'].items() if k not in blacklist }
#data = {k: max(v) for k, v in stats['metrics'].items() if k not in blacklist }
data['name'] = stats['name']
data['mode'] = stats['mode']
data['steps'] = len(stats['metrics']['it'])
else:
data = unq[it]
self.parse_metrics(data)
self.last_info_check_at = highest_step
def cleanup_old(self, keep=2):
if keep <= 0:
return
if args.tts_backend == "vall-e":
return
if not os.path.isdir(f'{self.training_dir}/finetune/'):
return
models = sorted([ int(d[:-8]) for d in os.listdir(f'{self.training_dir}/finetune/models/') if d[-8:] == "_gpt.pth" ])
states = sorted([ int(d[:-6]) for d in os.listdir(f'{self.training_dir}/finetune/training_state/') if d[-6:] == ".state" ])
remove_models = models[:-keep]
remove_states = states[:-keep]
for d in remove_models:
path = f'{self.training_dir}/finetune/models/{d}_gpt.pth'
print("Removing", path)
os.remove(path)
for d in remove_states:
path = f'{self.training_dir}/finetune/training_state/{d}.state'
print("Removing", path)
os.remove(path)
def parse(self, line, verbose=False, keep_x_past_checkpoints=0, buffer_size=8, progress=None ):
self.buffer.append(f'{line}')
data = None
percent = 0
message = None
should_return = False
MESSAGE_START = 'Start training from epoch'
MESSAGE_FINSIHED = 'Finished training'
MESSAGE_SAVING = 'Saving models and training states.'
MESSAGE_METRICS_TRAINING = 'Training Metrics:'
MESSAGE_METRICS_VALIDATION = 'Validation Metrics:'
if line.find(MESSAGE_FINSIHED) >= 0:
self.killed = True
# rip out iteration info
elif not self.training_started:
if line.find(MESSAGE_START) >= 0:
self.training_started = True # could just leverage the above variable, but this is python, and there's no point in these aggressive microoptimizations
match = re.findall(r'epoch: ([\d,]+)', line)
if match and len(match) > 0:
self.epoch = int(match[0].replace(",", ""))
match = re.findall(r'iter: ([\d,]+)', line)
if match and len(match) > 0:
self.it = int(match[0].replace(",", ""))
self.checkpoints = int((self.its - self.it) / self.save_rate)
self.load_statistics()
should_return = True
else:
if line.find(MESSAGE_SAVING) >= 0:
self.checkpoint += 1
message = f"[{self.checkpoint}/{self.checkpoints}] Saving checkpoint..."
percent = self.checkpoint / self.checkpoints
self.cleanup_old(keep=keep_x_past_checkpoints)
elif line.find(MESSAGE_METRICS_TRAINING) >= 0:
data = json.loads(line.split(MESSAGE_METRICS_TRAINING)[-1])
data['mode'] = "training"
elif line.find(MESSAGE_METRICS_VALIDATION) >= 0:
data = json.loads(line.split(MESSAGE_METRICS_VALIDATION)[-1])
data['mode'] = "validation"
if data is not None:
if ': nan' in line and not self.nan_detected:
self.nan_detected = self.it
self.parse_metrics( data )
message = self.get_status()
if message:
percent = self.it / float(self.its) # self.epoch / float(self.epochs)
if progress is not None:
progress(percent, message)
self.buffer.append(f'[{"{:.3f}".format(percent*100)}%] {message}')
should_return = True
if verbose and not self.training_started:
should_return = True
self.buffer = self.buffer[-buffer_size:]
result = None
if should_return:
result = "".join(self.buffer) if not self.training_started else message
return (
result,
percent,
message,
)
try:
import altair as alt
alt.data_transformers.enable('default', max_rows=None)
except Exception as e:
print(e)
pass
def run_training(config_path, verbose=False, keep_x_past_checkpoints=0, progress=gr.Progress(track_tqdm=True)):
global training_state
if training_state and training_state.process:
return "Training already in progress"
# ensure we have the dvae.pth
if args.tts_backend == "tortoise":
get_model_path('dvae.pth')
# I don't know if this is still necessary, as it was bitching at me for not doing this, despite it being in a separate process
torch.multiprocessing.freeze_support()
unload_tts()
unload_whisper()
unload_voicefixer()
training_state = TrainingState(config_path=config_path, keep_x_past_checkpoints=keep_x_past_checkpoints)
for line in iter(training_state.process.stdout.readline, ""):
if training_state is None or training_state.killed:
return
result, percent, message = training_state.parse( line=line, verbose=verbose, keep_x_past_checkpoints=keep_x_past_checkpoints, progress=progress )
print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}")
if result:
yield result
if progress is not None and message:
progress(percent, message)
if training_state:
training_state.process.stdout.close()
return_code = training_state.process.wait()
training_state = None
def update_training_dataplot(x_min=None, x_max=None, y_min=None, y_max=None, config_path=None):
global training_state
losses = None
lrs = None
grad_norms = None
x_lim = [ x_min, x_max ]
y_lim = [ y_min, y_max ]
if not training_state:
if config_path:
training_state = TrainingState(config_path=config_path, start=False)
training_state.load_statistics()
message = training_state.get_status()
if training_state:
if not x_lim[-1]:
x_lim[-1] = training_state.epochs
if not y_lim[-1]:
y_lim = None
if len(training_state.statistics['loss']) > 0:
losses = gr.LinePlot.update(
value = pd.DataFrame(training_state.statistics['loss']),
x_lim=x_lim, y_lim=y_lim,
x="epoch", y="value", # x="it",
title="Loss Metrics", color="type", tooltip=['epoch', 'it', 'value', 'type'],
width=500, height=350
)
if len(training_state.statistics['lr']) > 0:
lrs = gr.LinePlot.update(
value = pd.DataFrame(training_state.statistics['lr']),
x_lim=x_lim,
x="epoch", y="value", # x="it",
title="Learning Rate", color="type", tooltip=['epoch', 'it', 'value', 'type'],
width=500, height=350
)
if len(training_state.statistics['grad_norm']) > 0:
grad_norms = gr.LinePlot.update(
value = pd.DataFrame(training_state.statistics['grad_norm']),
x_lim=x_lim,
x="epoch", y="value", # x="it",
title="Gradient Normals", color="type", tooltip=['epoch', 'it', 'value', 'type'],
width=500, height=350
)
if config_path:
del training_state
training_state = None
return (losses, lrs, grad_norms)
def reconnect_training(verbose=False, progress=gr.Progress(track_tqdm=True)):
global training_state
if not training_state or not training_state.process:
return "Training not in progress"
for line in iter(training_state.process.stdout.readline, ""):
result, percent, message = training_state.parse( line=line, verbose=verbose, progress=progress )
print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}")
if result:
yield result
if progress is not None and message:
progress(percent, message)
def stop_training():
global training_state
if training_state is None:
return "No training in progress"
print("Killing training process...")
training_state.killed = True
children = []
if args.tts_backend == "tortoise":
# wrapped in a try/catch in case for some reason this fails outside of Linux
try:
children = [p.info for p in psutil.process_iter(attrs=['pid', 'name', 'cmdline']) if './src/train.py' in p.info['cmdline']]
except Exception as e:
pass
training_state.process.stdout.close()
training_state.process.terminate()
training_state.process.kill()
elif args.tts_backend == "vall-e":
print(training_state.process.communicate(input='quit')[0])
return_code = training_state.process.wait()
for p in children:
os.kill( p['pid'], signal.SIGKILL )
training_state = None
print("Killed training process.")
return f"Training cancelled: {return_code}"
def get_halfp_model_path():
autoregressive_model_path = get_model_path('autoregressive.pth')
return autoregressive_model_path.replace(".pth", "_half.pth")
def convert_to_halfp():
autoregressive_model_path = get_model_path('autoregressive.pth')
print(f'Converting model to half precision: {autoregressive_model_path}')
model = torch.load(autoregressive_model_path)
for k in model:
model[k] = model[k].half()
outfile = get_halfp_model_path()
torch.save(model, outfile)
print(f'Converted model to half precision: {outfile}')
# collapses short segments into the previous segment
def whisper_sanitize( results ):
sanitized = json.loads(json.dumps(results))
sanitized['segments'] = []
for segment in results['segments']:
length = segment['end'] - segment['start']
if length >= MIN_TRAINING_DURATION or len(sanitized['segments']) == 0:
sanitized['segments'].append(segment)
continue
last_segment = sanitized['segments'][-1]
# segment already asimilitated it, somehow
if last_segment['end'] >= segment['end']:
continue
"""
# segment already asimilitated it, somehow
if last_segment['text'].endswith(segment['text']):
continue
"""
last_segment['text'] += segment['text']
last_segment['end'] = segment['end']
for i in range(len(sanitized['segments'])):
sanitized['segments'][i]['id'] = i
return sanitized
def whisper_transcribe( file, language=None ):
# shouldn't happen, but it's for safety
global whisper_model
global whisper_align_model
if not whisper_model:
load_whisper_model(language=language)
if args.whisper_backend == "openai/whisper":
if not language:
language = None
return whisper_model.transcribe(file, language=language)
if args.whisper_backend == "lightmare/whispercpp":
res = whisper_model.transcribe(file)
segments = whisper_model.extract_text_and_timestamps( res )
result = {
'text': [],
'segments': []
}
for segment in segments:
reparsed = {
'start': segment[0] / 100.0,
'end': segment[1] / 100.0,
'text': segment[2],
'id': len(result['segments'])
}
result['text'].append( segment[2] )
result['segments'].append(reparsed)
result['text'] = " ".join(result['text'])
return result
if args.whisper_backend == "m-bain/whisperx":
import whisperx
device = "cuda" if get_device_name() == "cuda" else "cpu"
result = whisper_model.transcribe(file, batch_size=args.whisper_batchsize)
align_model, metadata = whisper_align_model
result_aligned = whisperx.align(result["segments"], align_model, metadata, file, device, return_char_alignments=False)
result['segments'] = result_aligned['segments']
result['text'] = []
for segment in result['segments']:
segment['id'] = len(result['text'])
result['text'].append(segment['text'].strip())
result['text'] = " ".join(result['text'])
return result
def validate_waveform( waveform, sample_rate, min_only=False ):
if not torch.any(waveform < 0):
return "Waveform is empty"
num_channels, num_frames = waveform.shape
duration = num_frames / sample_rate
if duration < MIN_TRAINING_DURATION:
return "Duration too short ({:.3f}s < {:.3f}s)".format(duration, MIN_TRAINING_DURATION)
if not min_only:
if duration > MAX_TRAINING_DURATION:
return "Duration too long ({:.3f}s < {:.3f}s)".format(MAX_TRAINING_DURATION, duration)
return
def transcribe_dataset( voice, language=None, skip_existings=False, progress=None ):
unload_tts()
global whisper_model
if whisper_model is None:
load_whisper_model(language=language)
results = {}
files = get_voice(voice, load_latents=False)
indir = f'./training/{voice}/'
infile = f'{indir}/whisper.json'
quantize_in_memory = args.tts_backend == "vall-e"
os.makedirs(f'{indir}/audio/', exist_ok=True)
TARGET_SAMPLE_RATE = 22050
if args.tts_backend != "tortoise":
TARGET_SAMPLE_RATE = 24000
if tts:
TARGET_SAMPLE_RATE = tts.input_sample_rate
if os.path.exists(infile):
results = json.load(open(infile, 'r', encoding="utf-8"))
for file in tqdm(files, desc="Iterating through voice files"):
basename = os.path.basename(file)
if basename in results and skip_existings:
print(f"Skipping already parsed file: {basename}")
continue
try:
result = whisper_transcribe(file, language=language)
except Exception as e:
print("Failed to transcribe:", file, e)
continue
results[basename] = result
if not quantize_in_memory:
waveform, sample_rate = torchaudio.load(file)
# resample to the input rate, since it'll get resampled for training anyways
# this should also "help" increase throughput a bit when filling the dataloaders
waveform, sample_rate = resample(waveform, sample_rate, TARGET_SAMPLE_RATE)
if waveform.shape[0] == 2:
waveform = waveform[:1]
try:
kwargs = {}
if basename[-4:] == ".wav":
kwargs['encoding'] = "PCM_S"
kwargs['bits_per_sample'] = 16
torchaudio.save(f"{indir}/audio/{basename}", waveform, sample_rate, **kwargs)
except Exception as e:
print(e)
with open(infile, 'w', encoding="utf-8") as f:
f.write(json.dumps(results, indent='\t'))
do_gc()
modified = False
for basename in results:
try:
sanitized = whisper_sanitize(results[basename])
if len(sanitized['segments']) > 0 and len(sanitized['segments']) != len(results[basename]['segments']):
results[basename] = sanitized
modified = True
print("Segments sanizited: ", basename)
except Exception as e:
print("Failed to sanitize:", basename, e)
pass
if modified:
os.rename(infile, infile.replace(".json", ".unsanitized.json"))
with open(infile, 'w', encoding="utf-8") as f:
f.write(json.dumps(results, indent='\t'))
return f"Processed dataset to: {indir}"
def slice_waveform( waveform, sample_rate, start, end, trim ):
start = int(start * sample_rate)
end = int(end * sample_rate)
if start < 0:
start = 0
if end >= waveform.shape[-1]:
end = waveform.shape[-1] - 1
sliced = waveform[:, start:end]
error = validate_waveform( sliced, sample_rate, min_only=True )
if trim and not error:
sliced = torchaudio.functional.vad( sliced, sample_rate )
return sliced, error
def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, results=None, progress=gr.Progress() ):
indir = f'./training/{voice}/'
infile = f'{indir}/whisper.json'
messages = []
if not os.path.exists(infile):
message = f"Missing dataset: {infile}"
print(message)
return message
if results is None:
results = json.load(open(infile, 'r', encoding="utf-8"))
TARGET_SAMPLE_RATE = 22050
if args.tts_backend != "tortoise":
TARGET_SAMPLE_RATE = 24000
if tts:
TARGET_SAMPLE_RATE = tts.input_sample_rate
files = 0
segments = 0
for filename in results:
path = f'./voices/{voice}/{filename}'
extension = os.path.splitext(filename)[-1][1:]
out_extension = extension # "wav"
if not os.path.exists(path):
path = f'./training/{voice}/{filename}'
if not os.path.exists(path):
message = f"Missing source audio: {filename}"
print(message)
messages.append(message)
continue
files += 1
result = results[filename]
waveform, sample_rate = torchaudio.load(path)
num_channels, num_frames = waveform.shape
duration = num_frames / sample_rate
for segment in result['segments']:
file = filename.replace(f".{extension}", f"_{pad(segment['id'], 4)}.{out_extension}")
sliced, error = slice_waveform( waveform, sample_rate, segment['start'] + start_offset, segment['end'] + end_offset, trim_silence )
if error:
message = f"{error}, skipping... {file}"
print(message)
messages.append(message)
continue
sliced, _ = resample( sliced, sample_rate, TARGET_SAMPLE_RATE )
if waveform.shape[0] == 2:
waveform = waveform[:1]
kwargs = {}
if file[-4:] == ".wav":
kwargs['encoding'] = "PCM_S"
kwargs['bits_per_sample'] = 16
torchaudio.save(f"{indir}/audio/{file}", sliced, TARGET_SAMPLE_RATE, **kwargs)
segments +=1
messages.append(f"Sliced segments: {files} => {segments}.")