Very, very, VERY, barebones integration with Bark (documentation soon)

This commit is contained in:
mrq 2023-04-26 04:48:09 +00:00
parent faa8da12d7
commit b6440091fb
2 changed files with 450 additions and 60 deletions

View File

@ -31,6 +31,7 @@ import music_tag
import gradio as gr import gradio as gr
import gradio.utils import gradio.utils
import pandas as pd import pandas as pd
import numpy as np
from glob import glob from glob import glob
from datetime import datetime from datetime import datetime
@ -65,6 +66,7 @@ MIN_TRAINING_DURATION = 0.6
MAX_TRAINING_DURATION = 11.6097505669 MAX_TRAINING_DURATION = 11.6097505669
VALLE_ENABLED = False VALLE_ENABLED = False
BARK_ENABLED = False
try: try:
from vall_e.emb.qnt import encode as valle_quantize from vall_e.emb.qnt import encode as valle_quantize
@ -76,11 +78,98 @@ try:
VALLE_ENABLED = True VALLE_ENABLED = True
except Exception as e: except Exception as e:
if False: # args.tts_backend == "vall-e":
raise e
pass pass
if VALLE_ENABLED: if VALLE_ENABLED:
TTSES.append('vall-e') TTSES.append('vall-e')
try:
from bark.generation import SAMPLE_RATE as BARK_SAMPLE_RATE, ALLOWED_PROMPTS, preload_models, codec_decode, generate_coarse, generate_fine, generate_text_semantic, load_codec_model
from bark.api import generate_audio as bark_generate_audio
from encodec.utils import convert_audio
from scipy.io.wavfile import write as write_wav
BARK_ENABLED = True
except Exception as e:
if False: # args.tts_backend == "bark":
raise e
pass
if BARK_ENABLED:
TTSES.append('bark')
class Bark_TTS():
def __init__(self, small=False):
self.input_sample_rate = BARK_SAMPLE_RATE
self.output_sample_rate = args.output_sample_rate
preload_models(
text_use_gpu=True,
coarse_use_gpu=True,
fine_use_gpu=True,
codec_use_gpu=True,
text_use_small=small,
coarse_use_small=small,
fine_use_small=small,
force_reload=False
)
def create_voice( self, voice, device='cuda' ):
transcription_json = f'./training/{voice}/whisper.json'
if not os.path.exists(transcription_json):
raise f"Transcription for voice not found: {voice}"
transcriptions = json.load(open(transcription_json, 'r', encoding="utf-8"))
candidates = []
for file in transcriptions:
result = transcriptions[file]
for segment in result['segments']:
entry = (
file.replace(".wav", f"_{pad(segment['id'], 4)}.wav"),
segment['end'] - segment['start'],
segment['text']
)
candidates.append(entry)
candidates.sort(key=lambda x: x[1])
candidate = random.choice(candidates)
audio_filepath = f'./training/{voice}/audio/{candidate[0]}'
text = candidate[-1]
print("Using as reference:", audio_filepath, text)
# Load and pre-process the audio waveform
model = load_codec_model(use_gpu=True)
wav, sr = torchaudio.load(audio_filepath)
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
wav = wav.unsqueeze(0).to(device)
# Extract discrete codes from EnCodec
with torch.no_grad():
encoded_frames = model.encode(wav)
codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze().cpu().numpy() # [n_q, T]
# get seconds of audio
seconds = wav.shape[-1] / model.sample_rate
# generate semantic tokens
semantic_tokens = generate_text_semantic(text, max_gen_duration_s=seconds, top_k=50, top_p=.95, temp=0.7)
output_path = './modules/bark/bark/assets/prompts/' + voice.replace("/", "_") + '.npz'
np.savez(output_path, fine_prompt=codes, coarse_prompt=codes[:2, :], semantic_prompt=semantic_tokens)
def inference( self, text, voice, text_temp=0.7, waveform_temp=0.7 ):
if not os.path.exists('./modules/bark/bark/assets/prompts/' + voice + '.npz'):
self.create_voice( voice )
voice = voice.replace("/", "_")
if voice not in ALLOWED_PROMPTS:
ALLOWED_PROMPTS.add( voice )
return (bark_generate_audio(text, history_prompt=voice, text_temp=text_temp, waveform_temp=waveform_temp), BARK_SAMPLE_RATE)
args = None args = None
tts = None tts = None
tts_loading = False tts_loading = False
@ -96,6 +185,9 @@ training_state = None
current_voice = None current_voice = None
def cleanup_voice_name( name ):
return name.split("/")[-1]
def resample( waveform, input_rate, output_rate=44100 ): def resample( waveform, input_rate, output_rate=44100 ):
# mono-ize # mono-ize
waveform = torch.mean(waveform, dim=0, keepdim=True) waveform = torch.mean(waveform, dim=0, keepdim=True)
@ -121,6 +213,291 @@ def generate(**kwargs):
return generate_tortoise(**kwargs) return generate_tortoise(**kwargs)
if args.tts_backend == "vall-e": if args.tts_backend == "vall-e":
return generate_valle(**kwargs) return generate_valle(**kwargs)
if args.tts_backend == "bark":
return generate_bark(**kwargs)
def generate_bark(**kwargs):
parameters = {}
parameters.update(kwargs)
voice = parameters['voice']
progress = parameters['progress'] if 'progress' in parameters else None
if parameters['seed'] == 0:
parameters['seed'] = None
usedSeed = parameters['seed']
global args
global tts
unload_whisper()
unload_voicefixer()
if not tts:
# should check if it's loading or unloaded, and load it if it's unloaded
if tts_loading:
raise Exception("TTS is still initializing...")
if progress is not None:
progress(0, "Initializing TTS...")
load_tts()
if hasattr(tts, "loading") and tts.loading:
raise Exception("TTS is still initializing...")
do_gc()
voice_samples = None
conditioning_latents = None
sample_voice = None
voice_cache = {}
def get_settings( override=None ):
settings = {
'voice': parameters['voice'],
'text_temp': float(parameters['temperature']),
'waveform_temp': float(parameters['temperature']),
}
# could be better to just do a ternary on everything above, but i am not a professional
selected_voice = voice
if override is not None:
if 'voice' in override:
selected_voice = override['voice']
for k in override:
if k not in settings:
continue
settings[k] = override[k]
return settings
if not parameters['delimiter']:
parameters['delimiter'] = "\n"
elif parameters['delimiter'] == "\\n":
parameters['delimiter'] = "\n"
if parameters['delimiter'] and parameters['delimiter'] != "" and parameters['delimiter'] in parameters['text']:
texts = parameters['text'].split(parameters['delimiter'])
else:
texts = split_and_recombine_text(parameters['text'])
full_start_time = time.time()
outdir = f"{args.results_folder}/{voice}/"
os.makedirs(outdir, exist_ok=True)
audio_cache = {}
volume_adjust = torchaudio.transforms.Vol(gain=args.output_volume, gain_type="amplitude") if args.output_volume != 1 else None
idx = 0
idx_cache = {}
for i, file in enumerate(os.listdir(outdir)):
filename = os.path.basename(file)
extension = os.path.splitext(filename)[1]
if extension != ".json" and extension != ".wav":
continue
match = re.findall(rf"^{cleanup_voice_name(voice)}_(\d+)(?:.+?)?{extension}$", filename)
if match and len(match) > 0:
key = int(match[0])
idx_cache[key] = True
if len(idx_cache) > 0:
keys = sorted(list(idx_cache.keys()))
idx = keys[-1] + 1
idx = pad(idx, 4)
def get_name(line=0, candidate=0, combined=False):
name = f"{idx}"
if combined:
name = f"{name}_combined"
elif len(texts) > 1:
name = f"{name}_{line}"
if parameters['candidates'] > 1:
name = f"{name}_{candidate}"
return name
def get_info( voice, settings = None, latents = True ):
info = {}
info.update(parameters)
info['time'] = time.time()-full_start_time
info['datetime'] = datetime.now().isoformat()
info['progress'] = None
del info['progress']
if info['delimiter'] == "\n":
info['delimiter'] = "\\n"
if settings is not None:
for k in settings:
if k in info:
info[k] = settings[k]
return info
INFERENCING = True
for line, cut_text in enumerate(texts):
progress.msg_prefix = f'[{str(line+1)}/{str(len(texts))}]'
print(f"{progress.msg_prefix} Generating line: {cut_text}")
start_time = time.time()
# do setting editing
match = re.findall(r'^(\{.+\}) (.+?)$', cut_text)
override = None
if match and len(match) > 0:
match = match[0]
try:
override = json.loads(match[0])
cut_text = match[1].strip()
except Exception as e:
raise Exception("Prompt settings editing requested, but received invalid JSON")
settings = get_settings( override=override )
gen = tts.inference(cut_text, **settings )
run_time = time.time()-start_time
print(f"Generating line took {run_time} seconds")
if not isinstance(gen, list):
gen = [gen]
for j, g in enumerate(gen):
wav, sr = g
name = get_name(line=line, candidate=j)
settings['text'] = cut_text
settings['time'] = run_time
settings['datetime'] = datetime.now().isoformat()
# save here in case some error happens mid-batch
#torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu(), sr)
write_wav(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', sr, wav)
wav, sr = torchaudio.load(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
audio_cache[name] = {
'audio': wav,
'settings': get_info(voice=override['voice'] if override and 'voice' in override else voice, settings=settings)
}
del gen
do_gc()
INFERENCING = False
for k in audio_cache:
audio = audio_cache[k]['audio']
audio, _ = resample(audio, tts.output_sample_rate, args.output_sample_rate)
if volume_adjust is not None:
audio = volume_adjust(audio)
audio_cache[k]['audio'] = audio
torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{k}.wav', audio, args.output_sample_rate)
output_voices = []
for candidate in range(parameters['candidates']):
if len(texts) > 1:
audio_clips = []
for line in range(len(texts)):
name = get_name(line=line, candidate=candidate)
audio = audio_cache[name]['audio']
audio_clips.append(audio)
name = get_name(candidate=candidate, combined=True)
audio = torch.cat(audio_clips, dim=-1)
torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', audio, args.output_sample_rate)
audio = audio.squeeze(0).cpu()
audio_cache[name] = {
'audio': audio,
'settings': get_info(voice=voice),
'output': True
}
else:
name = get_name(candidate=candidate)
audio_cache[name]['output'] = True
if args.voice_fixer:
if not voicefixer:
progress(0, "Loading voicefix...")
load_voicefixer()
try:
fixed_cache = {}
for name in progress.tqdm(audio_cache, desc="Running voicefix..."):
del audio_cache[name]['audio']
if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
continue
path = f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav'
fixed = f'{outdir}/{cleanup_voice_name(voice)}_{name}_fixed.wav'
voicefixer.restore(
input=path,
output=fixed,
cuda=get_device_name() == "cuda" and args.voice_fixer_use_cuda,
#mode=mode,
)
fixed_cache[f'{name}_fixed'] = {
'settings': audio_cache[name]['settings'],
'output': True
}
audio_cache[name]['output'] = False
for name in fixed_cache:
audio_cache[name] = fixed_cache[name]
except Exception as e:
print(e)
print("\nFailed to run Voicefixer")
for name in audio_cache:
if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
if args.prune_nonfinal_outputs:
audio_cache[name]['pruned'] = True
os.remove(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
continue
output_voices.append(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
if not args.embed_output_metadata:
with open(f'{outdir}/{cleanup_voice_name(voice)}_{name}.json', 'w', encoding="utf-8") as f:
f.write(json.dumps(audio_cache[name]['settings'], indent='\t') )
if args.embed_output_metadata:
for name in progress.tqdm(audio_cache, desc="Embedding metadata..."):
if 'pruned' in audio_cache[name] and audio_cache[name]['pruned']:
continue
metadata = music_tag.load_file(f"{outdir}/{cleanup_voice_name(voice)}_{name}.wav")
metadata['lyrics'] = json.dumps(audio_cache[name]['settings'])
metadata.save()
if sample_voice is not None:
sample_voice = (tts.input_sample_rate, sample_voice.numpy())
info = get_info(voice=voice, latents=False)
print(f"Generation took {info['time']} seconds, saved to '{output_voices[0]}'\n")
info['seed'] = usedSeed
if 'latents' in info:
del info['latents']
os.makedirs('./config/', exist_ok=True)
with open(f'./config/generate.json', 'w', encoding="utf-8") as f:
f.write(json.dumps(info, indent='\t') )
stats = [
[ parameters['seed'], "{:.3f}".format(info['time']) ]
]
return (
sample_voice,
output_voices,
stats,
)
def generate_valle(**kwargs): def generate_valle(**kwargs):
parameters = {} parameters = {}
@ -289,9 +666,9 @@ def generate_valle(**kwargs):
settings['datetime'] = datetime.now().isoformat() settings['datetime'] = datetime.now().isoformat()
# save here in case some error happens mid-batch # save here in case some error happens mid-batch
#torchaudio.save(f'{outdir}/{voice}_{name}.wav', wav.cpu(), sr) #torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu(), sr)
soundfile.write(f'{outdir}/{voice}_{name}.wav', wav.cpu()[0,0], sr) soundfile.write(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu()[0,0], sr)
wav, sr = torchaudio.load(f'{outdir}/{voice}_{name}.wav') wav, sr = torchaudio.load(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
audio_cache[name] = { audio_cache[name] = {
'audio': wav, 'audio': wav,
@ -310,7 +687,7 @@ def generate_valle(**kwargs):
audio = volume_adjust(audio) audio = volume_adjust(audio)
audio_cache[k]['audio'] = audio audio_cache[k]['audio'] = audio
torchaudio.save(f'{outdir}/{voice}_{k}.wav', audio, args.output_sample_rate) torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{k}.wav', audio, args.output_sample_rate)
output_voices = [] output_voices = []
for candidate in range(parameters['candidates']): for candidate in range(parameters['candidates']):
@ -323,7 +700,7 @@ def generate_valle(**kwargs):
name = get_name(candidate=candidate, combined=True) name = get_name(candidate=candidate, combined=True)
audio = torch.cat(audio_clips, dim=-1) audio = torch.cat(audio_clips, dim=-1)
torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, args.output_sample_rate) torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', audio, args.output_sample_rate)
audio = audio.squeeze(0).cpu() audio = audio.squeeze(0).cpu()
audio_cache[name] = { audio_cache[name] = {
@ -348,8 +725,8 @@ def generate_valle(**kwargs):
if 'output' not in audio_cache[name] or not audio_cache[name]['output']: if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
continue continue
path = f'{outdir}/{voice}_{name}.wav' path = f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav'
fixed = f'{outdir}/{voice}_{name}_fixed.wav' fixed = f'{outdir}/{cleanup_voice_name(voice)}_{name}_fixed.wav'
voicefixer.restore( voicefixer.restore(
input=path, input=path,
output=fixed, output=fixed,
@ -373,13 +750,13 @@ def generate_valle(**kwargs):
if 'output' not in audio_cache[name] or not audio_cache[name]['output']: if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
if args.prune_nonfinal_outputs: if args.prune_nonfinal_outputs:
audio_cache[name]['pruned'] = True audio_cache[name]['pruned'] = True
os.remove(f'{outdir}/{voice}_{name}.wav') os.remove(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
continue continue
output_voices.append(f'{outdir}/{voice}_{name}.wav') output_voices.append(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
if not args.embed_output_metadata: if not args.embed_output_metadata:
with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f: with open(f'{outdir}/{cleanup_voice_name(voice)}_{name}.json', 'w', encoding="utf-8") as f:
f.write(json.dumps(audio_cache[name]['settings'], indent='\t') ) f.write(json.dumps(audio_cache[name]['settings'], indent='\t') )
if args.embed_output_metadata: if args.embed_output_metadata:
@ -387,7 +764,7 @@ def generate_valle(**kwargs):
if 'pruned' in audio_cache[name] and audio_cache[name]['pruned']: if 'pruned' in audio_cache[name] and audio_cache[name]['pruned']:
continue continue
metadata = music_tag.load_file(f"{outdir}/{voice}_{name}.wav") metadata = music_tag.load_file(f"{outdir}/{cleanup_voice_name(voice)}_{name}.wav")
metadata['lyrics'] = json.dumps(audio_cache[name]['settings']) metadata['lyrics'] = json.dumps(audio_cache[name]['settings'])
metadata.save() metadata.save()
@ -415,8 +792,6 @@ def generate_valle(**kwargs):
stats, stats,
) )
def generate_tortoise(**kwargs): def generate_tortoise(**kwargs):
parameters = {} parameters = {}
parameters.update(kwargs) parameters.update(kwargs)
@ -698,7 +1073,7 @@ def generate_tortoise(**kwargs):
'settings': get_info(voice=override['voice'] if override and 'voice' in override else voice, settings=settings) 'settings': get_info(voice=override['voice'] if override and 'voice' in override else voice, settings=settings)
} }
# save here in case some error happens mid-batch # save here in case some error happens mid-batch
torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, tts.output_sample_rate) torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', audio, tts.output_sample_rate)
del gen del gen
do_gc() do_gc()
@ -712,7 +1087,7 @@ def generate_tortoise(**kwargs):
audio = volume_adjust(audio) audio = volume_adjust(audio)
audio_cache[k]['audio'] = audio audio_cache[k]['audio'] = audio
torchaudio.save(f'{outdir}/{voice}_{k}.wav', audio, args.output_sample_rate) torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{k}.wav', audio, args.output_sample_rate)
output_voices = [] output_voices = []
for candidate in range(parameters['candidates']): for candidate in range(parameters['candidates']):
@ -725,7 +1100,7 @@ def generate_tortoise(**kwargs):
name = get_name(candidate=candidate, combined=True) name = get_name(candidate=candidate, combined=True)
audio = torch.cat(audio_clips, dim=-1) audio = torch.cat(audio_clips, dim=-1)
torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, args.output_sample_rate) torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', audio, args.output_sample_rate)
audio = audio.squeeze(0).cpu() audio = audio.squeeze(0).cpu()
audio_cache[name] = { audio_cache[name] = {
@ -750,8 +1125,8 @@ def generate_tortoise(**kwargs):
if 'output' not in audio_cache[name] or not audio_cache[name]['output']: if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
continue continue
path = f'{outdir}/{voice}_{name}.wav' path = f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav'
fixed = f'{outdir}/{voice}_{name}_fixed.wav' fixed = f'{outdir}/{cleanup_voice_name(voice)}_{name}_fixed.wav'
voicefixer.restore( voicefixer.restore(
input=path, input=path,
output=fixed, output=fixed,
@ -775,13 +1150,13 @@ def generate_tortoise(**kwargs):
if 'output' not in audio_cache[name] or not audio_cache[name]['output']: if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
if args.prune_nonfinal_outputs: if args.prune_nonfinal_outputs:
audio_cache[name]['pruned'] = True audio_cache[name]['pruned'] = True
os.remove(f'{outdir}/{voice}_{name}.wav') os.remove(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
continue continue
output_voices.append(f'{outdir}/{voice}_{name}.wav') output_voices.append(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
if not args.embed_output_metadata: if not args.embed_output_metadata:
with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f: with open(f'{outdir}/{cleanup_voice_name(voice)}_{name}.json', 'w', encoding="utf-8") as f:
f.write(json.dumps(audio_cache[name]['settings'], indent='\t') ) f.write(json.dumps(audio_cache[name]['settings'], indent='\t') )
if args.embed_output_metadata: if args.embed_output_metadata:
@ -789,7 +1164,7 @@ def generate_tortoise(**kwargs):
if 'pruned' in audio_cache[name] and audio_cache[name]['pruned']: if 'pruned' in audio_cache[name] and audio_cache[name]['pruned']:
continue continue
metadata = music_tag.load_file(f"{outdir}/{voice}_{name}.wav") metadata = music_tag.load_file(f"{outdir}/{cleanup_voice_name(voice)}_{name}.wav")
metadata['lyrics'] = json.dumps(audio_cache[name]['settings']) metadata['lyrics'] = json.dumps(audio_cache[name]['settings'])
metadata.save() metadata.save()
@ -1096,9 +1471,9 @@ class TrainingState():
'ar-half.loss', 'nar-half.loss', 'ar-half+nar-half.loss', 'ar-half.loss', 'nar-half.loss', 'ar-half+nar-half.loss',
'ar-quarter.loss', 'nar-quarter.loss', 'ar-quarter+nar-quarter.loss', 'ar-quarter.loss', 'nar-quarter.loss', 'ar-quarter+nar-quarter.loss',
# 'ar.loss.nll', 'nar.loss.nll', 'ar.loss.nll', 'nar.loss.nll',
# 'ar-half.loss.nll', 'nar-half.loss.nll', 'ar-half.loss.nll', 'nar-half.loss.nll',
# 'ar-quarter.loss.nll', 'nar-quarter.loss.nll', 'ar-quarter.loss.nll', 'nar-quarter.loss.nll',
] ]
keys['accuracies'] = [ keys['accuracies'] = [
@ -1464,14 +1839,14 @@ def run_training(config_path, verbose=False, keep_x_past_checkpoints=0, progress
return_code = training_state.process.wait() return_code = training_state.process.wait()
training_state = None training_state = None
def update_training_dataplot(x_lim=None, y_lim=None, config_path=None): def update_training_dataplot(x_min=None, x_max=None, y_min=None, y_max=None, config_path=None):
global training_state global training_state
losses = None losses = None
lrs = None lrs = None
grad_norms = None grad_norms = None
x_lim = [ 0, x_lim ] x_lim = [ x_min, x_max ]
y_lim = [ 0, y_lim ] y_lim = [ y_min, y_max ]
if not training_state: if not training_state:
if config_path: if config_path:
@ -1490,23 +1865,23 @@ def update_training_dataplot(x_lim=None, y_lim=None, config_path=None):
losses = gr.LinePlot.update( losses = gr.LinePlot.update(
value = pd.DataFrame(training_state.statistics['loss']), value = pd.DataFrame(training_state.statistics['loss']),
x_lim=x_lim, y_lim=y_lim, x_lim=x_lim, y_lim=y_lim,
x="epoch", y="value", x="it", y="value", # x="epoch",
title="Loss Metrics", color="type", tooltip=['epoch', 'it', 'value', 'type'], title="Loss Metrics", color="type", tooltip=['epoch', 'it', 'value', 'type'],
width=500, height=350 width=500, height=350
) )
if len(training_state.statistics['lr']) > 0: if len(training_state.statistics['lr']) > 0:
lrs = gr.LinePlot.update( lrs = gr.LinePlot.update(
value = pd.DataFrame(training_state.statistics['lr']), value = pd.DataFrame(training_state.statistics['lr']),
x_lim=x_lim, y_lim=y_lim, x_lim=x_lim,
x="epoch", y="value", x="it", y="value", # x="epoch",
title="Learning Rate", color="type", tooltip=['epoch', 'it', 'value', 'type'], title="Learning Rate", color="type", tooltip=['epoch', 'it', 'value', 'type'],
width=500, height=350 width=500, height=350
) )
if len(training_state.statistics['grad_norm']) > 0: if len(training_state.statistics['grad_norm']) > 0:
grad_norms = gr.LinePlot.update( grad_norms = gr.LinePlot.update(
value = pd.DataFrame(training_state.statistics['grad_norm']), value = pd.DataFrame(training_state.statistics['grad_norm']),
x_lim=x_lim, y_lim=y_lim, x_lim=x_lim,
x="epoch", y="value", x="it", y="value", # x="epoch",
title="Gradient Normals", color="type", tooltip=['epoch', 'it', 'value', 'type'], title="Gradient Normals", color="type", tooltip=['epoch', 'it', 'value', 'type'],
width=500, height=350 width=500, height=350
) )
@ -1649,13 +2024,13 @@ def whisper_transcribe( file, language=None ):
device = "cuda" if get_device_name() == "cuda" else "cpu" device = "cuda" if get_device_name() == "cuda" else "cpu"
if whisper_vad: if whisper_vad:
# omits a considerable amount of the end # omits a considerable amount of the end
"""
if args.whisper_batchsize > 1: if args.whisper_batchsize > 1:
result = whisperx.transcribe_with_vad_parallel(whisper_model, file, whisper_vad, batch_size=args.whisper_batchsize, language=language, task="transcribe") result = whisperx.transcribe_with_vad_parallel(whisper_model, file, whisper_vad, batch_size=args.whisper_batchsize, language=language, task="transcribe")
else: else:
result = whisperx.transcribe_with_vad(whisper_model, file, whisper_vad) result = whisperx.transcribe_with_vad(whisper_model, file, whisper_vad)
""" """
result = whisperx.transcribe_with_vad(whisper_model, file, whisper_vad) result = whisperx.transcribe_with_vad(whisper_model, file, whisper_vad)
"""
else: else:
result = whisper_model.transcribe(file) result = whisper_model.transcribe(file)
@ -1717,7 +2092,7 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non
os.makedirs(f'{indir}/audio/', exist_ok=True) os.makedirs(f'{indir}/audio/', exist_ok=True)
TARGET_SAMPLE_RATE = 22050 TARGET_SAMPLE_RATE = 22050
if args.tts_backend == "vall-e": if args.tts_backend != "tortoise":
TARGET_SAMPLE_RATE = 24000 TARGET_SAMPLE_RATE = 24000
if tts: if tts:
TARGET_SAMPLE_RATE = tts.input_sample_rate TARGET_SAMPLE_RATE = tts.input_sample_rate
@ -1735,7 +2110,7 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non
try: try:
result = whisper_transcribe(file, language=language) result = whisper_transcribe(file, language=language)
except Exception as e: except Exception as e:
print("Failed to transcribe:", file) print("Failed to transcribe:", file, e)
continue continue
results[basename] = result results[basename] = result
@ -1802,7 +2177,7 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, resul
results = json.load(open(infile, 'r', encoding="utf-8")) results = json.load(open(infile, 'r', encoding="utf-8"))
TARGET_SAMPLE_RATE = 22050 TARGET_SAMPLE_RATE = 22050
if args.tts_backend == "vall-e": if args.tts_backend != "tortoise":
TARGET_SAMPLE_RATE = 24000 TARGET_SAMPLE_RATE = 24000
if tts: if tts:
TARGET_SAMPLE_RATE = tts.input_sample_rate TARGET_SAMPLE_RATE = tts.input_sample_rate
@ -1934,8 +2309,7 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
lines = { 'training': [], 'validation': [] } lines = { 'training': [], 'validation': [] }
segments = {} segments = {}
# I'm not sure how the VALL-E implementation decides what's validation and what's not if args.tts_backend != "tortoise":
if args.tts_backend == "vall-e":
text_length = 0 text_length = 0
audio_length = 0 audio_length = 0
@ -3008,6 +3382,10 @@ def load_tts( restart=False,
print(f"Loading VALL-E... (Config: {valle_model})") print(f"Loading VALL-E... (Config: {valle_model})")
tts = VALLE_TTS(config=args.valle_model) tts = VALLE_TTS(config=args.valle_model)
elif args.tts_backend == "bark":
print(f"Loading Bark...")
tts = Bark_TTS(small=args.low_vram)
print("Loaded TTS, ready for generation.") print("Loaded TTS, ready for generation.")
tts_loading = False tts_loading = False

View File

@ -167,6 +167,10 @@ def reset_generate_settings_proxy():
return tuple(res) return tuple(res)
def compute_latents_proxy(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)): def compute_latents_proxy(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
if args.tts_backend == "bark":
global tts
tts.create_voice( voice )
return voice
compute_latents( voice=voice, voice_latents_chunks=voice_latents_chunks, progress=progress ) compute_latents( voice=voice, voice_latents_chunks=voice_latents_chunks, progress=progress )
return voice return voice
@ -222,13 +226,13 @@ def prepare_all_datasets( language, validation_text_length, validation_audio_len
print("Processing:", voice) print("Processing:", voice)
message = transcribe_dataset( voice=voice, language=language, skip_existings=skip_existings, progress=progress ) message = transcribe_dataset( voice=voice, language=language, skip_existings=skip_existings, progress=progress )
messages.append(message) messages.append(message)
"""
if slice_audio: if slice_audio:
for voice in voices: for voice in voices:
print("Processing:", voice) print("Processing:", voice)
message = slice_dataset( voice, trim_silence=trim_silence, start_offset=slice_start_offset, end_offset=slice_end_offset, results=None, progress=progress ) message = slice_dataset( voice, trim_silence=trim_silence, start_offset=slice_start_offset, end_offset=slice_end_offset, results=None, progress=progress )
messages.append(message) messages.append(message)
"""
for voice in voices: for voice in voices:
print("Processing:", voice) print("Processing:", voice)
@ -400,12 +404,13 @@ def setup_gradio():
outputs=GENERATE_SETTINGS["mic_audio"], outputs=GENERATE_SETTINGS["mic_audio"],
) )
with gr.Column(): with gr.Column():
preset = None
GENERATE_SETTINGS["candidates"] = gr.Slider(value=1, minimum=1, maximum=6, step=1, label="Candidates", visible=args.tts_backend=="tortoise") GENERATE_SETTINGS["candidates"] = gr.Slider(value=1, minimum=1, maximum=6, step=1, label="Candidates", visible=args.tts_backend=="tortoise")
GENERATE_SETTINGS["seed"] = gr.Number(value=0, precision=0, label="Seed") GENERATE_SETTINGS["seed"] = gr.Number(value=0, precision=0, label="Seed", visible=args.tts_backend!="tortoise")
preset = gr.Radio( ["Ultra Fast", "Fast", "Standard", "High Quality"], label="Preset", type="value", value="Ultra Fast" ) preset = gr.Radio( ["Ultra Fast", "Fast", "Standard", "High Quality"], label="Preset", type="value", value="Ultra Fast", visible=args.tts_backend=="tortoise" )
GENERATE_SETTINGS["num_autoregressive_samples"] = gr.Slider(value=16, minimum=2, maximum=512, step=1, label="Samples") GENERATE_SETTINGS["num_autoregressive_samples"] = gr.Slider(value=16, minimum=2, maximum=512, step=1, label="Samples", visible=args.tts_backend!="bark")
GENERATE_SETTINGS["diffusion_iterations"] = gr.Slider(value=30, minimum=0, maximum=512, step=1, label="Iterations", visible=args.tts_backend=="tortoise") GENERATE_SETTINGS["diffusion_iterations"] = gr.Slider(value=30, minimum=0, maximum=512, step=1, label="Iterations", visible=args.tts_backend=="tortoise")
GENERATE_SETTINGS["temperature"] = gr.Slider(value=0.2, minimum=0, maximum=1, step=0.1, label="Temperature") GENERATE_SETTINGS["temperature"] = gr.Slider(value=0.2, minimum=0, maximum=1, step=0.1, label="Temperature")
@ -490,7 +495,7 @@ def setup_gradio():
merger_button = gr.Button(value="Run Merger") merger_button = gr.Button(value="Run Merger")
with gr.Column(): with gr.Column():
merger_output = gr.TextArea(label="Console Output", max_lines=8) merger_output = gr.TextArea(label="Console Output", max_lines=8)
with gr.Tab("Training"): with gr.Tab("Training", visible=args.tts_backend != "bark"):
with gr.Tab("Prepare Dataset"): with gr.Tab("Prepare Dataset"):
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
@ -586,8 +591,10 @@ def setup_gradio():
keep_x_past_checkpoints = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1) keep_x_past_checkpoints = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
with gr.Row(): with gr.Row():
training_graph_x_lim = gr.Number(label="X Limit", precision=0, value=0) training_graph_x_min = gr.Number(label="X Min", precision=0, value=0)
training_graph_y_lim = gr.Number(label="Y Limit", precision=0, value=0) training_graph_x_max = gr.Number(label="X Max", precision=0, value=0)
training_graph_y_min = gr.Number(label="Y Min", precision=0, value=0)
training_graph_y_max = gr.Number(label="Y Max", precision=0, value=0)
with gr.Row(): with gr.Row():
start_training_button = gr.Button(value="Train") start_training_button = gr.Button(value="Train")
@ -597,7 +604,7 @@ def setup_gradio():
with gr.Column(): with gr.Column():
training_loss_graph = gr.LinePlot(label="Training Metrics", training_loss_graph = gr.LinePlot(label="Training Metrics",
x="epoch", x="it", # x="epoch",
y="value", y="value",
title="Loss Metrics", title="Loss Metrics",
color="type", color="type",
@ -606,7 +613,7 @@ def setup_gradio():
height=350, height=350,
) )
training_lr_graph = gr.LinePlot(label="Training Metrics", training_lr_graph = gr.LinePlot(label="Training Metrics",
x="epoch", x="it", # x="epoch",
y="value", y="value",
title="Learning Rate", title="Learning Rate",
color="type", color="type",
@ -615,7 +622,7 @@ def setup_gradio():
height=350, height=350,
) )
training_grad_norm_graph = gr.LinePlot(label="Training Metrics", training_grad_norm_graph = gr.LinePlot(label="Training Metrics",
x="epoch", x="it", # x="epoch",
y="value", y="value",
title="Gradient Normals", title="Gradient Normals",
color="type", color="type",
@ -765,13 +772,14 @@ def setup_gradio():
inputs=show_experimental_settings, inputs=show_experimental_settings,
outputs=experimental_column outputs=experimental_column
) )
preset.change(fn=update_presets, if preset:
inputs=preset, preset.change(fn=update_presets,
outputs=[ inputs=preset,
GENERATE_SETTINGS['num_autoregressive_samples'], outputs=[
GENERATE_SETTINGS['diffusion_iterations'], GENERATE_SETTINGS['num_autoregressive_samples'],
], GENERATE_SETTINGS['diffusion_iterations'],
) ],
)
recompute_voice_latents.click(compute_latents_proxy, recompute_voice_latents.click(compute_latents_proxy,
inputs=[ inputs=[
@ -860,8 +868,10 @@ def setup_gradio():
training_output.change( training_output.change(
fn=update_training_dataplot, fn=update_training_dataplot,
inputs=[ inputs=[
training_graph_x_lim, training_graph_x_min,
training_graph_y_lim, training_graph_x_max,
training_graph_y_min,
training_graph_y_max,
], ],
outputs=[ outputs=[
training_loss_graph, training_loss_graph,
@ -874,8 +884,10 @@ def setup_gradio():
view_losses.click( view_losses.click(
fn=update_training_dataplot, fn=update_training_dataplot,
inputs=[ inputs=[
training_graph_x_lim, training_graph_x_min,
training_graph_y_lim, training_graph_x_max,
training_graph_y_min,
training_graph_y_max,
training_configs, training_configs,
], ],
outputs=[ outputs=[