forked from mrq/ai-voice-cloning
added VALL-E inference support (very rudimentary, gimped, but it will load a model trained on a config generated through the web UI)
This commit is contained in:
parent
9b01377667
commit
4744120be2
|
@ -1,61 +0,0 @@
|
|||
{
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": 2e-05,
|
||||
"betas": [
|
||||
0.9,
|
||||
0.96
|
||||
],
|
||||
"eps": 1e-07,
|
||||
"weight_decay": 0.01
|
||||
}
|
||||
},
|
||||
"scheduler":{
|
||||
"type":"WarmupLR",
|
||||
"params":{
|
||||
"warmup_min_lr":0,
|
||||
"warmup_max_lr":2e-5,
|
||||
"warmup_num_steps":100,
|
||||
"warmup_type":"linear"
|
||||
}
|
||||
},
|
||||
"fp16":{
|
||||
"enabled":true,
|
||||
"loss_scale":0,
|
||||
"loss_scale_window":1000,
|
||||
"initial_scale_power":16,
|
||||
"hysteresis":2,
|
||||
"min_loss_scale":1
|
||||
},
|
||||
"autotuning":{
|
||||
"enabled":false,
|
||||
"results_dir":"./config/autotune/results",
|
||||
"exps_dir":"./config/autotune/exps",
|
||||
"overwrite":false,
|
||||
"metric":"throughput",
|
||||
"start_profile_step":10,
|
||||
"end_profile_step":20,
|
||||
"fast":false,
|
||||
"max_train_batch_size":32,
|
||||
"mp_size":1,
|
||||
"num_tuning_micro_batch_sizes":3,
|
||||
"tuner_type":"model_based",
|
||||
"tuner_early_stopping":5,
|
||||
"tuner_num_trials":50,
|
||||
"arg_mappings":{
|
||||
"train_micro_batch_size_per_gpu":"--per_device_train_batch_size",
|
||||
"gradient_accumulation_steps ":"--gradient_accumulation_steps"
|
||||
}
|
||||
},
|
||||
"zero_optimization":{
|
||||
"stage":0,
|
||||
"reduce_bucket_size":"auto",
|
||||
"contiguous_gradients":true,
|
||||
"sub_group_size":1e8,
|
||||
"stage3_prefetch_bucket_size":"auto",
|
||||
"stage3_param_persistence_threshold":"auto",
|
||||
"stage3_max_live_parameters":"auto",
|
||||
"stage3_max_reuse_distance":"auto"
|
||||
}
|
||||
}
|
405
src/utils.py
405
src/utils.py
|
@ -22,6 +22,7 @@ import psutil
|
|||
import yaml
|
||||
import hashlib
|
||||
import string
|
||||
import random
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
@ -34,7 +35,7 @@ import pandas as pd
|
|||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
|
||||
from tortoise.api import TextToSpeech, MODELS, get_model_path, pad_or_truncate
|
||||
from tortoise.api import TextToSpeech as TorToise_TTS, MODELS, get_model_path, pad_or_truncate
|
||||
from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir, get_voices
|
||||
from tortoise.utils.text import split_and_recombine_text
|
||||
from tortoise.utils.device import get_device_name, set_device_name, get_device_count, get_device_vram, get_device_batch_size, do_gc
|
||||
|
@ -68,6 +69,10 @@ try:
|
|||
from vall_e.emb.qnt import encode as valle_quantize
|
||||
from vall_e.emb.g2p import encode as valle_phonemize
|
||||
|
||||
from vall_e.inference import TTS as VALLE_TTS
|
||||
|
||||
import soundfile
|
||||
|
||||
VALLE_ENABLED = True
|
||||
except Exception as e:
|
||||
pass
|
||||
|
@ -111,6 +116,12 @@ def resample( waveform, input_rate, output_rate=44100 ):
|
|||
return RESAMPLERS[key]( waveform ), output_rate
|
||||
|
||||
def generate(**kwargs):
|
||||
if args.tts_backend == "tortoise":
|
||||
return generate_tortoise(**kwargs)
|
||||
if args.tts_backend == "vall-e":
|
||||
return generate_valle(**kwargs)
|
||||
|
||||
def generate_valle(**kwargs):
|
||||
parameters = {}
|
||||
parameters.update(kwargs)
|
||||
|
||||
|
@ -140,7 +151,298 @@ def generate(**kwargs):
|
|||
do_gc()
|
||||
|
||||
voice_samples = None
|
||||
conditioning_latents =None
|
||||
conditioning_latents = None
|
||||
sample_voice = None
|
||||
|
||||
voice_cache = {}
|
||||
def fetch_voice( voice ):
|
||||
voice_dir = f'./voices/{voice}/'
|
||||
files = [ f'{voice_dir}/{d}' for d in os.listdir(voice_dir) if d[-4:] == ".wav" ]
|
||||
return files
|
||||
# return random.choice(files)
|
||||
|
||||
def get_settings( override=None ):
|
||||
settings = {
|
||||
'ar_temp': float(parameters['temperature']),
|
||||
'nar_temp': float(parameters['temperature']),
|
||||
'max_ar_samples': parameters['num_autoregressive_samples'],
|
||||
}
|
||||
|
||||
# could be better to just do a ternary on everything above, but i am not a professional
|
||||
selected_voice = voice
|
||||
if override is not None:
|
||||
if 'voice' in override:
|
||||
selected_voice = override['voice']
|
||||
|
||||
for k in override:
|
||||
if k not in settings:
|
||||
continue
|
||||
settings[k] = override[k]
|
||||
|
||||
settings['reference'] = fetch_voice(voice=selected_voice)
|
||||
return settings
|
||||
|
||||
if not parameters['delimiter']:
|
||||
parameters['delimiter'] = "\n"
|
||||
elif parameters['delimiter'] == "\\n":
|
||||
parameters['delimiter'] = "\n"
|
||||
|
||||
if parameters['delimiter'] and parameters['delimiter'] != "" and parameters['delimiter'] in parameters['text']:
|
||||
texts = parameters['text'].split(parameters['delimiter'])
|
||||
else:
|
||||
texts = split_and_recombine_text(parameters['text'])
|
||||
|
||||
full_start_time = time.time()
|
||||
|
||||
outdir = f"{args.results_folder}/{voice}/"
|
||||
os.makedirs(outdir, exist_ok=True)
|
||||
|
||||
audio_cache = {}
|
||||
|
||||
volume_adjust = torchaudio.transforms.Vol(gain=args.output_volume, gain_type="amplitude") if args.output_volume != 1 else None
|
||||
|
||||
idx = 0
|
||||
idx_cache = {}
|
||||
for i, file in enumerate(os.listdir(outdir)):
|
||||
filename = os.path.basename(file)
|
||||
extension = os.path.splitext(filename)[1]
|
||||
if extension != ".json" and extension != ".wav":
|
||||
continue
|
||||
match = re.findall(rf"^{voice}_(\d+)(?:.+?)?{extension}$", filename)
|
||||
if match and len(match) > 0:
|
||||
key = int(match[0])
|
||||
idx_cache[key] = True
|
||||
|
||||
if len(idx_cache) > 0:
|
||||
keys = sorted(list(idx_cache.keys()))
|
||||
idx = keys[-1] + 1
|
||||
|
||||
idx = pad(idx, 4)
|
||||
|
||||
def get_name(line=0, candidate=0, combined=False):
|
||||
name = f"{idx}"
|
||||
if combined:
|
||||
name = f"{name}_combined"
|
||||
elif len(texts) > 1:
|
||||
name = f"{name}_{line}"
|
||||
if parameters['candidates'] > 1:
|
||||
name = f"{name}_{candidate}"
|
||||
return name
|
||||
|
||||
def get_info( voice, settings = None, latents = True ):
|
||||
info = {}
|
||||
info.update(parameters)
|
||||
|
||||
info['time'] = time.time()-full_start_time
|
||||
info['datetime'] = datetime.now().isoformat()
|
||||
|
||||
info['progress'] = None
|
||||
del info['progress']
|
||||
|
||||
if info['delimiter'] == "\n":
|
||||
info['delimiter'] = "\\n"
|
||||
|
||||
if settings is not None:
|
||||
for k in settings:
|
||||
if k in info:
|
||||
info[k] = settings[k]
|
||||
return info
|
||||
|
||||
INFERENCING = True
|
||||
for line, cut_text in enumerate(texts):
|
||||
progress.msg_prefix = f'[{str(line+1)}/{str(len(texts))}]'
|
||||
print(f"{progress.msg_prefix} Generating line: {cut_text}")
|
||||
start_time = time.time()
|
||||
|
||||
# do setting editing
|
||||
match = re.findall(r'^(\{.+\}) (.+?)$', cut_text)
|
||||
override = None
|
||||
if match and len(match) > 0:
|
||||
match = match[0]
|
||||
try:
|
||||
override = json.loads(match[0])
|
||||
cut_text = match[1].strip()
|
||||
except Exception as e:
|
||||
raise Exception("Prompt settings editing requested, but received invalid JSON")
|
||||
|
||||
settings = get_settings( override=override )
|
||||
reference = settings['reference']
|
||||
settings.pop("reference")
|
||||
|
||||
gen = tts.inference(cut_text, reference, **settings )
|
||||
|
||||
run_time = time.time()-start_time
|
||||
print(f"Generating line took {run_time} seconds")
|
||||
|
||||
if not isinstance(gen, list):
|
||||
gen = [gen]
|
||||
|
||||
for j, g in enumerate(gen):
|
||||
wav, sr = g
|
||||
name = get_name(line=line, candidate=j)
|
||||
|
||||
settings['text'] = cut_text
|
||||
settings['time'] = run_time
|
||||
settings['datetime'] = datetime.now().isoformat()
|
||||
|
||||
# save here in case some error happens mid-batch
|
||||
#torchaudio.save(f'{outdir}/{voice}_{name}.wav', wav.cpu(), sr)
|
||||
soundfile.write(f'{outdir}/{voice}_{name}.wav', wav.cpu()[0,0], sr)
|
||||
wav, sr = torchaudio.load(f'{outdir}/{voice}_{name}.wav')
|
||||
|
||||
audio_cache[name] = {
|
||||
'audio': wav,
|
||||
'settings': get_info(voice=override['voice'] if override and 'voice' in override else voice, settings=settings)
|
||||
}
|
||||
|
||||
del gen
|
||||
do_gc()
|
||||
INFERENCING = False
|
||||
|
||||
for k in audio_cache:
|
||||
audio = audio_cache[k]['audio']
|
||||
|
||||
audio, _ = resample(audio, tts.output_sample_rate, args.output_sample_rate)
|
||||
if volume_adjust is not None:
|
||||
audio = volume_adjust(audio)
|
||||
|
||||
audio_cache[k]['audio'] = audio
|
||||
torchaudio.save(f'{outdir}/{voice}_{k}.wav', audio, args.output_sample_rate)
|
||||
|
||||
output_voices = []
|
||||
for candidate in range(parameters['candidates']):
|
||||
if len(texts) > 1:
|
||||
audio_clips = []
|
||||
for line in range(len(texts)):
|
||||
name = get_name(line=line, candidate=candidate)
|
||||
audio = audio_cache[name]['audio']
|
||||
audio_clips.append(audio)
|
||||
|
||||
name = get_name(candidate=candidate, combined=True)
|
||||
audio = torch.cat(audio_clips, dim=-1)
|
||||
torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, args.output_sample_rate)
|
||||
|
||||
audio = audio.squeeze(0).cpu()
|
||||
audio_cache[name] = {
|
||||
'audio': audio,
|
||||
'settings': get_info(voice=voice),
|
||||
'output': True
|
||||
}
|
||||
else:
|
||||
name = get_name(candidate=candidate)
|
||||
audio_cache[name]['output'] = True
|
||||
|
||||
|
||||
if args.voice_fixer:
|
||||
if not voicefixer:
|
||||
progress(0, "Loading voicefix...")
|
||||
load_voicefixer()
|
||||
|
||||
try:
|
||||
fixed_cache = {}
|
||||
for name in progress.tqdm(audio_cache, desc="Running voicefix..."):
|
||||
del audio_cache[name]['audio']
|
||||
if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
|
||||
continue
|
||||
|
||||
path = f'{outdir}/{voice}_{name}.wav'
|
||||
fixed = f'{outdir}/{voice}_{name}_fixed.wav'
|
||||
voicefixer.restore(
|
||||
input=path,
|
||||
output=fixed,
|
||||
cuda=get_device_name() == "cuda" and args.voice_fixer_use_cuda,
|
||||
#mode=mode,
|
||||
)
|
||||
|
||||
fixed_cache[f'{name}_fixed'] = {
|
||||
'settings': audio_cache[name]['settings'],
|
||||
'output': True
|
||||
}
|
||||
audio_cache[name]['output'] = False
|
||||
|
||||
for name in fixed_cache:
|
||||
audio_cache[name] = fixed_cache[name]
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("\nFailed to run Voicefixer")
|
||||
|
||||
for name in audio_cache:
|
||||
if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
|
||||
if args.prune_nonfinal_outputs:
|
||||
audio_cache[name]['pruned'] = True
|
||||
os.remove(f'{outdir}/{voice}_{name}.wav')
|
||||
continue
|
||||
|
||||
output_voices.append(f'{outdir}/{voice}_{name}.wav')
|
||||
|
||||
if not args.embed_output_metadata:
|
||||
with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f:
|
||||
f.write(json.dumps(audio_cache[name]['settings'], indent='\t') )
|
||||
|
||||
if args.embed_output_metadata:
|
||||
for name in progress.tqdm(audio_cache, desc="Embedding metadata..."):
|
||||
if 'pruned' in audio_cache[name] and audio_cache[name]['pruned']:
|
||||
continue
|
||||
|
||||
metadata = music_tag.load_file(f"{outdir}/{voice}_{name}.wav")
|
||||
metadata['lyrics'] = json.dumps(audio_cache[name]['settings'])
|
||||
metadata.save()
|
||||
|
||||
if sample_voice is not None:
|
||||
sample_voice = (tts.input_sample_rate, sample_voice.numpy())
|
||||
|
||||
info = get_info(voice=voice, latents=False)
|
||||
print(f"Generation took {info['time']} seconds, saved to '{output_voices[0]}'\n")
|
||||
|
||||
info['seed'] = usedSeed
|
||||
if 'latents' in info:
|
||||
del info['latents']
|
||||
|
||||
os.makedirs('./config/', exist_ok=True)
|
||||
with open(f'./config/generate.json', 'w', encoding="utf-8") as f:
|
||||
f.write(json.dumps(info, indent='\t') )
|
||||
|
||||
stats = [
|
||||
[ parameters['seed'], "{:.3f}".format(info['time']) ]
|
||||
]
|
||||
|
||||
return (
|
||||
sample_voice,
|
||||
output_voices,
|
||||
stats,
|
||||
)
|
||||
|
||||
|
||||
|
||||
def generate_tortoise(**kwargs):
|
||||
parameters = {}
|
||||
parameters.update(kwargs)
|
||||
|
||||
voice = parameters['voice']
|
||||
progress = parameters['progress'] if 'progress' in parameters else None
|
||||
if parameters['seed'] == 0:
|
||||
parameters['seed'] = None
|
||||
|
||||
usedSeed = parameters['seed']
|
||||
|
||||
global args
|
||||
global tts
|
||||
|
||||
unload_whisper()
|
||||
unload_voicefixer()
|
||||
|
||||
if not tts:
|
||||
# should check if it's loading or unloaded, and load it if it's unloaded
|
||||
if tts_loading:
|
||||
raise Exception("TTS is still initializing...")
|
||||
load_tts()
|
||||
if hasattr(tts, "loading") and tts.loading:
|
||||
raise Exception("TTS is still initializing...")
|
||||
|
||||
do_gc()
|
||||
|
||||
voice_samples = None
|
||||
conditioning_latents = None
|
||||
sample_voice = None
|
||||
|
||||
voice_cache = {}
|
||||
|
@ -295,11 +597,13 @@ def generate(**kwargs):
|
|||
def get_info( voice, settings = None, latents = True ):
|
||||
info = {}
|
||||
info.update(parameters)
|
||||
info['time'] = time.time()-full_start_time
|
||||
|
||||
info['time'] = time.time()-full_start_time
|
||||
info['datetime'] = datetime.now().isoformat()
|
||||
|
||||
info['model'] = tts.autoregressive_model_path
|
||||
info['model_hash'] = tts.autoregressive_model_hash
|
||||
|
||||
info['progress'] = None
|
||||
del info['progress']
|
||||
|
||||
|
@ -381,7 +685,8 @@ def generate(**kwargs):
|
|||
|
||||
settings['text'] = cut_text
|
||||
settings['time'] = run_time
|
||||
settings['datetime'] = datetime.now().isoformat(),
|
||||
settings['datetime'] = datetime.now().isoformat()
|
||||
if args.tts_backend == "tortoise":
|
||||
settings['model'] = tts.autoregressive_model_path
|
||||
settings['model_hash'] = tts.autoregressive_model_hash
|
||||
|
||||
|
@ -745,8 +1050,8 @@ class TrainingState():
|
|||
self.it_rate = f'{"{:.3f}".format(1/it_rate)}it/s' if 0 < it_rate and it_rate < 1 else f'{"{:.3f}".format(it_rate)}s/it'
|
||||
self.it_rates += it_rate
|
||||
|
||||
if self.it_rates > 0 and self.it * self.steps > 0:
|
||||
epoch_rate = self.it_rates / self.it * self.steps
|
||||
if epoch_rate > 0:
|
||||
self.epoch_rate = f'{"{:.3f}".format(1/epoch_rate)}epoch/s' if 0 < epoch_rate and epoch_rate < 1 else f'{"{:.3f}".format(epoch_rate)}s/epoch'
|
||||
|
||||
try:
|
||||
|
@ -925,6 +1230,7 @@ class TrainingState():
|
|||
self.it_rates = 0
|
||||
|
||||
unq = {}
|
||||
averager = None
|
||||
|
||||
for log in logs:
|
||||
with open(log, 'r', encoding="utf-8") as f:
|
||||
|
@ -941,16 +1247,18 @@ class TrainingState():
|
|||
if line.find('Training Metrics:') >= 0:
|
||||
split = line.split("Training Metrics:")[-1]
|
||||
data = json.loads(split)
|
||||
data['mode'] = "training"
|
||||
|
||||
name = "train"
|
||||
mode = "training"
|
||||
elif line.find('Validation Metrics:') >= 0:
|
||||
data = json.loads(line.split("Validation Metrics:")[-1])
|
||||
data['mode'] = "validation"
|
||||
if "it" not in data:
|
||||
data['it'] = it
|
||||
if "epoch" not in data:
|
||||
data['epoch'] = epoch
|
||||
|
||||
name = data['name'] if 'name' in data else "val"
|
||||
mode = "validation"
|
||||
else:
|
||||
continue
|
||||
|
||||
|
@ -960,14 +1268,39 @@ class TrainingState():
|
|||
it = data['it']
|
||||
epoch = data['epoch']
|
||||
|
||||
# this method should have it at least
|
||||
unq[f'{it}_{name}'] = data
|
||||
if args.tts_backend == "vall-e":
|
||||
if not averager or averager['key'] != f'{it}_{name}' or averager['mode'] != mode:
|
||||
averager = {
|
||||
'key': f'{it}_{name}',
|
||||
'mode': mode,
|
||||
"metrics": {}
|
||||
}
|
||||
for k in data:
|
||||
if data[k] is None:
|
||||
continue
|
||||
averager['metrics'][k] = [ data[k] ]
|
||||
else:
|
||||
for k in data:
|
||||
if data[k] is None:
|
||||
continue
|
||||
averager['metrics'][k].append( data[k] )
|
||||
|
||||
unq[f'{it}_{mode}_{name}'] = averager
|
||||
else:
|
||||
unq[f'{it}_{mode}_{name}'] = data
|
||||
|
||||
if update and it <= self.last_info_check_at:
|
||||
continue
|
||||
|
||||
for it in unq:
|
||||
self.parse_metrics(unq[it])
|
||||
if args.tts_backend == "vall-e":
|
||||
stats = unq[it]
|
||||
data = {k: sum(v) / len(v) for k, v in stats['metrics'].items()}
|
||||
data['mode'] = stats
|
||||
data['steps'] = len(stats['metrics']['it'])
|
||||
else:
|
||||
data = unq[it]
|
||||
self.parse_metrics(data)
|
||||
|
||||
self.last_info_check_at = highest_step
|
||||
|
||||
|
@ -1087,6 +1420,7 @@ def run_training(config_path, verbose=False, keep_x_past_checkpoints=0, progress
|
|||
|
||||
|
||||
# ensure we have the dvae.pth
|
||||
if args.tts_backend == "tortoise":
|
||||
get_model_path('dvae.pth')
|
||||
|
||||
# I don't know if this is still necessary, as it was bitching at me for not doing this, despite it being in a separate process
|
||||
|
@ -2086,6 +2420,8 @@ def get_voice_list(dir=get_voice_dir(), append_defaults=False):
|
|||
res = res + defaults
|
||||
return res
|
||||
|
||||
def get_valle_models(dir="./training/"):
|
||||
return [ f'{dir}/{d}/config.yaml' for d in os.listdir(dir) if os.path.exists(f'{dir}/{d}/config.yaml') ]
|
||||
|
||||
def get_autoregressive_models(dir="./models/finetunes/", prefixed=False, auto=False):
|
||||
os.makedirs(dir, exist_ok=True)
|
||||
|
@ -2269,6 +2605,8 @@ def setup_args():
|
|||
|
||||
'phonemizer-backend': 'espeak',
|
||||
|
||||
'valle-model': None,
|
||||
|
||||
'whisper-backend': 'openai/whisper',
|
||||
'whisper-model': "base",
|
||||
'whisper-batchsize': 1,
|
||||
|
@ -2319,6 +2657,8 @@ def setup_args():
|
|||
|
||||
parser.add_argument("--phonemizer-backend", default=default_arguments['phonemizer-backend'], help="Specifies which phonemizer backend to use.")
|
||||
|
||||
parser.add_argument("--valle-model", default=default_arguments['valle-model'], help="Specifies which VALL-E model to use for sampling.")
|
||||
|
||||
parser.add_argument("--whisper-backend", default=default_arguments['whisper-backend'], action='store_true', help="Picks which whisper backend to use (openai/whisper, lightmare/whispercpp)")
|
||||
parser.add_argument("--whisper-model", default=default_arguments['whisper-model'], help="Specifies which whisper model to use for transcription.")
|
||||
parser.add_argument("--whisper-batchsize", type=int, default=default_arguments['whisper-batchsize'], help="Specifies batch size for WhisperX")
|
||||
|
@ -2390,6 +2730,8 @@ def get_default_settings( hypenated=True ):
|
|||
|
||||
'phonemizer-backend': args.phonemizer_backend,
|
||||
|
||||
'valle-model': args.valle_model,
|
||||
|
||||
'whisper-backend': args.whisper_backend,
|
||||
'whisper-model': args.whisper_model,
|
||||
'whisper-batchsize': args.whisper_batchsize,
|
||||
|
@ -2440,6 +2782,8 @@ def update_args( **kwargs ):
|
|||
|
||||
args.phonemizer_backend = settings['phonemizer_backend']
|
||||
|
||||
args.valle_model = settings['valle_model']
|
||||
|
||||
args.whisper_backend = settings['whisper_backend']
|
||||
args.whisper_model = settings['whisper_model']
|
||||
args.whisper_batchsize = settings['whisper_batchsize']
|
||||
|
@ -2553,13 +2897,20 @@ def version_check_tts( min_version ):
|
|||
return True
|
||||
return False
|
||||
|
||||
def load_tts( restart=False, autoregressive_model=None, diffusion_model=None, vocoder_model=None, tokenizer_json=None ):
|
||||
def load_tts( restart=False,
|
||||
# TorToiSe configs
|
||||
autoregressive_model=None, diffusion_model=None, vocoder_model=None, tokenizer_json=None,
|
||||
# VALL-E configs
|
||||
valle_model=None,
|
||||
):
|
||||
global args
|
||||
global tts
|
||||
|
||||
if restart:
|
||||
unload_tts()
|
||||
|
||||
tts_loading = True
|
||||
if args.tts_backend == "tortoise":
|
||||
if autoregressive_model:
|
||||
args.autoregressive_model = autoregressive_model
|
||||
else:
|
||||
|
@ -2586,17 +2937,21 @@ def load_tts( restart=False, autoregressive_model=None, diffusion_model=None, vo
|
|||
if get_device_name() == "cpu":
|
||||
print("!!!! WARNING !!!! No GPU available in PyTorch. You may need to reinstall PyTorch.")
|
||||
|
||||
tts_loading = True
|
||||
print(f"Loading TorToiSe... (AR: {autoregressive_model}, vocoder: {vocoder_model})")
|
||||
tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=autoregressive_model, diffusion_model_path=diffusion_model, vocoder_model=vocoder_model, tokenizer_json=tokenizer_json, unsqueeze_sample_batches=args.unsqueeze_sample_batches)
|
||||
print(f"Loading TorToiSe... (AR: {autoregressive_model}, diffusion: {diffusion_model}, vocoder: {vocoder_model})")
|
||||
tts = TorToise_TTS(minor_optimizations=not args.low_vram, autoregressive_model_path=autoregressive_model, diffusion_model_path=diffusion_model, vocoder_model=vocoder_model, tokenizer_json=tokenizer_json, unsqueeze_sample_batches=args.unsqueeze_sample_batches)
|
||||
elif args.tts_backend == "vall-e":
|
||||
if valle_model:
|
||||
args.valle_model = valle_model
|
||||
else:
|
||||
valle_model = args.valle_model
|
||||
|
||||
print(f"Loading VALL-E... (Config: {valle_model})")
|
||||
tts = VALLE_TTS(config=args.valle_model)
|
||||
|
||||
print("Loaded TTS, ready for generation.")
|
||||
tts_loading = False
|
||||
|
||||
get_model_path('dvae.pth')
|
||||
print("Loaded TorToiSe, ready for generation.")
|
||||
return tts
|
||||
|
||||
setup_tortoise = load_tts
|
||||
|
||||
def unload_tts():
|
||||
global tts
|
||||
|
||||
|
@ -2643,6 +2998,9 @@ def deduce_autoregressive_model(voice=None):
|
|||
return get_model_path('autoregressive.pth')
|
||||
|
||||
def update_autoregressive_model(autoregressive_model_path):
|
||||
if args.tts_backend != "tortoise":
|
||||
raise f"Unsupported backend: {args.tts_backend}"
|
||||
|
||||
match = re.findall(r'^\[[a-fA-F0-9]{8}\] (.+?)$', autoregressive_model_path)
|
||||
if match:
|
||||
autoregressive_model_path = match[0]
|
||||
|
@ -2677,6 +3035,9 @@ def update_autoregressive_model(autoregressive_model_path):
|
|||
return autoregressive_model_path
|
||||
|
||||
def update_diffusion_model(diffusion_model_path):
|
||||
if args.tts_backend != "tortoise":
|
||||
raise f"Unsupported backend: {args.tts_backend}"
|
||||
|
||||
match = re.findall(r'^\[[a-fA-F0-9]{8}\] (.+?)$', diffusion_model_path)
|
||||
if match:
|
||||
diffusion_model_path = match[0]
|
||||
|
@ -2711,6 +3072,9 @@ def update_diffusion_model(diffusion_model_path):
|
|||
return diffusion_model_path
|
||||
|
||||
def update_vocoder_model(vocoder_model):
|
||||
if args.tts_backend != "tortoise":
|
||||
raise f"Unsupported backend: {args.tts_backend}"
|
||||
|
||||
args.vocoder_model = vocoder_model
|
||||
save_args_settings()
|
||||
print(f'Stored vocoder model to settings: {vocoder_model}')
|
||||
|
@ -2733,6 +3097,9 @@ def update_vocoder_model(vocoder_model):
|
|||
return vocoder_model
|
||||
|
||||
def update_tokenizer(tokenizer_json):
|
||||
if args.tts_backend != "tortoise":
|
||||
raise f"Unsupported backend: {args.tts_backend}"
|
||||
|
||||
args.tokenizer_json = tokenizer_json
|
||||
save_args_settings()
|
||||
print(f'Stored tokenizer to settings: {tokenizer_json}')
|
||||
|
|
18
src/webui.py
18
src/webui.py
|
@ -315,6 +315,8 @@ def setup_gradio():
|
|||
voice_list = get_voice_list()
|
||||
result_voices = get_voice_list(args.results_folder)
|
||||
|
||||
valle_models = get_valle_models()
|
||||
|
||||
autoregressive_models = get_autoregressive_models()
|
||||
diffusion_models = get_diffusion_models()
|
||||
tokenizer_jsons = get_tokenizer_jsons()
|
||||
|
@ -337,11 +339,11 @@ def setup_gradio():
|
|||
with gr.Column():
|
||||
GENERATE_SETTINGS["delimiter"] = gr.Textbox(lines=1, label="Line Delimiter", placeholder="\\n")
|
||||
|
||||
GENERATE_SETTINGS["emotion"] = gr.Radio( ["Happy", "Sad", "Angry", "Disgusted", "Arrogant", "Custom", "None"], value="None", label="Emotion", type="value", interactive=True )
|
||||
GENERATE_SETTINGS["emotion"] = gr.Radio( ["Happy", "Sad", "Angry", "Disgusted", "Arrogant", "Custom", "None"], value="None", label="Emotion", type="value", interactive=True, visible=args.tts_backend=="tortoise" )
|
||||
GENERATE_SETTINGS["prompt"] = gr.Textbox(lines=1, label="Custom Emotion", visible=False)
|
||||
GENERATE_SETTINGS["voice"] = gr.Dropdown(choices=voice_list_with_defaults, label="Voice", type="value", value=voice_list_with_defaults[0]) # it'd be very cash money if gradio was able to default to the first value in the list without this shit
|
||||
GENERATE_SETTINGS["mic_audio"] = gr.Audio( label="Microphone Source", source="microphone", type="filepath", visible=False )
|
||||
GENERATE_SETTINGS["voice_latents_chunks"] = gr.Number(label="Voice Chunks", precision=0, value=0)
|
||||
GENERATE_SETTINGS["voice_latents_chunks"] = gr.Number(label="Voice Chunks", precision=0, value=0, visible=args.tts_backend=="tortoise")
|
||||
with gr.Row():
|
||||
refresh_voices = gr.Button(value="Refresh Voice List")
|
||||
recompute_voice_latents = gr.Button(value="(Re)Compute Voice Latents")
|
||||
|
@ -357,17 +359,17 @@ def setup_gradio():
|
|||
outputs=GENERATE_SETTINGS["mic_audio"],
|
||||
)
|
||||
with gr.Column():
|
||||
GENERATE_SETTINGS["candidates"] = gr.Slider(value=1, minimum=1, maximum=6, step=1, label="Candidates")
|
||||
GENERATE_SETTINGS["candidates"] = gr.Slider(value=1, minimum=1, maximum=6, step=1, label="Candidates", visible=args.tts_backend=="tortoise")
|
||||
GENERATE_SETTINGS["seed"] = gr.Number(value=0, precision=0, label="Seed")
|
||||
|
||||
preset = gr.Radio( ["Ultra Fast", "Fast", "Standard", "High Quality"], label="Preset", type="value", value="Ultra Fast" )
|
||||
|
||||
GENERATE_SETTINGS["num_autoregressive_samples"] = gr.Slider(value=16, minimum=2, maximum=512, step=1, label="Samples")
|
||||
GENERATE_SETTINGS["diffusion_iterations"] = gr.Slider(value=30, minimum=0, maximum=512, step=1, label="Iterations")
|
||||
GENERATE_SETTINGS["diffusion_iterations"] = gr.Slider(value=30, minimum=0, maximum=512, step=1, label="Iterations", visible=args.tts_backend=="tortoise")
|
||||
|
||||
GENERATE_SETTINGS["temperature"] = gr.Slider(value=0.2, minimum=0, maximum=1, step=0.1, label="Temperature")
|
||||
|
||||
show_experimental_settings = gr.Checkbox(label="Show Experimental Settings")
|
||||
show_experimental_settings = gr.Checkbox(label="Show Experimental Settings", visible=args.tts_backend=="tortoise")
|
||||
reset_generate_settings_button = gr.Button(value="Reset to Default")
|
||||
with gr.Column(visible=False) as col:
|
||||
experimental_column = col
|
||||
|
@ -606,10 +608,12 @@ def setup_gradio():
|
|||
EXEC_SETTINGS['device_override'] = gr.Textbox(label="Device Override", value=args.device_override)
|
||||
|
||||
EXEC_SETTINGS['results_folder'] = gr.Textbox(label="Results Folder", value=args.results_folder)
|
||||
|
||||
with gr.Column():
|
||||
# EXEC_SETTINGS['tts_backend'] = gr.Dropdown(TTSES, label="TTS Backend", value=args.tts_backend if args.tts_backend else TTSES[0])
|
||||
|
||||
with gr.Column(visible=args.tts_backend=="vall-e"):
|
||||
EXEC_SETTINGS['valle_model'] = gr.Dropdown(choices=valle_models, label="VALL-E Model Config", value=args.valle_model if args.valle_model else valle_models[0])
|
||||
|
||||
with gr.Column(visible=args.tts_backend=="tortoise"):
|
||||
EXEC_SETTINGS['autoregressive_model'] = gr.Dropdown(choices=["auto"] + autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else "auto")
|
||||
EXEC_SETTINGS['diffusion_model'] = gr.Dropdown(choices=diffusion_models, label="Diffusion Model", value=args.diffusion_model if args.diffusion_model else diffusion_models[0])
|
||||
EXEC_SETTINGS['vocoder_model'] = gr.Dropdown(VOCODERS, label="Vocoder", value=args.vocoder_model if args.vocoder_model else VOCODERS[-1])
|
||||
|
|
Loading…
Reference in New Issue
Block a user