forked from mrq/ai-voice-cloning
added VALL-E inference support (very rudimentary, gimped, but it will load a model trained on a config generated through the web UI)
This commit is contained in:
parent
9b01377667
commit
4744120be2
|
@ -1,61 +0,0 @@
|
||||||
{
|
|
||||||
"optimizer": {
|
|
||||||
"type": "AdamW",
|
|
||||||
"params": {
|
|
||||||
"lr": 2e-05,
|
|
||||||
"betas": [
|
|
||||||
0.9,
|
|
||||||
0.96
|
|
||||||
],
|
|
||||||
"eps": 1e-07,
|
|
||||||
"weight_decay": 0.01
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"scheduler":{
|
|
||||||
"type":"WarmupLR",
|
|
||||||
"params":{
|
|
||||||
"warmup_min_lr":0,
|
|
||||||
"warmup_max_lr":2e-5,
|
|
||||||
"warmup_num_steps":100,
|
|
||||||
"warmup_type":"linear"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"fp16":{
|
|
||||||
"enabled":true,
|
|
||||||
"loss_scale":0,
|
|
||||||
"loss_scale_window":1000,
|
|
||||||
"initial_scale_power":16,
|
|
||||||
"hysteresis":2,
|
|
||||||
"min_loss_scale":1
|
|
||||||
},
|
|
||||||
"autotuning":{
|
|
||||||
"enabled":false,
|
|
||||||
"results_dir":"./config/autotune/results",
|
|
||||||
"exps_dir":"./config/autotune/exps",
|
|
||||||
"overwrite":false,
|
|
||||||
"metric":"throughput",
|
|
||||||
"start_profile_step":10,
|
|
||||||
"end_profile_step":20,
|
|
||||||
"fast":false,
|
|
||||||
"max_train_batch_size":32,
|
|
||||||
"mp_size":1,
|
|
||||||
"num_tuning_micro_batch_sizes":3,
|
|
||||||
"tuner_type":"model_based",
|
|
||||||
"tuner_early_stopping":5,
|
|
||||||
"tuner_num_trials":50,
|
|
||||||
"arg_mappings":{
|
|
||||||
"train_micro_batch_size_per_gpu":"--per_device_train_batch_size",
|
|
||||||
"gradient_accumulation_steps ":"--gradient_accumulation_steps"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"zero_optimization":{
|
|
||||||
"stage":0,
|
|
||||||
"reduce_bucket_size":"auto",
|
|
||||||
"contiguous_gradients":true,
|
|
||||||
"sub_group_size":1e8,
|
|
||||||
"stage3_prefetch_bucket_size":"auto",
|
|
||||||
"stage3_param_persistence_threshold":"auto",
|
|
||||||
"stage3_max_live_parameters":"auto",
|
|
||||||
"stage3_max_reuse_distance":"auto"
|
|
||||||
}
|
|
||||||
}
|
|
403
src/utils.py
403
src/utils.py
|
@ -22,6 +22,7 @@ import psutil
|
||||||
import yaml
|
import yaml
|
||||||
import hashlib
|
import hashlib
|
||||||
import string
|
import string
|
||||||
|
import random
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
|
@ -34,7 +35,7 @@ import pandas as pd
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
|
||||||
from tortoise.api import TextToSpeech, MODELS, get_model_path, pad_or_truncate
|
from tortoise.api import TextToSpeech as TorToise_TTS, MODELS, get_model_path, pad_or_truncate
|
||||||
from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir, get_voices
|
from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir, get_voices
|
||||||
from tortoise.utils.text import split_and_recombine_text
|
from tortoise.utils.text import split_and_recombine_text
|
||||||
from tortoise.utils.device import get_device_name, set_device_name, get_device_count, get_device_vram, get_device_batch_size, do_gc
|
from tortoise.utils.device import get_device_name, set_device_name, get_device_count, get_device_vram, get_device_batch_size, do_gc
|
||||||
|
@ -68,6 +69,10 @@ try:
|
||||||
from vall_e.emb.qnt import encode as valle_quantize
|
from vall_e.emb.qnt import encode as valle_quantize
|
||||||
from vall_e.emb.g2p import encode as valle_phonemize
|
from vall_e.emb.g2p import encode as valle_phonemize
|
||||||
|
|
||||||
|
from vall_e.inference import TTS as VALLE_TTS
|
||||||
|
|
||||||
|
import soundfile
|
||||||
|
|
||||||
VALLE_ENABLED = True
|
VALLE_ENABLED = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
@ -111,6 +116,12 @@ def resample( waveform, input_rate, output_rate=44100 ):
|
||||||
return RESAMPLERS[key]( waveform ), output_rate
|
return RESAMPLERS[key]( waveform ), output_rate
|
||||||
|
|
||||||
def generate(**kwargs):
|
def generate(**kwargs):
|
||||||
|
if args.tts_backend == "tortoise":
|
||||||
|
return generate_tortoise(**kwargs)
|
||||||
|
if args.tts_backend == "vall-e":
|
||||||
|
return generate_valle(**kwargs)
|
||||||
|
|
||||||
|
def generate_valle(**kwargs):
|
||||||
parameters = {}
|
parameters = {}
|
||||||
parameters.update(kwargs)
|
parameters.update(kwargs)
|
||||||
|
|
||||||
|
@ -143,6 +154,297 @@ def generate(**kwargs):
|
||||||
conditioning_latents = None
|
conditioning_latents = None
|
||||||
sample_voice = None
|
sample_voice = None
|
||||||
|
|
||||||
|
voice_cache = {}
|
||||||
|
def fetch_voice( voice ):
|
||||||
|
voice_dir = f'./voices/{voice}/'
|
||||||
|
files = [ f'{voice_dir}/{d}' for d in os.listdir(voice_dir) if d[-4:] == ".wav" ]
|
||||||
|
return files
|
||||||
|
# return random.choice(files)
|
||||||
|
|
||||||
|
def get_settings( override=None ):
|
||||||
|
settings = {
|
||||||
|
'ar_temp': float(parameters['temperature']),
|
||||||
|
'nar_temp': float(parameters['temperature']),
|
||||||
|
'max_ar_samples': parameters['num_autoregressive_samples'],
|
||||||
|
}
|
||||||
|
|
||||||
|
# could be better to just do a ternary on everything above, but i am not a professional
|
||||||
|
selected_voice = voice
|
||||||
|
if override is not None:
|
||||||
|
if 'voice' in override:
|
||||||
|
selected_voice = override['voice']
|
||||||
|
|
||||||
|
for k in override:
|
||||||
|
if k not in settings:
|
||||||
|
continue
|
||||||
|
settings[k] = override[k]
|
||||||
|
|
||||||
|
settings['reference'] = fetch_voice(voice=selected_voice)
|
||||||
|
return settings
|
||||||
|
|
||||||
|
if not parameters['delimiter']:
|
||||||
|
parameters['delimiter'] = "\n"
|
||||||
|
elif parameters['delimiter'] == "\\n":
|
||||||
|
parameters['delimiter'] = "\n"
|
||||||
|
|
||||||
|
if parameters['delimiter'] and parameters['delimiter'] != "" and parameters['delimiter'] in parameters['text']:
|
||||||
|
texts = parameters['text'].split(parameters['delimiter'])
|
||||||
|
else:
|
||||||
|
texts = split_and_recombine_text(parameters['text'])
|
||||||
|
|
||||||
|
full_start_time = time.time()
|
||||||
|
|
||||||
|
outdir = f"{args.results_folder}/{voice}/"
|
||||||
|
os.makedirs(outdir, exist_ok=True)
|
||||||
|
|
||||||
|
audio_cache = {}
|
||||||
|
|
||||||
|
volume_adjust = torchaudio.transforms.Vol(gain=args.output_volume, gain_type="amplitude") if args.output_volume != 1 else None
|
||||||
|
|
||||||
|
idx = 0
|
||||||
|
idx_cache = {}
|
||||||
|
for i, file in enumerate(os.listdir(outdir)):
|
||||||
|
filename = os.path.basename(file)
|
||||||
|
extension = os.path.splitext(filename)[1]
|
||||||
|
if extension != ".json" and extension != ".wav":
|
||||||
|
continue
|
||||||
|
match = re.findall(rf"^{voice}_(\d+)(?:.+?)?{extension}$", filename)
|
||||||
|
if match and len(match) > 0:
|
||||||
|
key = int(match[0])
|
||||||
|
idx_cache[key] = True
|
||||||
|
|
||||||
|
if len(idx_cache) > 0:
|
||||||
|
keys = sorted(list(idx_cache.keys()))
|
||||||
|
idx = keys[-1] + 1
|
||||||
|
|
||||||
|
idx = pad(idx, 4)
|
||||||
|
|
||||||
|
def get_name(line=0, candidate=0, combined=False):
|
||||||
|
name = f"{idx}"
|
||||||
|
if combined:
|
||||||
|
name = f"{name}_combined"
|
||||||
|
elif len(texts) > 1:
|
||||||
|
name = f"{name}_{line}"
|
||||||
|
if parameters['candidates'] > 1:
|
||||||
|
name = f"{name}_{candidate}"
|
||||||
|
return name
|
||||||
|
|
||||||
|
def get_info( voice, settings = None, latents = True ):
|
||||||
|
info = {}
|
||||||
|
info.update(parameters)
|
||||||
|
|
||||||
|
info['time'] = time.time()-full_start_time
|
||||||
|
info['datetime'] = datetime.now().isoformat()
|
||||||
|
|
||||||
|
info['progress'] = None
|
||||||
|
del info['progress']
|
||||||
|
|
||||||
|
if info['delimiter'] == "\n":
|
||||||
|
info['delimiter'] = "\\n"
|
||||||
|
|
||||||
|
if settings is not None:
|
||||||
|
for k in settings:
|
||||||
|
if k in info:
|
||||||
|
info[k] = settings[k]
|
||||||
|
return info
|
||||||
|
|
||||||
|
INFERENCING = True
|
||||||
|
for line, cut_text in enumerate(texts):
|
||||||
|
progress.msg_prefix = f'[{str(line+1)}/{str(len(texts))}]'
|
||||||
|
print(f"{progress.msg_prefix} Generating line: {cut_text}")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# do setting editing
|
||||||
|
match = re.findall(r'^(\{.+\}) (.+?)$', cut_text)
|
||||||
|
override = None
|
||||||
|
if match and len(match) > 0:
|
||||||
|
match = match[0]
|
||||||
|
try:
|
||||||
|
override = json.loads(match[0])
|
||||||
|
cut_text = match[1].strip()
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception("Prompt settings editing requested, but received invalid JSON")
|
||||||
|
|
||||||
|
settings = get_settings( override=override )
|
||||||
|
reference = settings['reference']
|
||||||
|
settings.pop("reference")
|
||||||
|
|
||||||
|
gen = tts.inference(cut_text, reference, **settings )
|
||||||
|
|
||||||
|
run_time = time.time()-start_time
|
||||||
|
print(f"Generating line took {run_time} seconds")
|
||||||
|
|
||||||
|
if not isinstance(gen, list):
|
||||||
|
gen = [gen]
|
||||||
|
|
||||||
|
for j, g in enumerate(gen):
|
||||||
|
wav, sr = g
|
||||||
|
name = get_name(line=line, candidate=j)
|
||||||
|
|
||||||
|
settings['text'] = cut_text
|
||||||
|
settings['time'] = run_time
|
||||||
|
settings['datetime'] = datetime.now().isoformat()
|
||||||
|
|
||||||
|
# save here in case some error happens mid-batch
|
||||||
|
#torchaudio.save(f'{outdir}/{voice}_{name}.wav', wav.cpu(), sr)
|
||||||
|
soundfile.write(f'{outdir}/{voice}_{name}.wav', wav.cpu()[0,0], sr)
|
||||||
|
wav, sr = torchaudio.load(f'{outdir}/{voice}_{name}.wav')
|
||||||
|
|
||||||
|
audio_cache[name] = {
|
||||||
|
'audio': wav,
|
||||||
|
'settings': get_info(voice=override['voice'] if override and 'voice' in override else voice, settings=settings)
|
||||||
|
}
|
||||||
|
|
||||||
|
del gen
|
||||||
|
do_gc()
|
||||||
|
INFERENCING = False
|
||||||
|
|
||||||
|
for k in audio_cache:
|
||||||
|
audio = audio_cache[k]['audio']
|
||||||
|
|
||||||
|
audio, _ = resample(audio, tts.output_sample_rate, args.output_sample_rate)
|
||||||
|
if volume_adjust is not None:
|
||||||
|
audio = volume_adjust(audio)
|
||||||
|
|
||||||
|
audio_cache[k]['audio'] = audio
|
||||||
|
torchaudio.save(f'{outdir}/{voice}_{k}.wav', audio, args.output_sample_rate)
|
||||||
|
|
||||||
|
output_voices = []
|
||||||
|
for candidate in range(parameters['candidates']):
|
||||||
|
if len(texts) > 1:
|
||||||
|
audio_clips = []
|
||||||
|
for line in range(len(texts)):
|
||||||
|
name = get_name(line=line, candidate=candidate)
|
||||||
|
audio = audio_cache[name]['audio']
|
||||||
|
audio_clips.append(audio)
|
||||||
|
|
||||||
|
name = get_name(candidate=candidate, combined=True)
|
||||||
|
audio = torch.cat(audio_clips, dim=-1)
|
||||||
|
torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, args.output_sample_rate)
|
||||||
|
|
||||||
|
audio = audio.squeeze(0).cpu()
|
||||||
|
audio_cache[name] = {
|
||||||
|
'audio': audio,
|
||||||
|
'settings': get_info(voice=voice),
|
||||||
|
'output': True
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
name = get_name(candidate=candidate)
|
||||||
|
audio_cache[name]['output'] = True
|
||||||
|
|
||||||
|
|
||||||
|
if args.voice_fixer:
|
||||||
|
if not voicefixer:
|
||||||
|
progress(0, "Loading voicefix...")
|
||||||
|
load_voicefixer()
|
||||||
|
|
||||||
|
try:
|
||||||
|
fixed_cache = {}
|
||||||
|
for name in progress.tqdm(audio_cache, desc="Running voicefix..."):
|
||||||
|
del audio_cache[name]['audio']
|
||||||
|
if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
|
||||||
|
continue
|
||||||
|
|
||||||
|
path = f'{outdir}/{voice}_{name}.wav'
|
||||||
|
fixed = f'{outdir}/{voice}_{name}_fixed.wav'
|
||||||
|
voicefixer.restore(
|
||||||
|
input=path,
|
||||||
|
output=fixed,
|
||||||
|
cuda=get_device_name() == "cuda" and args.voice_fixer_use_cuda,
|
||||||
|
#mode=mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
fixed_cache[f'{name}_fixed'] = {
|
||||||
|
'settings': audio_cache[name]['settings'],
|
||||||
|
'output': True
|
||||||
|
}
|
||||||
|
audio_cache[name]['output'] = False
|
||||||
|
|
||||||
|
for name in fixed_cache:
|
||||||
|
audio_cache[name] = fixed_cache[name]
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
print("\nFailed to run Voicefixer")
|
||||||
|
|
||||||
|
for name in audio_cache:
|
||||||
|
if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
|
||||||
|
if args.prune_nonfinal_outputs:
|
||||||
|
audio_cache[name]['pruned'] = True
|
||||||
|
os.remove(f'{outdir}/{voice}_{name}.wav')
|
||||||
|
continue
|
||||||
|
|
||||||
|
output_voices.append(f'{outdir}/{voice}_{name}.wav')
|
||||||
|
|
||||||
|
if not args.embed_output_metadata:
|
||||||
|
with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f:
|
||||||
|
f.write(json.dumps(audio_cache[name]['settings'], indent='\t') )
|
||||||
|
|
||||||
|
if args.embed_output_metadata:
|
||||||
|
for name in progress.tqdm(audio_cache, desc="Embedding metadata..."):
|
||||||
|
if 'pruned' in audio_cache[name] and audio_cache[name]['pruned']:
|
||||||
|
continue
|
||||||
|
|
||||||
|
metadata = music_tag.load_file(f"{outdir}/{voice}_{name}.wav")
|
||||||
|
metadata['lyrics'] = json.dumps(audio_cache[name]['settings'])
|
||||||
|
metadata.save()
|
||||||
|
|
||||||
|
if sample_voice is not None:
|
||||||
|
sample_voice = (tts.input_sample_rate, sample_voice.numpy())
|
||||||
|
|
||||||
|
info = get_info(voice=voice, latents=False)
|
||||||
|
print(f"Generation took {info['time']} seconds, saved to '{output_voices[0]}'\n")
|
||||||
|
|
||||||
|
info['seed'] = usedSeed
|
||||||
|
if 'latents' in info:
|
||||||
|
del info['latents']
|
||||||
|
|
||||||
|
os.makedirs('./config/', exist_ok=True)
|
||||||
|
with open(f'./config/generate.json', 'w', encoding="utf-8") as f:
|
||||||
|
f.write(json.dumps(info, indent='\t') )
|
||||||
|
|
||||||
|
stats = [
|
||||||
|
[ parameters['seed'], "{:.3f}".format(info['time']) ]
|
||||||
|
]
|
||||||
|
|
||||||
|
return (
|
||||||
|
sample_voice,
|
||||||
|
output_voices,
|
||||||
|
stats,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def generate_tortoise(**kwargs):
|
||||||
|
parameters = {}
|
||||||
|
parameters.update(kwargs)
|
||||||
|
|
||||||
|
voice = parameters['voice']
|
||||||
|
progress = parameters['progress'] if 'progress' in parameters else None
|
||||||
|
if parameters['seed'] == 0:
|
||||||
|
parameters['seed'] = None
|
||||||
|
|
||||||
|
usedSeed = parameters['seed']
|
||||||
|
|
||||||
|
global args
|
||||||
|
global tts
|
||||||
|
|
||||||
|
unload_whisper()
|
||||||
|
unload_voicefixer()
|
||||||
|
|
||||||
|
if not tts:
|
||||||
|
# should check if it's loading or unloaded, and load it if it's unloaded
|
||||||
|
if tts_loading:
|
||||||
|
raise Exception("TTS is still initializing...")
|
||||||
|
load_tts()
|
||||||
|
if hasattr(tts, "loading") and tts.loading:
|
||||||
|
raise Exception("TTS is still initializing...")
|
||||||
|
|
||||||
|
do_gc()
|
||||||
|
|
||||||
|
voice_samples = None
|
||||||
|
conditioning_latents = None
|
||||||
|
sample_voice = None
|
||||||
|
|
||||||
voice_cache = {}
|
voice_cache = {}
|
||||||
def fetch_voice( voice ):
|
def fetch_voice( voice ):
|
||||||
cache_key = f'{voice}:{tts.autoregressive_model_hash[:8]}'
|
cache_key = f'{voice}:{tts.autoregressive_model_hash[:8]}'
|
||||||
|
@ -295,11 +597,13 @@ def generate(**kwargs):
|
||||||
def get_info( voice, settings = None, latents = True ):
|
def get_info( voice, settings = None, latents = True ):
|
||||||
info = {}
|
info = {}
|
||||||
info.update(parameters)
|
info.update(parameters)
|
||||||
info['time'] = time.time()-full_start_time
|
|
||||||
|
|
||||||
|
info['time'] = time.time()-full_start_time
|
||||||
info['datetime'] = datetime.now().isoformat()
|
info['datetime'] = datetime.now().isoformat()
|
||||||
|
|
||||||
info['model'] = tts.autoregressive_model_path
|
info['model'] = tts.autoregressive_model_path
|
||||||
info['model_hash'] = tts.autoregressive_model_hash
|
info['model_hash'] = tts.autoregressive_model_hash
|
||||||
|
|
||||||
info['progress'] = None
|
info['progress'] = None
|
||||||
del info['progress']
|
del info['progress']
|
||||||
|
|
||||||
|
@ -381,7 +685,8 @@ def generate(**kwargs):
|
||||||
|
|
||||||
settings['text'] = cut_text
|
settings['text'] = cut_text
|
||||||
settings['time'] = run_time
|
settings['time'] = run_time
|
||||||
settings['datetime'] = datetime.now().isoformat(),
|
settings['datetime'] = datetime.now().isoformat()
|
||||||
|
if args.tts_backend == "tortoise":
|
||||||
settings['model'] = tts.autoregressive_model_path
|
settings['model'] = tts.autoregressive_model_path
|
||||||
settings['model_hash'] = tts.autoregressive_model_hash
|
settings['model_hash'] = tts.autoregressive_model_hash
|
||||||
|
|
||||||
|
@ -745,8 +1050,8 @@ class TrainingState():
|
||||||
self.it_rate = f'{"{:.3f}".format(1/it_rate)}it/s' if 0 < it_rate and it_rate < 1 else f'{"{:.3f}".format(it_rate)}s/it'
|
self.it_rate = f'{"{:.3f}".format(1/it_rate)}it/s' if 0 < it_rate and it_rate < 1 else f'{"{:.3f}".format(it_rate)}s/it'
|
||||||
self.it_rates += it_rate
|
self.it_rates += it_rate
|
||||||
|
|
||||||
|
if self.it_rates > 0 and self.it * self.steps > 0:
|
||||||
epoch_rate = self.it_rates / self.it * self.steps
|
epoch_rate = self.it_rates / self.it * self.steps
|
||||||
if epoch_rate > 0:
|
|
||||||
self.epoch_rate = f'{"{:.3f}".format(1/epoch_rate)}epoch/s' if 0 < epoch_rate and epoch_rate < 1 else f'{"{:.3f}".format(epoch_rate)}s/epoch'
|
self.epoch_rate = f'{"{:.3f}".format(1/epoch_rate)}epoch/s' if 0 < epoch_rate and epoch_rate < 1 else f'{"{:.3f}".format(epoch_rate)}s/epoch'
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -925,6 +1230,7 @@ class TrainingState():
|
||||||
self.it_rates = 0
|
self.it_rates = 0
|
||||||
|
|
||||||
unq = {}
|
unq = {}
|
||||||
|
averager = None
|
||||||
|
|
||||||
for log in logs:
|
for log in logs:
|
||||||
with open(log, 'r', encoding="utf-8") as f:
|
with open(log, 'r', encoding="utf-8") as f:
|
||||||
|
@ -941,16 +1247,18 @@ class TrainingState():
|
||||||
if line.find('Training Metrics:') >= 0:
|
if line.find('Training Metrics:') >= 0:
|
||||||
split = line.split("Training Metrics:")[-1]
|
split = line.split("Training Metrics:")[-1]
|
||||||
data = json.loads(split)
|
data = json.loads(split)
|
||||||
data['mode'] = "training"
|
|
||||||
name = "train"
|
name = "train"
|
||||||
|
mode = "training"
|
||||||
elif line.find('Validation Metrics:') >= 0:
|
elif line.find('Validation Metrics:') >= 0:
|
||||||
data = json.loads(line.split("Validation Metrics:")[-1])
|
data = json.loads(line.split("Validation Metrics:")[-1])
|
||||||
data['mode'] = "validation"
|
|
||||||
if "it" not in data:
|
if "it" not in data:
|
||||||
data['it'] = it
|
data['it'] = it
|
||||||
if "epoch" not in data:
|
if "epoch" not in data:
|
||||||
data['epoch'] = epoch
|
data['epoch'] = epoch
|
||||||
|
|
||||||
name = data['name'] if 'name' in data else "val"
|
name = data['name'] if 'name' in data else "val"
|
||||||
|
mode = "validation"
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -960,14 +1268,39 @@ class TrainingState():
|
||||||
it = data['it']
|
it = data['it']
|
||||||
epoch = data['epoch']
|
epoch = data['epoch']
|
||||||
|
|
||||||
# this method should have it at least
|
if args.tts_backend == "vall-e":
|
||||||
unq[f'{it}_{name}'] = data
|
if not averager or averager['key'] != f'{it}_{name}' or averager['mode'] != mode:
|
||||||
|
averager = {
|
||||||
|
'key': f'{it}_{name}',
|
||||||
|
'mode': mode,
|
||||||
|
"metrics": {}
|
||||||
|
}
|
||||||
|
for k in data:
|
||||||
|
if data[k] is None:
|
||||||
|
continue
|
||||||
|
averager['metrics'][k] = [ data[k] ]
|
||||||
|
else:
|
||||||
|
for k in data:
|
||||||
|
if data[k] is None:
|
||||||
|
continue
|
||||||
|
averager['metrics'][k].append( data[k] )
|
||||||
|
|
||||||
|
unq[f'{it}_{mode}_{name}'] = averager
|
||||||
|
else:
|
||||||
|
unq[f'{it}_{mode}_{name}'] = data
|
||||||
|
|
||||||
if update and it <= self.last_info_check_at:
|
if update and it <= self.last_info_check_at:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for it in unq:
|
for it in unq:
|
||||||
self.parse_metrics(unq[it])
|
if args.tts_backend == "vall-e":
|
||||||
|
stats = unq[it]
|
||||||
|
data = {k: sum(v) / len(v) for k, v in stats['metrics'].items()}
|
||||||
|
data['mode'] = stats
|
||||||
|
data['steps'] = len(stats['metrics']['it'])
|
||||||
|
else:
|
||||||
|
data = unq[it]
|
||||||
|
self.parse_metrics(data)
|
||||||
|
|
||||||
self.last_info_check_at = highest_step
|
self.last_info_check_at = highest_step
|
||||||
|
|
||||||
|
@ -1087,6 +1420,7 @@ def run_training(config_path, verbose=False, keep_x_past_checkpoints=0, progress
|
||||||
|
|
||||||
|
|
||||||
# ensure we have the dvae.pth
|
# ensure we have the dvae.pth
|
||||||
|
if args.tts_backend == "tortoise":
|
||||||
get_model_path('dvae.pth')
|
get_model_path('dvae.pth')
|
||||||
|
|
||||||
# I don't know if this is still necessary, as it was bitching at me for not doing this, despite it being in a separate process
|
# I don't know if this is still necessary, as it was bitching at me for not doing this, despite it being in a separate process
|
||||||
|
@ -2086,6 +2420,8 @@ def get_voice_list(dir=get_voice_dir(), append_defaults=False):
|
||||||
res = res + defaults
|
res = res + defaults
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
def get_valle_models(dir="./training/"):
|
||||||
|
return [ f'{dir}/{d}/config.yaml' for d in os.listdir(dir) if os.path.exists(f'{dir}/{d}/config.yaml') ]
|
||||||
|
|
||||||
def get_autoregressive_models(dir="./models/finetunes/", prefixed=False, auto=False):
|
def get_autoregressive_models(dir="./models/finetunes/", prefixed=False, auto=False):
|
||||||
os.makedirs(dir, exist_ok=True)
|
os.makedirs(dir, exist_ok=True)
|
||||||
|
@ -2269,6 +2605,8 @@ def setup_args():
|
||||||
|
|
||||||
'phonemizer-backend': 'espeak',
|
'phonemizer-backend': 'espeak',
|
||||||
|
|
||||||
|
'valle-model': None,
|
||||||
|
|
||||||
'whisper-backend': 'openai/whisper',
|
'whisper-backend': 'openai/whisper',
|
||||||
'whisper-model': "base",
|
'whisper-model': "base",
|
||||||
'whisper-batchsize': 1,
|
'whisper-batchsize': 1,
|
||||||
|
@ -2319,6 +2657,8 @@ def setup_args():
|
||||||
|
|
||||||
parser.add_argument("--phonemizer-backend", default=default_arguments['phonemizer-backend'], help="Specifies which phonemizer backend to use.")
|
parser.add_argument("--phonemizer-backend", default=default_arguments['phonemizer-backend'], help="Specifies which phonemizer backend to use.")
|
||||||
|
|
||||||
|
parser.add_argument("--valle-model", default=default_arguments['valle-model'], help="Specifies which VALL-E model to use for sampling.")
|
||||||
|
|
||||||
parser.add_argument("--whisper-backend", default=default_arguments['whisper-backend'], action='store_true', help="Picks which whisper backend to use (openai/whisper, lightmare/whispercpp)")
|
parser.add_argument("--whisper-backend", default=default_arguments['whisper-backend'], action='store_true', help="Picks which whisper backend to use (openai/whisper, lightmare/whispercpp)")
|
||||||
parser.add_argument("--whisper-model", default=default_arguments['whisper-model'], help="Specifies which whisper model to use for transcription.")
|
parser.add_argument("--whisper-model", default=default_arguments['whisper-model'], help="Specifies which whisper model to use for transcription.")
|
||||||
parser.add_argument("--whisper-batchsize", type=int, default=default_arguments['whisper-batchsize'], help="Specifies batch size for WhisperX")
|
parser.add_argument("--whisper-batchsize", type=int, default=default_arguments['whisper-batchsize'], help="Specifies batch size for WhisperX")
|
||||||
|
@ -2390,6 +2730,8 @@ def get_default_settings( hypenated=True ):
|
||||||
|
|
||||||
'phonemizer-backend': args.phonemizer_backend,
|
'phonemizer-backend': args.phonemizer_backend,
|
||||||
|
|
||||||
|
'valle-model': args.valle_model,
|
||||||
|
|
||||||
'whisper-backend': args.whisper_backend,
|
'whisper-backend': args.whisper_backend,
|
||||||
'whisper-model': args.whisper_model,
|
'whisper-model': args.whisper_model,
|
||||||
'whisper-batchsize': args.whisper_batchsize,
|
'whisper-batchsize': args.whisper_batchsize,
|
||||||
|
@ -2440,6 +2782,8 @@ def update_args( **kwargs ):
|
||||||
|
|
||||||
args.phonemizer_backend = settings['phonemizer_backend']
|
args.phonemizer_backend = settings['phonemizer_backend']
|
||||||
|
|
||||||
|
args.valle_model = settings['valle_model']
|
||||||
|
|
||||||
args.whisper_backend = settings['whisper_backend']
|
args.whisper_backend = settings['whisper_backend']
|
||||||
args.whisper_model = settings['whisper_model']
|
args.whisper_model = settings['whisper_model']
|
||||||
args.whisper_batchsize = settings['whisper_batchsize']
|
args.whisper_batchsize = settings['whisper_batchsize']
|
||||||
|
@ -2553,13 +2897,20 @@ def version_check_tts( min_version ):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def load_tts( restart=False, autoregressive_model=None, diffusion_model=None, vocoder_model=None, tokenizer_json=None ):
|
def load_tts( restart=False,
|
||||||
|
# TorToiSe configs
|
||||||
|
autoregressive_model=None, diffusion_model=None, vocoder_model=None, tokenizer_json=None,
|
||||||
|
# VALL-E configs
|
||||||
|
valle_model=None,
|
||||||
|
):
|
||||||
global args
|
global args
|
||||||
global tts
|
global tts
|
||||||
|
|
||||||
if restart:
|
if restart:
|
||||||
unload_tts()
|
unload_tts()
|
||||||
|
|
||||||
|
tts_loading = True
|
||||||
|
if args.tts_backend == "tortoise":
|
||||||
if autoregressive_model:
|
if autoregressive_model:
|
||||||
args.autoregressive_model = autoregressive_model
|
args.autoregressive_model = autoregressive_model
|
||||||
else:
|
else:
|
||||||
|
@ -2586,17 +2937,21 @@ def load_tts( restart=False, autoregressive_model=None, diffusion_model=None, vo
|
||||||
if get_device_name() == "cpu":
|
if get_device_name() == "cpu":
|
||||||
print("!!!! WARNING !!!! No GPU available in PyTorch. You may need to reinstall PyTorch.")
|
print("!!!! WARNING !!!! No GPU available in PyTorch. You may need to reinstall PyTorch.")
|
||||||
|
|
||||||
tts_loading = True
|
print(f"Loading TorToiSe... (AR: {autoregressive_model}, diffusion: {diffusion_model}, vocoder: {vocoder_model})")
|
||||||
print(f"Loading TorToiSe... (AR: {autoregressive_model}, vocoder: {vocoder_model})")
|
tts = TorToise_TTS(minor_optimizations=not args.low_vram, autoregressive_model_path=autoregressive_model, diffusion_model_path=diffusion_model, vocoder_model=vocoder_model, tokenizer_json=tokenizer_json, unsqueeze_sample_batches=args.unsqueeze_sample_batches)
|
||||||
tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=autoregressive_model, diffusion_model_path=diffusion_model, vocoder_model=vocoder_model, tokenizer_json=tokenizer_json, unsqueeze_sample_batches=args.unsqueeze_sample_batches)
|
elif args.tts_backend == "vall-e":
|
||||||
|
if valle_model:
|
||||||
|
args.valle_model = valle_model
|
||||||
|
else:
|
||||||
|
valle_model = args.valle_model
|
||||||
|
|
||||||
|
print(f"Loading VALL-E... (Config: {valle_model})")
|
||||||
|
tts = VALLE_TTS(config=args.valle_model)
|
||||||
|
|
||||||
|
print("Loaded TTS, ready for generation.")
|
||||||
tts_loading = False
|
tts_loading = False
|
||||||
|
|
||||||
get_model_path('dvae.pth')
|
|
||||||
print("Loaded TorToiSe, ready for generation.")
|
|
||||||
return tts
|
return tts
|
||||||
|
|
||||||
setup_tortoise = load_tts
|
|
||||||
|
|
||||||
def unload_tts():
|
def unload_tts():
|
||||||
global tts
|
global tts
|
||||||
|
|
||||||
|
@ -2643,6 +2998,9 @@ def deduce_autoregressive_model(voice=None):
|
||||||
return get_model_path('autoregressive.pth')
|
return get_model_path('autoregressive.pth')
|
||||||
|
|
||||||
def update_autoregressive_model(autoregressive_model_path):
|
def update_autoregressive_model(autoregressive_model_path):
|
||||||
|
if args.tts_backend != "tortoise":
|
||||||
|
raise f"Unsupported backend: {args.tts_backend}"
|
||||||
|
|
||||||
match = re.findall(r'^\[[a-fA-F0-9]{8}\] (.+?)$', autoregressive_model_path)
|
match = re.findall(r'^\[[a-fA-F0-9]{8}\] (.+?)$', autoregressive_model_path)
|
||||||
if match:
|
if match:
|
||||||
autoregressive_model_path = match[0]
|
autoregressive_model_path = match[0]
|
||||||
|
@ -2677,6 +3035,9 @@ def update_autoregressive_model(autoregressive_model_path):
|
||||||
return autoregressive_model_path
|
return autoregressive_model_path
|
||||||
|
|
||||||
def update_diffusion_model(diffusion_model_path):
|
def update_diffusion_model(diffusion_model_path):
|
||||||
|
if args.tts_backend != "tortoise":
|
||||||
|
raise f"Unsupported backend: {args.tts_backend}"
|
||||||
|
|
||||||
match = re.findall(r'^\[[a-fA-F0-9]{8}\] (.+?)$', diffusion_model_path)
|
match = re.findall(r'^\[[a-fA-F0-9]{8}\] (.+?)$', diffusion_model_path)
|
||||||
if match:
|
if match:
|
||||||
diffusion_model_path = match[0]
|
diffusion_model_path = match[0]
|
||||||
|
@ -2711,6 +3072,9 @@ def update_diffusion_model(diffusion_model_path):
|
||||||
return diffusion_model_path
|
return diffusion_model_path
|
||||||
|
|
||||||
def update_vocoder_model(vocoder_model):
|
def update_vocoder_model(vocoder_model):
|
||||||
|
if args.tts_backend != "tortoise":
|
||||||
|
raise f"Unsupported backend: {args.tts_backend}"
|
||||||
|
|
||||||
args.vocoder_model = vocoder_model
|
args.vocoder_model = vocoder_model
|
||||||
save_args_settings()
|
save_args_settings()
|
||||||
print(f'Stored vocoder model to settings: {vocoder_model}')
|
print(f'Stored vocoder model to settings: {vocoder_model}')
|
||||||
|
@ -2733,6 +3097,9 @@ def update_vocoder_model(vocoder_model):
|
||||||
return vocoder_model
|
return vocoder_model
|
||||||
|
|
||||||
def update_tokenizer(tokenizer_json):
|
def update_tokenizer(tokenizer_json):
|
||||||
|
if args.tts_backend != "tortoise":
|
||||||
|
raise f"Unsupported backend: {args.tts_backend}"
|
||||||
|
|
||||||
args.tokenizer_json = tokenizer_json
|
args.tokenizer_json = tokenizer_json
|
||||||
save_args_settings()
|
save_args_settings()
|
||||||
print(f'Stored tokenizer to settings: {tokenizer_json}')
|
print(f'Stored tokenizer to settings: {tokenizer_json}')
|
||||||
|
|
18
src/webui.py
18
src/webui.py
|
@ -315,6 +315,8 @@ def setup_gradio():
|
||||||
voice_list = get_voice_list()
|
voice_list = get_voice_list()
|
||||||
result_voices = get_voice_list(args.results_folder)
|
result_voices = get_voice_list(args.results_folder)
|
||||||
|
|
||||||
|
valle_models = get_valle_models()
|
||||||
|
|
||||||
autoregressive_models = get_autoregressive_models()
|
autoregressive_models = get_autoregressive_models()
|
||||||
diffusion_models = get_diffusion_models()
|
diffusion_models = get_diffusion_models()
|
||||||
tokenizer_jsons = get_tokenizer_jsons()
|
tokenizer_jsons = get_tokenizer_jsons()
|
||||||
|
@ -337,11 +339,11 @@ def setup_gradio():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
GENERATE_SETTINGS["delimiter"] = gr.Textbox(lines=1, label="Line Delimiter", placeholder="\\n")
|
GENERATE_SETTINGS["delimiter"] = gr.Textbox(lines=1, label="Line Delimiter", placeholder="\\n")
|
||||||
|
|
||||||
GENERATE_SETTINGS["emotion"] = gr.Radio( ["Happy", "Sad", "Angry", "Disgusted", "Arrogant", "Custom", "None"], value="None", label="Emotion", type="value", interactive=True )
|
GENERATE_SETTINGS["emotion"] = gr.Radio( ["Happy", "Sad", "Angry", "Disgusted", "Arrogant", "Custom", "None"], value="None", label="Emotion", type="value", interactive=True, visible=args.tts_backend=="tortoise" )
|
||||||
GENERATE_SETTINGS["prompt"] = gr.Textbox(lines=1, label="Custom Emotion", visible=False)
|
GENERATE_SETTINGS["prompt"] = gr.Textbox(lines=1, label="Custom Emotion", visible=False)
|
||||||
GENERATE_SETTINGS["voice"] = gr.Dropdown(choices=voice_list_with_defaults, label="Voice", type="value", value=voice_list_with_defaults[0]) # it'd be very cash money if gradio was able to default to the first value in the list without this shit
|
GENERATE_SETTINGS["voice"] = gr.Dropdown(choices=voice_list_with_defaults, label="Voice", type="value", value=voice_list_with_defaults[0]) # it'd be very cash money if gradio was able to default to the first value in the list without this shit
|
||||||
GENERATE_SETTINGS["mic_audio"] = gr.Audio( label="Microphone Source", source="microphone", type="filepath", visible=False )
|
GENERATE_SETTINGS["mic_audio"] = gr.Audio( label="Microphone Source", source="microphone", type="filepath", visible=False )
|
||||||
GENERATE_SETTINGS["voice_latents_chunks"] = gr.Number(label="Voice Chunks", precision=0, value=0)
|
GENERATE_SETTINGS["voice_latents_chunks"] = gr.Number(label="Voice Chunks", precision=0, value=0, visible=args.tts_backend=="tortoise")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
refresh_voices = gr.Button(value="Refresh Voice List")
|
refresh_voices = gr.Button(value="Refresh Voice List")
|
||||||
recompute_voice_latents = gr.Button(value="(Re)Compute Voice Latents")
|
recompute_voice_latents = gr.Button(value="(Re)Compute Voice Latents")
|
||||||
|
@ -357,17 +359,17 @@ def setup_gradio():
|
||||||
outputs=GENERATE_SETTINGS["mic_audio"],
|
outputs=GENERATE_SETTINGS["mic_audio"],
|
||||||
)
|
)
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
GENERATE_SETTINGS["candidates"] = gr.Slider(value=1, minimum=1, maximum=6, step=1, label="Candidates")
|
GENERATE_SETTINGS["candidates"] = gr.Slider(value=1, minimum=1, maximum=6, step=1, label="Candidates", visible=args.tts_backend=="tortoise")
|
||||||
GENERATE_SETTINGS["seed"] = gr.Number(value=0, precision=0, label="Seed")
|
GENERATE_SETTINGS["seed"] = gr.Number(value=0, precision=0, label="Seed")
|
||||||
|
|
||||||
preset = gr.Radio( ["Ultra Fast", "Fast", "Standard", "High Quality"], label="Preset", type="value", value="Ultra Fast" )
|
preset = gr.Radio( ["Ultra Fast", "Fast", "Standard", "High Quality"], label="Preset", type="value", value="Ultra Fast" )
|
||||||
|
|
||||||
GENERATE_SETTINGS["num_autoregressive_samples"] = gr.Slider(value=16, minimum=2, maximum=512, step=1, label="Samples")
|
GENERATE_SETTINGS["num_autoregressive_samples"] = gr.Slider(value=16, minimum=2, maximum=512, step=1, label="Samples")
|
||||||
GENERATE_SETTINGS["diffusion_iterations"] = gr.Slider(value=30, minimum=0, maximum=512, step=1, label="Iterations")
|
GENERATE_SETTINGS["diffusion_iterations"] = gr.Slider(value=30, minimum=0, maximum=512, step=1, label="Iterations", visible=args.tts_backend=="tortoise")
|
||||||
|
|
||||||
GENERATE_SETTINGS["temperature"] = gr.Slider(value=0.2, minimum=0, maximum=1, step=0.1, label="Temperature")
|
GENERATE_SETTINGS["temperature"] = gr.Slider(value=0.2, minimum=0, maximum=1, step=0.1, label="Temperature")
|
||||||
|
|
||||||
show_experimental_settings = gr.Checkbox(label="Show Experimental Settings")
|
show_experimental_settings = gr.Checkbox(label="Show Experimental Settings", visible=args.tts_backend=="tortoise")
|
||||||
reset_generate_settings_button = gr.Button(value="Reset to Default")
|
reset_generate_settings_button = gr.Button(value="Reset to Default")
|
||||||
with gr.Column(visible=False) as col:
|
with gr.Column(visible=False) as col:
|
||||||
experimental_column = col
|
experimental_column = col
|
||||||
|
@ -606,10 +608,12 @@ def setup_gradio():
|
||||||
EXEC_SETTINGS['device_override'] = gr.Textbox(label="Device Override", value=args.device_override)
|
EXEC_SETTINGS['device_override'] = gr.Textbox(label="Device Override", value=args.device_override)
|
||||||
|
|
||||||
EXEC_SETTINGS['results_folder'] = gr.Textbox(label="Results Folder", value=args.results_folder)
|
EXEC_SETTINGS['results_folder'] = gr.Textbox(label="Results Folder", value=args.results_folder)
|
||||||
|
|
||||||
with gr.Column():
|
|
||||||
# EXEC_SETTINGS['tts_backend'] = gr.Dropdown(TTSES, label="TTS Backend", value=args.tts_backend if args.tts_backend else TTSES[0])
|
# EXEC_SETTINGS['tts_backend'] = gr.Dropdown(TTSES, label="TTS Backend", value=args.tts_backend if args.tts_backend else TTSES[0])
|
||||||
|
|
||||||
|
with gr.Column(visible=args.tts_backend=="vall-e"):
|
||||||
|
EXEC_SETTINGS['valle_model'] = gr.Dropdown(choices=valle_models, label="VALL-E Model Config", value=args.valle_model if args.valle_model else valle_models[0])
|
||||||
|
|
||||||
|
with gr.Column(visible=args.tts_backend=="tortoise"):
|
||||||
EXEC_SETTINGS['autoregressive_model'] = gr.Dropdown(choices=["auto"] + autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else "auto")
|
EXEC_SETTINGS['autoregressive_model'] = gr.Dropdown(choices=["auto"] + autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else "auto")
|
||||||
EXEC_SETTINGS['diffusion_model'] = gr.Dropdown(choices=diffusion_models, label="Diffusion Model", value=args.diffusion_model if args.diffusion_model else diffusion_models[0])
|
EXEC_SETTINGS['diffusion_model'] = gr.Dropdown(choices=diffusion_models, label="Diffusion Model", value=args.diffusion_model if args.diffusion_model else diffusion_models[0])
|
||||||
EXEC_SETTINGS['vocoder_model'] = gr.Dropdown(VOCODERS, label="Vocoder", value=args.vocoder_model if args.vocoder_model else VOCODERS[-1])
|
EXEC_SETTINGS['vocoder_model'] = gr.Dropdown(VOCODERS, label="Vocoder", value=args.vocoder_model if args.vocoder_model else VOCODERS[-1])
|
||||||
|
|
Loading…
Reference in New Issue
Block a user