Very, very, VERY, barebones integration with Bark (documentation soon)

This commit is contained in:
mrq 2023-04-26 04:48:09 +00:00
parent faa8da12d7
commit b6440091fb
2 changed files with 450 additions and 60 deletions

View File

@ -31,6 +31,7 @@ import music_tag
import gradio as gr
import gradio.utils
import pandas as pd
import numpy as np
from glob import glob
from datetime import datetime
@ -65,6 +66,7 @@ MIN_TRAINING_DURATION = 0.6
MAX_TRAINING_DURATION = 11.6097505669
VALLE_ENABLED = False
BARK_ENABLED = False
try:
from vall_e.emb.qnt import encode as valle_quantize
@ -76,11 +78,98 @@ try:
VALLE_ENABLED = True
except Exception as e:
if False: # args.tts_backend == "vall-e":
raise e
pass
if VALLE_ENABLED:
TTSES.append('vall-e')
try:
from bark.generation import SAMPLE_RATE as BARK_SAMPLE_RATE, ALLOWED_PROMPTS, preload_models, codec_decode, generate_coarse, generate_fine, generate_text_semantic, load_codec_model
from bark.api import generate_audio as bark_generate_audio
from encodec.utils import convert_audio
from scipy.io.wavfile import write as write_wav
BARK_ENABLED = True
except Exception as e:
if False: # args.tts_backend == "bark":
raise e
pass
if BARK_ENABLED:
TTSES.append('bark')
class Bark_TTS():
def __init__(self, small=False):
self.input_sample_rate = BARK_SAMPLE_RATE
self.output_sample_rate = args.output_sample_rate
preload_models(
text_use_gpu=True,
coarse_use_gpu=True,
fine_use_gpu=True,
codec_use_gpu=True,
text_use_small=small,
coarse_use_small=small,
fine_use_small=small,
force_reload=False
)
def create_voice( self, voice, device='cuda' ):
transcription_json = f'./training/{voice}/whisper.json'
if not os.path.exists(transcription_json):
raise f"Transcription for voice not found: {voice}"
transcriptions = json.load(open(transcription_json, 'r', encoding="utf-8"))
candidates = []
for file in transcriptions:
result = transcriptions[file]
for segment in result['segments']:
entry = (
file.replace(".wav", f"_{pad(segment['id'], 4)}.wav"),
segment['end'] - segment['start'],
segment['text']
)
candidates.append(entry)
candidates.sort(key=lambda x: x[1])
candidate = random.choice(candidates)
audio_filepath = f'./training/{voice}/audio/{candidate[0]}'
text = candidate[-1]
print("Using as reference:", audio_filepath, text)
# Load and pre-process the audio waveform
model = load_codec_model(use_gpu=True)
wav, sr = torchaudio.load(audio_filepath)
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
wav = wav.unsqueeze(0).to(device)
# Extract discrete codes from EnCodec
with torch.no_grad():
encoded_frames = model.encode(wav)
codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze().cpu().numpy() # [n_q, T]
# get seconds of audio
seconds = wav.shape[-1] / model.sample_rate
# generate semantic tokens
semantic_tokens = generate_text_semantic(text, max_gen_duration_s=seconds, top_k=50, top_p=.95, temp=0.7)
output_path = './modules/bark/bark/assets/prompts/' + voice.replace("/", "_") + '.npz'
np.savez(output_path, fine_prompt=codes, coarse_prompt=codes[:2, :], semantic_prompt=semantic_tokens)
def inference( self, text, voice, text_temp=0.7, waveform_temp=0.7 ):
if not os.path.exists('./modules/bark/bark/assets/prompts/' + voice + '.npz'):
self.create_voice( voice )
voice = voice.replace("/", "_")
if voice not in ALLOWED_PROMPTS:
ALLOWED_PROMPTS.add( voice )
return (bark_generate_audio(text, history_prompt=voice, text_temp=text_temp, waveform_temp=waveform_temp), BARK_SAMPLE_RATE)
args = None
tts = None
tts_loading = False
@ -96,6 +185,9 @@ training_state = None
current_voice = None
def cleanup_voice_name( name ):
return name.split("/")[-1]
def resample( waveform, input_rate, output_rate=44100 ):
# mono-ize
waveform = torch.mean(waveform, dim=0, keepdim=True)
@ -121,6 +213,291 @@ def generate(**kwargs):
return generate_tortoise(**kwargs)
if args.tts_backend == "vall-e":
return generate_valle(**kwargs)
if args.tts_backend == "bark":
return generate_bark(**kwargs)
def generate_bark(**kwargs):
parameters = {}
parameters.update(kwargs)
voice = parameters['voice']
progress = parameters['progress'] if 'progress' in parameters else None
if parameters['seed'] == 0:
parameters['seed'] = None
usedSeed = parameters['seed']
global args
global tts
unload_whisper()
unload_voicefixer()
if not tts:
# should check if it's loading or unloaded, and load it if it's unloaded
if tts_loading:
raise Exception("TTS is still initializing...")
if progress is not None:
progress(0, "Initializing TTS...")
load_tts()
if hasattr(tts, "loading") and tts.loading:
raise Exception("TTS is still initializing...")
do_gc()
voice_samples = None
conditioning_latents = None
sample_voice = None
voice_cache = {}
def get_settings( override=None ):
settings = {
'voice': parameters['voice'],
'text_temp': float(parameters['temperature']),
'waveform_temp': float(parameters['temperature']),
}
# could be better to just do a ternary on everything above, but i am not a professional
selected_voice = voice
if override is not None:
if 'voice' in override:
selected_voice = override['voice']
for k in override:
if k not in settings:
continue
settings[k] = override[k]
return settings
if not parameters['delimiter']:
parameters['delimiter'] = "\n"
elif parameters['delimiter'] == "\\n":
parameters['delimiter'] = "\n"
if parameters['delimiter'] and parameters['delimiter'] != "" and parameters['delimiter'] in parameters['text']:
texts = parameters['text'].split(parameters['delimiter'])
else:
texts = split_and_recombine_text(parameters['text'])
full_start_time = time.time()
outdir = f"{args.results_folder}/{voice}/"
os.makedirs(outdir, exist_ok=True)
audio_cache = {}
volume_adjust = torchaudio.transforms.Vol(gain=args.output_volume, gain_type="amplitude") if args.output_volume != 1 else None
idx = 0
idx_cache = {}
for i, file in enumerate(os.listdir(outdir)):
filename = os.path.basename(file)
extension = os.path.splitext(filename)[1]
if extension != ".json" and extension != ".wav":
continue
match = re.findall(rf"^{cleanup_voice_name(voice)}_(\d+)(?:.+?)?{extension}$", filename)
if match and len(match) > 0:
key = int(match[0])
idx_cache[key] = True
if len(idx_cache) > 0:
keys = sorted(list(idx_cache.keys()))
idx = keys[-1] + 1
idx = pad(idx, 4)
def get_name(line=0, candidate=0, combined=False):
name = f"{idx}"
if combined:
name = f"{name}_combined"
elif len(texts) > 1:
name = f"{name}_{line}"
if parameters['candidates'] > 1:
name = f"{name}_{candidate}"
return name
def get_info( voice, settings = None, latents = True ):
info = {}
info.update(parameters)
info['time'] = time.time()-full_start_time
info['datetime'] = datetime.now().isoformat()
info['progress'] = None
del info['progress']
if info['delimiter'] == "\n":
info['delimiter'] = "\\n"
if settings is not None:
for k in settings:
if k in info:
info[k] = settings[k]
return info
INFERENCING = True
for line, cut_text in enumerate(texts):
progress.msg_prefix = f'[{str(line+1)}/{str(len(texts))}]'
print(f"{progress.msg_prefix} Generating line: {cut_text}")
start_time = time.time()
# do setting editing
match = re.findall(r'^(\{.+\}) (.+?)$', cut_text)
override = None
if match and len(match) > 0:
match = match[0]
try:
override = json.loads(match[0])
cut_text = match[1].strip()
except Exception as e:
raise Exception("Prompt settings editing requested, but received invalid JSON")
settings = get_settings( override=override )
gen = tts.inference(cut_text, **settings )
run_time = time.time()-start_time
print(f"Generating line took {run_time} seconds")
if not isinstance(gen, list):
gen = [gen]
for j, g in enumerate(gen):
wav, sr = g
name = get_name(line=line, candidate=j)
settings['text'] = cut_text
settings['time'] = run_time
settings['datetime'] = datetime.now().isoformat()
# save here in case some error happens mid-batch
#torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu(), sr)
write_wav(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', sr, wav)
wav, sr = torchaudio.load(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
audio_cache[name] = {
'audio': wav,
'settings': get_info(voice=override['voice'] if override and 'voice' in override else voice, settings=settings)
}
del gen
do_gc()
INFERENCING = False
for k in audio_cache:
audio = audio_cache[k]['audio']
audio, _ = resample(audio, tts.output_sample_rate, args.output_sample_rate)
if volume_adjust is not None:
audio = volume_adjust(audio)
audio_cache[k]['audio'] = audio
torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{k}.wav', audio, args.output_sample_rate)
output_voices = []
for candidate in range(parameters['candidates']):
if len(texts) > 1:
audio_clips = []
for line in range(len(texts)):
name = get_name(line=line, candidate=candidate)
audio = audio_cache[name]['audio']
audio_clips.append(audio)
name = get_name(candidate=candidate, combined=True)
audio = torch.cat(audio_clips, dim=-1)
torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', audio, args.output_sample_rate)
audio = audio.squeeze(0).cpu()
audio_cache[name] = {
'audio': audio,
'settings': get_info(voice=voice),
'output': True
}
else:
name = get_name(candidate=candidate)
audio_cache[name]['output'] = True
if args.voice_fixer:
if not voicefixer:
progress(0, "Loading voicefix...")
load_voicefixer()
try:
fixed_cache = {}
for name in progress.tqdm(audio_cache, desc="Running voicefix..."):
del audio_cache[name]['audio']
if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
continue
path = f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav'
fixed = f'{outdir}/{cleanup_voice_name(voice)}_{name}_fixed.wav'
voicefixer.restore(
input=path,
output=fixed,
cuda=get_device_name() == "cuda" and args.voice_fixer_use_cuda,
#mode=mode,
)
fixed_cache[f'{name}_fixed'] = {
'settings': audio_cache[name]['settings'],
'output': True
}
audio_cache[name]['output'] = False
for name in fixed_cache:
audio_cache[name] = fixed_cache[name]
except Exception as e:
print(e)
print("\nFailed to run Voicefixer")
for name in audio_cache:
if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
if args.prune_nonfinal_outputs:
audio_cache[name]['pruned'] = True
os.remove(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
continue
output_voices.append(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
if not args.embed_output_metadata:
with open(f'{outdir}/{cleanup_voice_name(voice)}_{name}.json', 'w', encoding="utf-8") as f:
f.write(json.dumps(audio_cache[name]['settings'], indent='\t') )
if args.embed_output_metadata:
for name in progress.tqdm(audio_cache, desc="Embedding metadata..."):
if 'pruned' in audio_cache[name] and audio_cache[name]['pruned']:
continue
metadata = music_tag.load_file(f"{outdir}/{cleanup_voice_name(voice)}_{name}.wav")
metadata['lyrics'] = json.dumps(audio_cache[name]['settings'])
metadata.save()
if sample_voice is not None:
sample_voice = (tts.input_sample_rate, sample_voice.numpy())
info = get_info(voice=voice, latents=False)
print(f"Generation took {info['time']} seconds, saved to '{output_voices[0]}'\n")
info['seed'] = usedSeed
if 'latents' in info:
del info['latents']
os.makedirs('./config/', exist_ok=True)
with open(f'./config/generate.json', 'w', encoding="utf-8") as f:
f.write(json.dumps(info, indent='\t') )
stats = [
[ parameters['seed'], "{:.3f}".format(info['time']) ]
]
return (
sample_voice,
output_voices,
stats,
)
def generate_valle(**kwargs):
parameters = {}
@ -289,9 +666,9 @@ def generate_valle(**kwargs):
settings['datetime'] = datetime.now().isoformat()
# save here in case some error happens mid-batch
#torchaudio.save(f'{outdir}/{voice}_{name}.wav', wav.cpu(), sr)
soundfile.write(f'{outdir}/{voice}_{name}.wav', wav.cpu()[0,0], sr)
wav, sr = torchaudio.load(f'{outdir}/{voice}_{name}.wav')
#torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu(), sr)
soundfile.write(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu()[0,0], sr)
wav, sr = torchaudio.load(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
audio_cache[name] = {
'audio': wav,
@ -310,7 +687,7 @@ def generate_valle(**kwargs):
audio = volume_adjust(audio)
audio_cache[k]['audio'] = audio
torchaudio.save(f'{outdir}/{voice}_{k}.wav', audio, args.output_sample_rate)
torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{k}.wav', audio, args.output_sample_rate)
output_voices = []
for candidate in range(parameters['candidates']):
@ -323,7 +700,7 @@ def generate_valle(**kwargs):
name = get_name(candidate=candidate, combined=True)
audio = torch.cat(audio_clips, dim=-1)
torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, args.output_sample_rate)
torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', audio, args.output_sample_rate)
audio = audio.squeeze(0).cpu()
audio_cache[name] = {
@ -348,8 +725,8 @@ def generate_valle(**kwargs):
if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
continue
path = f'{outdir}/{voice}_{name}.wav'
fixed = f'{outdir}/{voice}_{name}_fixed.wav'
path = f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav'
fixed = f'{outdir}/{cleanup_voice_name(voice)}_{name}_fixed.wav'
voicefixer.restore(
input=path,
output=fixed,
@ -373,13 +750,13 @@ def generate_valle(**kwargs):
if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
if args.prune_nonfinal_outputs:
audio_cache[name]['pruned'] = True
os.remove(f'{outdir}/{voice}_{name}.wav')
os.remove(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
continue
output_voices.append(f'{outdir}/{voice}_{name}.wav')
output_voices.append(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
if not args.embed_output_metadata:
with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f:
with open(f'{outdir}/{cleanup_voice_name(voice)}_{name}.json', 'w', encoding="utf-8") as f:
f.write(json.dumps(audio_cache[name]['settings'], indent='\t') )
if args.embed_output_metadata:
@ -387,7 +764,7 @@ def generate_valle(**kwargs):
if 'pruned' in audio_cache[name] and audio_cache[name]['pruned']:
continue
metadata = music_tag.load_file(f"{outdir}/{voice}_{name}.wav")
metadata = music_tag.load_file(f"{outdir}/{cleanup_voice_name(voice)}_{name}.wav")
metadata['lyrics'] = json.dumps(audio_cache[name]['settings'])
metadata.save()
@ -415,8 +792,6 @@ def generate_valle(**kwargs):
stats,
)
def generate_tortoise(**kwargs):
parameters = {}
parameters.update(kwargs)
@ -698,7 +1073,7 @@ def generate_tortoise(**kwargs):
'settings': get_info(voice=override['voice'] if override and 'voice' in override else voice, settings=settings)
}
# save here in case some error happens mid-batch
torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, tts.output_sample_rate)
torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', audio, tts.output_sample_rate)
del gen
do_gc()
@ -712,7 +1087,7 @@ def generate_tortoise(**kwargs):
audio = volume_adjust(audio)
audio_cache[k]['audio'] = audio
torchaudio.save(f'{outdir}/{voice}_{k}.wav', audio, args.output_sample_rate)
torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{k}.wav', audio, args.output_sample_rate)
output_voices = []
for candidate in range(parameters['candidates']):
@ -725,7 +1100,7 @@ def generate_tortoise(**kwargs):
name = get_name(candidate=candidate, combined=True)
audio = torch.cat(audio_clips, dim=-1)
torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, args.output_sample_rate)
torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', audio, args.output_sample_rate)
audio = audio.squeeze(0).cpu()
audio_cache[name] = {
@ -750,8 +1125,8 @@ def generate_tortoise(**kwargs):
if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
continue
path = f'{outdir}/{voice}_{name}.wav'
fixed = f'{outdir}/{voice}_{name}_fixed.wav'
path = f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav'
fixed = f'{outdir}/{cleanup_voice_name(voice)}_{name}_fixed.wav'
voicefixer.restore(
input=path,
output=fixed,
@ -775,13 +1150,13 @@ def generate_tortoise(**kwargs):
if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
if args.prune_nonfinal_outputs:
audio_cache[name]['pruned'] = True
os.remove(f'{outdir}/{voice}_{name}.wav')
os.remove(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
continue
output_voices.append(f'{outdir}/{voice}_{name}.wav')
output_voices.append(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
if not args.embed_output_metadata:
with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f:
with open(f'{outdir}/{cleanup_voice_name(voice)}_{name}.json', 'w', encoding="utf-8") as f:
f.write(json.dumps(audio_cache[name]['settings'], indent='\t') )
if args.embed_output_metadata:
@ -789,7 +1164,7 @@ def generate_tortoise(**kwargs):
if 'pruned' in audio_cache[name] and audio_cache[name]['pruned']:
continue
metadata = music_tag.load_file(f"{outdir}/{voice}_{name}.wav")
metadata = music_tag.load_file(f"{outdir}/{cleanup_voice_name(voice)}_{name}.wav")
metadata['lyrics'] = json.dumps(audio_cache[name]['settings'])
metadata.save()
@ -1096,9 +1471,9 @@ class TrainingState():
'ar-half.loss', 'nar-half.loss', 'ar-half+nar-half.loss',
'ar-quarter.loss', 'nar-quarter.loss', 'ar-quarter+nar-quarter.loss',
# 'ar.loss.nll', 'nar.loss.nll',
# 'ar-half.loss.nll', 'nar-half.loss.nll',
# 'ar-quarter.loss.nll', 'nar-quarter.loss.nll',
'ar.loss.nll', 'nar.loss.nll',
'ar-half.loss.nll', 'nar-half.loss.nll',
'ar-quarter.loss.nll', 'nar-quarter.loss.nll',
]
keys['accuracies'] = [
@ -1464,14 +1839,14 @@ def run_training(config_path, verbose=False, keep_x_past_checkpoints=0, progress
return_code = training_state.process.wait()
training_state = None
def update_training_dataplot(x_lim=None, y_lim=None, config_path=None):
def update_training_dataplot(x_min=None, x_max=None, y_min=None, y_max=None, config_path=None):
global training_state
losses = None
lrs = None
grad_norms = None
x_lim = [ 0, x_lim ]
y_lim = [ 0, y_lim ]
x_lim = [ x_min, x_max ]
y_lim = [ y_min, y_max ]
if not training_state:
if config_path:
@ -1490,23 +1865,23 @@ def update_training_dataplot(x_lim=None, y_lim=None, config_path=None):
losses = gr.LinePlot.update(
value = pd.DataFrame(training_state.statistics['loss']),
x_lim=x_lim, y_lim=y_lim,
x="epoch", y="value",
x="it", y="value", # x="epoch",
title="Loss Metrics", color="type", tooltip=['epoch', 'it', 'value', 'type'],
width=500, height=350
)
if len(training_state.statistics['lr']) > 0:
lrs = gr.LinePlot.update(
value = pd.DataFrame(training_state.statistics['lr']),
x_lim=x_lim, y_lim=y_lim,
x="epoch", y="value",
x_lim=x_lim,
x="it", y="value", # x="epoch",
title="Learning Rate", color="type", tooltip=['epoch', 'it', 'value', 'type'],
width=500, height=350
)
if len(training_state.statistics['grad_norm']) > 0:
grad_norms = gr.LinePlot.update(
value = pd.DataFrame(training_state.statistics['grad_norm']),
x_lim=x_lim, y_lim=y_lim,
x="epoch", y="value",
x_lim=x_lim,
x="it", y="value", # x="epoch",
title="Gradient Normals", color="type", tooltip=['epoch', 'it', 'value', 'type'],
width=500, height=350
)
@ -1649,13 +2024,13 @@ def whisper_transcribe( file, language=None ):
device = "cuda" if get_device_name() == "cuda" else "cpu"
if whisper_vad:
# omits a considerable amount of the end
"""
if args.whisper_batchsize > 1:
result = whisperx.transcribe_with_vad_parallel(whisper_model, file, whisper_vad, batch_size=args.whisper_batchsize, language=language, task="transcribe")
else:
result = whisperx.transcribe_with_vad(whisper_model, file, whisper_vad)
"""
result = whisperx.transcribe_with_vad(whisper_model, file, whisper_vad)
"""
else:
result = whisper_model.transcribe(file)
@ -1717,7 +2092,7 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non
os.makedirs(f'{indir}/audio/', exist_ok=True)
TARGET_SAMPLE_RATE = 22050
if args.tts_backend == "vall-e":
if args.tts_backend != "tortoise":
TARGET_SAMPLE_RATE = 24000
if tts:
TARGET_SAMPLE_RATE = tts.input_sample_rate
@ -1735,7 +2110,7 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non
try:
result = whisper_transcribe(file, language=language)
except Exception as e:
print("Failed to transcribe:", file)
print("Failed to transcribe:", file, e)
continue
results[basename] = result
@ -1802,7 +2177,7 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, resul
results = json.load(open(infile, 'r', encoding="utf-8"))
TARGET_SAMPLE_RATE = 22050
if args.tts_backend == "vall-e":
if args.tts_backend != "tortoise":
TARGET_SAMPLE_RATE = 24000
if tts:
TARGET_SAMPLE_RATE = tts.input_sample_rate
@ -1934,8 +2309,7 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
lines = { 'training': [], 'validation': [] }
segments = {}
# I'm not sure how the VALL-E implementation decides what's validation and what's not
if args.tts_backend == "vall-e":
if args.tts_backend != "tortoise":
text_length = 0
audio_length = 0
@ -3008,6 +3382,10 @@ def load_tts( restart=False,
print(f"Loading VALL-E... (Config: {valle_model})")
tts = VALLE_TTS(config=args.valle_model)
elif args.tts_backend == "bark":
print(f"Loading Bark...")
tts = Bark_TTS(small=args.low_vram)
print("Loaded TTS, ready for generation.")
tts_loading = False

View File

@ -167,6 +167,10 @@ def reset_generate_settings_proxy():
return tuple(res)
def compute_latents_proxy(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
if args.tts_backend == "bark":
global tts
tts.create_voice( voice )
return voice
compute_latents( voice=voice, voice_latents_chunks=voice_latents_chunks, progress=progress )
return voice
@ -222,13 +226,13 @@ def prepare_all_datasets( language, validation_text_length, validation_audio_len
print("Processing:", voice)
message = transcribe_dataset( voice=voice, language=language, skip_existings=skip_existings, progress=progress )
messages.append(message)
"""
if slice_audio:
for voice in voices:
print("Processing:", voice)
message = slice_dataset( voice, trim_silence=trim_silence, start_offset=slice_start_offset, end_offset=slice_end_offset, results=None, progress=progress )
messages.append(message)
"""
for voice in voices:
print("Processing:", voice)
@ -400,12 +404,13 @@ def setup_gradio():
outputs=GENERATE_SETTINGS["mic_audio"],
)
with gr.Column():
preset = None
GENERATE_SETTINGS["candidates"] = gr.Slider(value=1, minimum=1, maximum=6, step=1, label="Candidates", visible=args.tts_backend=="tortoise")
GENERATE_SETTINGS["seed"] = gr.Number(value=0, precision=0, label="Seed")
GENERATE_SETTINGS["seed"] = gr.Number(value=0, precision=0, label="Seed", visible=args.tts_backend!="tortoise")
preset = gr.Radio( ["Ultra Fast", "Fast", "Standard", "High Quality"], label="Preset", type="value", value="Ultra Fast" )
preset = gr.Radio( ["Ultra Fast", "Fast", "Standard", "High Quality"], label="Preset", type="value", value="Ultra Fast", visible=args.tts_backend=="tortoise" )
GENERATE_SETTINGS["num_autoregressive_samples"] = gr.Slider(value=16, minimum=2, maximum=512, step=1, label="Samples")
GENERATE_SETTINGS["num_autoregressive_samples"] = gr.Slider(value=16, minimum=2, maximum=512, step=1, label="Samples", visible=args.tts_backend!="bark")
GENERATE_SETTINGS["diffusion_iterations"] = gr.Slider(value=30, minimum=0, maximum=512, step=1, label="Iterations", visible=args.tts_backend=="tortoise")
GENERATE_SETTINGS["temperature"] = gr.Slider(value=0.2, minimum=0, maximum=1, step=0.1, label="Temperature")
@ -490,7 +495,7 @@ def setup_gradio():
merger_button = gr.Button(value="Run Merger")
with gr.Column():
merger_output = gr.TextArea(label="Console Output", max_lines=8)
with gr.Tab("Training"):
with gr.Tab("Training", visible=args.tts_backend != "bark"):
with gr.Tab("Prepare Dataset"):
with gr.Row():
with gr.Column():
@ -586,8 +591,10 @@ def setup_gradio():
keep_x_past_checkpoints = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
with gr.Row():
training_graph_x_lim = gr.Number(label="X Limit", precision=0, value=0)
training_graph_y_lim = gr.Number(label="Y Limit", precision=0, value=0)
training_graph_x_min = gr.Number(label="X Min", precision=0, value=0)
training_graph_x_max = gr.Number(label="X Max", precision=0, value=0)
training_graph_y_min = gr.Number(label="Y Min", precision=0, value=0)
training_graph_y_max = gr.Number(label="Y Max", precision=0, value=0)
with gr.Row():
start_training_button = gr.Button(value="Train")
@ -597,7 +604,7 @@ def setup_gradio():
with gr.Column():
training_loss_graph = gr.LinePlot(label="Training Metrics",
x="epoch",
x="it", # x="epoch",
y="value",
title="Loss Metrics",
color="type",
@ -606,7 +613,7 @@ def setup_gradio():
height=350,
)
training_lr_graph = gr.LinePlot(label="Training Metrics",
x="epoch",
x="it", # x="epoch",
y="value",
title="Learning Rate",
color="type",
@ -615,7 +622,7 @@ def setup_gradio():
height=350,
)
training_grad_norm_graph = gr.LinePlot(label="Training Metrics",
x="epoch",
x="it", # x="epoch",
y="value",
title="Gradient Normals",
color="type",
@ -765,13 +772,14 @@ def setup_gradio():
inputs=show_experimental_settings,
outputs=experimental_column
)
preset.change(fn=update_presets,
inputs=preset,
outputs=[
GENERATE_SETTINGS['num_autoregressive_samples'],
GENERATE_SETTINGS['diffusion_iterations'],
],
)
if preset:
preset.change(fn=update_presets,
inputs=preset,
outputs=[
GENERATE_SETTINGS['num_autoregressive_samples'],
GENERATE_SETTINGS['diffusion_iterations'],
],
)
recompute_voice_latents.click(compute_latents_proxy,
inputs=[
@ -860,8 +868,10 @@ def setup_gradio():
training_output.change(
fn=update_training_dataplot,
inputs=[
training_graph_x_lim,
training_graph_y_lim,
training_graph_x_min,
training_graph_x_max,
training_graph_y_min,
training_graph_y_max,
],
outputs=[
training_loss_graph,
@ -874,8 +884,10 @@ def setup_gradio():
view_losses.click(
fn=update_training_dataplot,
inputs=[
training_graph_x_lim,
training_graph_y_lim,
training_graph_x_min,
training_graph_x_max,
training_graph_y_min,
training_graph_y_max,
training_configs,
],
outputs=[