Very, very, VERY, barebones integration with Bark (documentation soon)
This commit is contained in:
parent
faa8da12d7
commit
b6440091fb
456
src/utils.py
456
src/utils.py
|
@ -31,6 +31,7 @@ import music_tag
|
|||
import gradio as gr
|
||||
import gradio.utils
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
from glob import glob
|
||||
from datetime import datetime
|
||||
|
@ -65,6 +66,7 @@ MIN_TRAINING_DURATION = 0.6
|
|||
MAX_TRAINING_DURATION = 11.6097505669
|
||||
|
||||
VALLE_ENABLED = False
|
||||
BARK_ENABLED = False
|
||||
|
||||
try:
|
||||
from vall_e.emb.qnt import encode as valle_quantize
|
||||
|
@ -76,11 +78,98 @@ try:
|
|||
|
||||
VALLE_ENABLED = True
|
||||
except Exception as e:
|
||||
if False: # args.tts_backend == "vall-e":
|
||||
raise e
|
||||
pass
|
||||
|
||||
if VALLE_ENABLED:
|
||||
TTSES.append('vall-e')
|
||||
|
||||
try:
|
||||
from bark.generation import SAMPLE_RATE as BARK_SAMPLE_RATE, ALLOWED_PROMPTS, preload_models, codec_decode, generate_coarse, generate_fine, generate_text_semantic, load_codec_model
|
||||
from bark.api import generate_audio as bark_generate_audio
|
||||
from encodec.utils import convert_audio
|
||||
|
||||
from scipy.io.wavfile import write as write_wav
|
||||
|
||||
BARK_ENABLED = True
|
||||
except Exception as e:
|
||||
if False: # args.tts_backend == "bark":
|
||||
raise e
|
||||
pass
|
||||
|
||||
if BARK_ENABLED:
|
||||
TTSES.append('bark')
|
||||
class Bark_TTS():
|
||||
def __init__(self, small=False):
|
||||
self.input_sample_rate = BARK_SAMPLE_RATE
|
||||
self.output_sample_rate = args.output_sample_rate
|
||||
|
||||
preload_models(
|
||||
text_use_gpu=True,
|
||||
coarse_use_gpu=True,
|
||||
fine_use_gpu=True,
|
||||
codec_use_gpu=True,
|
||||
|
||||
text_use_small=small,
|
||||
coarse_use_small=small,
|
||||
fine_use_small=small,
|
||||
|
||||
force_reload=False
|
||||
)
|
||||
|
||||
def create_voice( self, voice, device='cuda' ):
|
||||
transcription_json = f'./training/{voice}/whisper.json'
|
||||
if not os.path.exists(transcription_json):
|
||||
raise f"Transcription for voice not found: {voice}"
|
||||
|
||||
transcriptions = json.load(open(transcription_json, 'r', encoding="utf-8"))
|
||||
candidates = []
|
||||
for file in transcriptions:
|
||||
result = transcriptions[file]
|
||||
for segment in result['segments']:
|
||||
entry = (
|
||||
file.replace(".wav", f"_{pad(segment['id'], 4)}.wav"),
|
||||
segment['end'] - segment['start'],
|
||||
segment['text']
|
||||
)
|
||||
candidates.append(entry)
|
||||
|
||||
candidates.sort(key=lambda x: x[1])
|
||||
candidate = random.choice(candidates)
|
||||
audio_filepath = f'./training/{voice}/audio/{candidate[0]}'
|
||||
text = candidate[-1]
|
||||
|
||||
print("Using as reference:", audio_filepath, text)
|
||||
|
||||
# Load and pre-process the audio waveform
|
||||
model = load_codec_model(use_gpu=True)
|
||||
wav, sr = torchaudio.load(audio_filepath)
|
||||
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
|
||||
wav = wav.unsqueeze(0).to(device)
|
||||
|
||||
# Extract discrete codes from EnCodec
|
||||
with torch.no_grad():
|
||||
encoded_frames = model.encode(wav)
|
||||
codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze().cpu().numpy() # [n_q, T]
|
||||
|
||||
# get seconds of audio
|
||||
seconds = wav.shape[-1] / model.sample_rate
|
||||
# generate semantic tokens
|
||||
semantic_tokens = generate_text_semantic(text, max_gen_duration_s=seconds, top_k=50, top_p=.95, temp=0.7)
|
||||
|
||||
output_path = './modules/bark/bark/assets/prompts/' + voice.replace("/", "_") + '.npz'
|
||||
np.savez(output_path, fine_prompt=codes, coarse_prompt=codes[:2, :], semantic_prompt=semantic_tokens)
|
||||
|
||||
def inference( self, text, voice, text_temp=0.7, waveform_temp=0.7 ):
|
||||
if not os.path.exists('./modules/bark/bark/assets/prompts/' + voice + '.npz'):
|
||||
self.create_voice( voice )
|
||||
voice = voice.replace("/", "_")
|
||||
if voice not in ALLOWED_PROMPTS:
|
||||
ALLOWED_PROMPTS.add( voice )
|
||||
|
||||
return (bark_generate_audio(text, history_prompt=voice, text_temp=text_temp, waveform_temp=waveform_temp), BARK_SAMPLE_RATE)
|
||||
|
||||
args = None
|
||||
tts = None
|
||||
tts_loading = False
|
||||
|
@ -96,6 +185,9 @@ training_state = None
|
|||
|
||||
current_voice = None
|
||||
|
||||
def cleanup_voice_name( name ):
|
||||
return name.split("/")[-1]
|
||||
|
||||
def resample( waveform, input_rate, output_rate=44100 ):
|
||||
# mono-ize
|
||||
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
||||
|
@ -121,6 +213,291 @@ def generate(**kwargs):
|
|||
return generate_tortoise(**kwargs)
|
||||
if args.tts_backend == "vall-e":
|
||||
return generate_valle(**kwargs)
|
||||
if args.tts_backend == "bark":
|
||||
return generate_bark(**kwargs)
|
||||
|
||||
def generate_bark(**kwargs):
|
||||
parameters = {}
|
||||
parameters.update(kwargs)
|
||||
|
||||
voice = parameters['voice']
|
||||
progress = parameters['progress'] if 'progress' in parameters else None
|
||||
if parameters['seed'] == 0:
|
||||
parameters['seed'] = None
|
||||
|
||||
usedSeed = parameters['seed']
|
||||
|
||||
global args
|
||||
global tts
|
||||
|
||||
unload_whisper()
|
||||
unload_voicefixer()
|
||||
|
||||
if not tts:
|
||||
# should check if it's loading or unloaded, and load it if it's unloaded
|
||||
if tts_loading:
|
||||
raise Exception("TTS is still initializing...")
|
||||
if progress is not None:
|
||||
progress(0, "Initializing TTS...")
|
||||
load_tts()
|
||||
if hasattr(tts, "loading") and tts.loading:
|
||||
raise Exception("TTS is still initializing...")
|
||||
|
||||
do_gc()
|
||||
|
||||
voice_samples = None
|
||||
conditioning_latents = None
|
||||
sample_voice = None
|
||||
|
||||
voice_cache = {}
|
||||
|
||||
def get_settings( override=None ):
|
||||
settings = {
|
||||
'voice': parameters['voice'],
|
||||
'text_temp': float(parameters['temperature']),
|
||||
'waveform_temp': float(parameters['temperature']),
|
||||
}
|
||||
|
||||
# could be better to just do a ternary on everything above, but i am not a professional
|
||||
selected_voice = voice
|
||||
if override is not None:
|
||||
if 'voice' in override:
|
||||
selected_voice = override['voice']
|
||||
|
||||
for k in override:
|
||||
if k not in settings:
|
||||
continue
|
||||
settings[k] = override[k]
|
||||
|
||||
return settings
|
||||
|
||||
if not parameters['delimiter']:
|
||||
parameters['delimiter'] = "\n"
|
||||
elif parameters['delimiter'] == "\\n":
|
||||
parameters['delimiter'] = "\n"
|
||||
|
||||
if parameters['delimiter'] and parameters['delimiter'] != "" and parameters['delimiter'] in parameters['text']:
|
||||
texts = parameters['text'].split(parameters['delimiter'])
|
||||
else:
|
||||
texts = split_and_recombine_text(parameters['text'])
|
||||
|
||||
full_start_time = time.time()
|
||||
|
||||
outdir = f"{args.results_folder}/{voice}/"
|
||||
os.makedirs(outdir, exist_ok=True)
|
||||
|
||||
audio_cache = {}
|
||||
|
||||
volume_adjust = torchaudio.transforms.Vol(gain=args.output_volume, gain_type="amplitude") if args.output_volume != 1 else None
|
||||
|
||||
idx = 0
|
||||
idx_cache = {}
|
||||
for i, file in enumerate(os.listdir(outdir)):
|
||||
filename = os.path.basename(file)
|
||||
extension = os.path.splitext(filename)[1]
|
||||
if extension != ".json" and extension != ".wav":
|
||||
continue
|
||||
match = re.findall(rf"^{cleanup_voice_name(voice)}_(\d+)(?:.+?)?{extension}$", filename)
|
||||
if match and len(match) > 0:
|
||||
key = int(match[0])
|
||||
idx_cache[key] = True
|
||||
|
||||
if len(idx_cache) > 0:
|
||||
keys = sorted(list(idx_cache.keys()))
|
||||
idx = keys[-1] + 1
|
||||
|
||||
idx = pad(idx, 4)
|
||||
|
||||
def get_name(line=0, candidate=0, combined=False):
|
||||
name = f"{idx}"
|
||||
if combined:
|
||||
name = f"{name}_combined"
|
||||
elif len(texts) > 1:
|
||||
name = f"{name}_{line}"
|
||||
if parameters['candidates'] > 1:
|
||||
name = f"{name}_{candidate}"
|
||||
return name
|
||||
|
||||
def get_info( voice, settings = None, latents = True ):
|
||||
info = {}
|
||||
info.update(parameters)
|
||||
|
||||
info['time'] = time.time()-full_start_time
|
||||
info['datetime'] = datetime.now().isoformat()
|
||||
|
||||
info['progress'] = None
|
||||
del info['progress']
|
||||
|
||||
if info['delimiter'] == "\n":
|
||||
info['delimiter'] = "\\n"
|
||||
|
||||
if settings is not None:
|
||||
for k in settings:
|
||||
if k in info:
|
||||
info[k] = settings[k]
|
||||
return info
|
||||
|
||||
INFERENCING = True
|
||||
for line, cut_text in enumerate(texts):
|
||||
progress.msg_prefix = f'[{str(line+1)}/{str(len(texts))}]'
|
||||
print(f"{progress.msg_prefix} Generating line: {cut_text}")
|
||||
start_time = time.time()
|
||||
|
||||
# do setting editing
|
||||
match = re.findall(r'^(\{.+\}) (.+?)$', cut_text)
|
||||
override = None
|
||||
if match and len(match) > 0:
|
||||
match = match[0]
|
||||
try:
|
||||
override = json.loads(match[0])
|
||||
cut_text = match[1].strip()
|
||||
except Exception as e:
|
||||
raise Exception("Prompt settings editing requested, but received invalid JSON")
|
||||
|
||||
settings = get_settings( override=override )
|
||||
|
||||
gen = tts.inference(cut_text, **settings )
|
||||
|
||||
run_time = time.time()-start_time
|
||||
print(f"Generating line took {run_time} seconds")
|
||||
|
||||
if not isinstance(gen, list):
|
||||
gen = [gen]
|
||||
|
||||
for j, g in enumerate(gen):
|
||||
wav, sr = g
|
||||
name = get_name(line=line, candidate=j)
|
||||
|
||||
settings['text'] = cut_text
|
||||
settings['time'] = run_time
|
||||
settings['datetime'] = datetime.now().isoformat()
|
||||
|
||||
# save here in case some error happens mid-batch
|
||||
#torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu(), sr)
|
||||
write_wav(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', sr, wav)
|
||||
wav, sr = torchaudio.load(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
|
||||
|
||||
audio_cache[name] = {
|
||||
'audio': wav,
|
||||
'settings': get_info(voice=override['voice'] if override and 'voice' in override else voice, settings=settings)
|
||||
}
|
||||
|
||||
del gen
|
||||
do_gc()
|
||||
INFERENCING = False
|
||||
|
||||
for k in audio_cache:
|
||||
audio = audio_cache[k]['audio']
|
||||
|
||||
audio, _ = resample(audio, tts.output_sample_rate, args.output_sample_rate)
|
||||
if volume_adjust is not None:
|
||||
audio = volume_adjust(audio)
|
||||
|
||||
audio_cache[k]['audio'] = audio
|
||||
torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{k}.wav', audio, args.output_sample_rate)
|
||||
|
||||
output_voices = []
|
||||
for candidate in range(parameters['candidates']):
|
||||
if len(texts) > 1:
|
||||
audio_clips = []
|
||||
for line in range(len(texts)):
|
||||
name = get_name(line=line, candidate=candidate)
|
||||
audio = audio_cache[name]['audio']
|
||||
audio_clips.append(audio)
|
||||
|
||||
name = get_name(candidate=candidate, combined=True)
|
||||
audio = torch.cat(audio_clips, dim=-1)
|
||||
torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', audio, args.output_sample_rate)
|
||||
|
||||
audio = audio.squeeze(0).cpu()
|
||||
audio_cache[name] = {
|
||||
'audio': audio,
|
||||
'settings': get_info(voice=voice),
|
||||
'output': True
|
||||
}
|
||||
else:
|
||||
name = get_name(candidate=candidate)
|
||||
audio_cache[name]['output'] = True
|
||||
|
||||
|
||||
if args.voice_fixer:
|
||||
if not voicefixer:
|
||||
progress(0, "Loading voicefix...")
|
||||
load_voicefixer()
|
||||
|
||||
try:
|
||||
fixed_cache = {}
|
||||
for name in progress.tqdm(audio_cache, desc="Running voicefix..."):
|
||||
del audio_cache[name]['audio']
|
||||
if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
|
||||
continue
|
||||
|
||||
path = f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav'
|
||||
fixed = f'{outdir}/{cleanup_voice_name(voice)}_{name}_fixed.wav'
|
||||
voicefixer.restore(
|
||||
input=path,
|
||||
output=fixed,
|
||||
cuda=get_device_name() == "cuda" and args.voice_fixer_use_cuda,
|
||||
#mode=mode,
|
||||
)
|
||||
|
||||
fixed_cache[f'{name}_fixed'] = {
|
||||
'settings': audio_cache[name]['settings'],
|
||||
'output': True
|
||||
}
|
||||
audio_cache[name]['output'] = False
|
||||
|
||||
for name in fixed_cache:
|
||||
audio_cache[name] = fixed_cache[name]
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("\nFailed to run Voicefixer")
|
||||
|
||||
for name in audio_cache:
|
||||
if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
|
||||
if args.prune_nonfinal_outputs:
|
||||
audio_cache[name]['pruned'] = True
|
||||
os.remove(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
|
||||
continue
|
||||
|
||||
output_voices.append(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
|
||||
|
||||
if not args.embed_output_metadata:
|
||||
with open(f'{outdir}/{cleanup_voice_name(voice)}_{name}.json', 'w', encoding="utf-8") as f:
|
||||
f.write(json.dumps(audio_cache[name]['settings'], indent='\t') )
|
||||
|
||||
if args.embed_output_metadata:
|
||||
for name in progress.tqdm(audio_cache, desc="Embedding metadata..."):
|
||||
if 'pruned' in audio_cache[name] and audio_cache[name]['pruned']:
|
||||
continue
|
||||
|
||||
metadata = music_tag.load_file(f"{outdir}/{cleanup_voice_name(voice)}_{name}.wav")
|
||||
metadata['lyrics'] = json.dumps(audio_cache[name]['settings'])
|
||||
metadata.save()
|
||||
|
||||
if sample_voice is not None:
|
||||
sample_voice = (tts.input_sample_rate, sample_voice.numpy())
|
||||
|
||||
info = get_info(voice=voice, latents=False)
|
||||
print(f"Generation took {info['time']} seconds, saved to '{output_voices[0]}'\n")
|
||||
|
||||
info['seed'] = usedSeed
|
||||
if 'latents' in info:
|
||||
del info['latents']
|
||||
|
||||
os.makedirs('./config/', exist_ok=True)
|
||||
with open(f'./config/generate.json', 'w', encoding="utf-8") as f:
|
||||
f.write(json.dumps(info, indent='\t') )
|
||||
|
||||
stats = [
|
||||
[ parameters['seed'], "{:.3f}".format(info['time']) ]
|
||||
]
|
||||
|
||||
return (
|
||||
sample_voice,
|
||||
output_voices,
|
||||
stats,
|
||||
)
|
||||
|
||||
def generate_valle(**kwargs):
|
||||
parameters = {}
|
||||
|
@ -289,9 +666,9 @@ def generate_valle(**kwargs):
|
|||
settings['datetime'] = datetime.now().isoformat()
|
||||
|
||||
# save here in case some error happens mid-batch
|
||||
#torchaudio.save(f'{outdir}/{voice}_{name}.wav', wav.cpu(), sr)
|
||||
soundfile.write(f'{outdir}/{voice}_{name}.wav', wav.cpu()[0,0], sr)
|
||||
wav, sr = torchaudio.load(f'{outdir}/{voice}_{name}.wav')
|
||||
#torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu(), sr)
|
||||
soundfile.write(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu()[0,0], sr)
|
||||
wav, sr = torchaudio.load(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
|
||||
|
||||
audio_cache[name] = {
|
||||
'audio': wav,
|
||||
|
@ -310,7 +687,7 @@ def generate_valle(**kwargs):
|
|||
audio = volume_adjust(audio)
|
||||
|
||||
audio_cache[k]['audio'] = audio
|
||||
torchaudio.save(f'{outdir}/{voice}_{k}.wav', audio, args.output_sample_rate)
|
||||
torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{k}.wav', audio, args.output_sample_rate)
|
||||
|
||||
output_voices = []
|
||||
for candidate in range(parameters['candidates']):
|
||||
|
@ -323,7 +700,7 @@ def generate_valle(**kwargs):
|
|||
|
||||
name = get_name(candidate=candidate, combined=True)
|
||||
audio = torch.cat(audio_clips, dim=-1)
|
||||
torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, args.output_sample_rate)
|
||||
torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', audio, args.output_sample_rate)
|
||||
|
||||
audio = audio.squeeze(0).cpu()
|
||||
audio_cache[name] = {
|
||||
|
@ -348,8 +725,8 @@ def generate_valle(**kwargs):
|
|||
if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
|
||||
continue
|
||||
|
||||
path = f'{outdir}/{voice}_{name}.wav'
|
||||
fixed = f'{outdir}/{voice}_{name}_fixed.wav'
|
||||
path = f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav'
|
||||
fixed = f'{outdir}/{cleanup_voice_name(voice)}_{name}_fixed.wav'
|
||||
voicefixer.restore(
|
||||
input=path,
|
||||
output=fixed,
|
||||
|
@ -373,13 +750,13 @@ def generate_valle(**kwargs):
|
|||
if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
|
||||
if args.prune_nonfinal_outputs:
|
||||
audio_cache[name]['pruned'] = True
|
||||
os.remove(f'{outdir}/{voice}_{name}.wav')
|
||||
os.remove(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
|
||||
continue
|
||||
|
||||
output_voices.append(f'{outdir}/{voice}_{name}.wav')
|
||||
output_voices.append(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
|
||||
|
||||
if not args.embed_output_metadata:
|
||||
with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f:
|
||||
with open(f'{outdir}/{cleanup_voice_name(voice)}_{name}.json', 'w', encoding="utf-8") as f:
|
||||
f.write(json.dumps(audio_cache[name]['settings'], indent='\t') )
|
||||
|
||||
if args.embed_output_metadata:
|
||||
|
@ -387,7 +764,7 @@ def generate_valle(**kwargs):
|
|||
if 'pruned' in audio_cache[name] and audio_cache[name]['pruned']:
|
||||
continue
|
||||
|
||||
metadata = music_tag.load_file(f"{outdir}/{voice}_{name}.wav")
|
||||
metadata = music_tag.load_file(f"{outdir}/{cleanup_voice_name(voice)}_{name}.wav")
|
||||
metadata['lyrics'] = json.dumps(audio_cache[name]['settings'])
|
||||
metadata.save()
|
||||
|
||||
|
@ -415,8 +792,6 @@ def generate_valle(**kwargs):
|
|||
stats,
|
||||
)
|
||||
|
||||
|
||||
|
||||
def generate_tortoise(**kwargs):
|
||||
parameters = {}
|
||||
parameters.update(kwargs)
|
||||
|
@ -698,7 +1073,7 @@ def generate_tortoise(**kwargs):
|
|||
'settings': get_info(voice=override['voice'] if override and 'voice' in override else voice, settings=settings)
|
||||
}
|
||||
# save here in case some error happens mid-batch
|
||||
torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, tts.output_sample_rate)
|
||||
torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', audio, tts.output_sample_rate)
|
||||
|
||||
del gen
|
||||
do_gc()
|
||||
|
@ -712,7 +1087,7 @@ def generate_tortoise(**kwargs):
|
|||
audio = volume_adjust(audio)
|
||||
|
||||
audio_cache[k]['audio'] = audio
|
||||
torchaudio.save(f'{outdir}/{voice}_{k}.wav', audio, args.output_sample_rate)
|
||||
torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{k}.wav', audio, args.output_sample_rate)
|
||||
|
||||
output_voices = []
|
||||
for candidate in range(parameters['candidates']):
|
||||
|
@ -725,7 +1100,7 @@ def generate_tortoise(**kwargs):
|
|||
|
||||
name = get_name(candidate=candidate, combined=True)
|
||||
audio = torch.cat(audio_clips, dim=-1)
|
||||
torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, args.output_sample_rate)
|
||||
torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', audio, args.output_sample_rate)
|
||||
|
||||
audio = audio.squeeze(0).cpu()
|
||||
audio_cache[name] = {
|
||||
|
@ -750,8 +1125,8 @@ def generate_tortoise(**kwargs):
|
|||
if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
|
||||
continue
|
||||
|
||||
path = f'{outdir}/{voice}_{name}.wav'
|
||||
fixed = f'{outdir}/{voice}_{name}_fixed.wav'
|
||||
path = f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav'
|
||||
fixed = f'{outdir}/{cleanup_voice_name(voice)}_{name}_fixed.wav'
|
||||
voicefixer.restore(
|
||||
input=path,
|
||||
output=fixed,
|
||||
|
@ -775,13 +1150,13 @@ def generate_tortoise(**kwargs):
|
|||
if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
|
||||
if args.prune_nonfinal_outputs:
|
||||
audio_cache[name]['pruned'] = True
|
||||
os.remove(f'{outdir}/{voice}_{name}.wav')
|
||||
os.remove(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
|
||||
continue
|
||||
|
||||
output_voices.append(f'{outdir}/{voice}_{name}.wav')
|
||||
output_voices.append(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
|
||||
|
||||
if not args.embed_output_metadata:
|
||||
with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f:
|
||||
with open(f'{outdir}/{cleanup_voice_name(voice)}_{name}.json', 'w', encoding="utf-8") as f:
|
||||
f.write(json.dumps(audio_cache[name]['settings'], indent='\t') )
|
||||
|
||||
if args.embed_output_metadata:
|
||||
|
@ -789,7 +1164,7 @@ def generate_tortoise(**kwargs):
|
|||
if 'pruned' in audio_cache[name] and audio_cache[name]['pruned']:
|
||||
continue
|
||||
|
||||
metadata = music_tag.load_file(f"{outdir}/{voice}_{name}.wav")
|
||||
metadata = music_tag.load_file(f"{outdir}/{cleanup_voice_name(voice)}_{name}.wav")
|
||||
metadata['lyrics'] = json.dumps(audio_cache[name]['settings'])
|
||||
metadata.save()
|
||||
|
||||
|
@ -1096,9 +1471,9 @@ class TrainingState():
|
|||
'ar-half.loss', 'nar-half.loss', 'ar-half+nar-half.loss',
|
||||
'ar-quarter.loss', 'nar-quarter.loss', 'ar-quarter+nar-quarter.loss',
|
||||
|
||||
# 'ar.loss.nll', 'nar.loss.nll',
|
||||
# 'ar-half.loss.nll', 'nar-half.loss.nll',
|
||||
# 'ar-quarter.loss.nll', 'nar-quarter.loss.nll',
|
||||
'ar.loss.nll', 'nar.loss.nll',
|
||||
'ar-half.loss.nll', 'nar-half.loss.nll',
|
||||
'ar-quarter.loss.nll', 'nar-quarter.loss.nll',
|
||||
]
|
||||
|
||||
keys['accuracies'] = [
|
||||
|
@ -1464,14 +1839,14 @@ def run_training(config_path, verbose=False, keep_x_past_checkpoints=0, progress
|
|||
return_code = training_state.process.wait()
|
||||
training_state = None
|
||||
|
||||
def update_training_dataplot(x_lim=None, y_lim=None, config_path=None):
|
||||
def update_training_dataplot(x_min=None, x_max=None, y_min=None, y_max=None, config_path=None):
|
||||
global training_state
|
||||
losses = None
|
||||
lrs = None
|
||||
grad_norms = None
|
||||
|
||||
x_lim = [ 0, x_lim ]
|
||||
y_lim = [ 0, y_lim ]
|
||||
x_lim = [ x_min, x_max ]
|
||||
y_lim = [ y_min, y_max ]
|
||||
|
||||
if not training_state:
|
||||
if config_path:
|
||||
|
@ -1490,23 +1865,23 @@ def update_training_dataplot(x_lim=None, y_lim=None, config_path=None):
|
|||
losses = gr.LinePlot.update(
|
||||
value = pd.DataFrame(training_state.statistics['loss']),
|
||||
x_lim=x_lim, y_lim=y_lim,
|
||||
x="epoch", y="value",
|
||||
x="it", y="value", # x="epoch",
|
||||
title="Loss Metrics", color="type", tooltip=['epoch', 'it', 'value', 'type'],
|
||||
width=500, height=350
|
||||
)
|
||||
if len(training_state.statistics['lr']) > 0:
|
||||
lrs = gr.LinePlot.update(
|
||||
value = pd.DataFrame(training_state.statistics['lr']),
|
||||
x_lim=x_lim, y_lim=y_lim,
|
||||
x="epoch", y="value",
|
||||
x_lim=x_lim,
|
||||
x="it", y="value", # x="epoch",
|
||||
title="Learning Rate", color="type", tooltip=['epoch', 'it', 'value', 'type'],
|
||||
width=500, height=350
|
||||
)
|
||||
if len(training_state.statistics['grad_norm']) > 0:
|
||||
grad_norms = gr.LinePlot.update(
|
||||
value = pd.DataFrame(training_state.statistics['grad_norm']),
|
||||
x_lim=x_lim, y_lim=y_lim,
|
||||
x="epoch", y="value",
|
||||
x_lim=x_lim,
|
||||
x="it", y="value", # x="epoch",
|
||||
title="Gradient Normals", color="type", tooltip=['epoch', 'it', 'value', 'type'],
|
||||
width=500, height=350
|
||||
)
|
||||
|
@ -1649,13 +2024,13 @@ def whisper_transcribe( file, language=None ):
|
|||
device = "cuda" if get_device_name() == "cuda" else "cpu"
|
||||
if whisper_vad:
|
||||
# omits a considerable amount of the end
|
||||
"""
|
||||
if args.whisper_batchsize > 1:
|
||||
result = whisperx.transcribe_with_vad_parallel(whisper_model, file, whisper_vad, batch_size=args.whisper_batchsize, language=language, task="transcribe")
|
||||
else:
|
||||
result = whisperx.transcribe_with_vad(whisper_model, file, whisper_vad)
|
||||
"""
|
||||
result = whisperx.transcribe_with_vad(whisper_model, file, whisper_vad)
|
||||
"""
|
||||
else:
|
||||
result = whisper_model.transcribe(file)
|
||||
|
||||
|
@ -1717,7 +2092,7 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non
|
|||
os.makedirs(f'{indir}/audio/', exist_ok=True)
|
||||
|
||||
TARGET_SAMPLE_RATE = 22050
|
||||
if args.tts_backend == "vall-e":
|
||||
if args.tts_backend != "tortoise":
|
||||
TARGET_SAMPLE_RATE = 24000
|
||||
if tts:
|
||||
TARGET_SAMPLE_RATE = tts.input_sample_rate
|
||||
|
@ -1735,7 +2110,7 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non
|
|||
try:
|
||||
result = whisper_transcribe(file, language=language)
|
||||
except Exception as e:
|
||||
print("Failed to transcribe:", file)
|
||||
print("Failed to transcribe:", file, e)
|
||||
continue
|
||||
|
||||
results[basename] = result
|
||||
|
@ -1802,7 +2177,7 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, resul
|
|||
results = json.load(open(infile, 'r', encoding="utf-8"))
|
||||
|
||||
TARGET_SAMPLE_RATE = 22050
|
||||
if args.tts_backend == "vall-e":
|
||||
if args.tts_backend != "tortoise":
|
||||
TARGET_SAMPLE_RATE = 24000
|
||||
if tts:
|
||||
TARGET_SAMPLE_RATE = tts.input_sample_rate
|
||||
|
@ -1934,8 +2309,7 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
|
|||
lines = { 'training': [], 'validation': [] }
|
||||
segments = {}
|
||||
|
||||
# I'm not sure how the VALL-E implementation decides what's validation and what's not
|
||||
if args.tts_backend == "vall-e":
|
||||
if args.tts_backend != "tortoise":
|
||||
text_length = 0
|
||||
audio_length = 0
|
||||
|
||||
|
@ -3008,6 +3382,10 @@ def load_tts( restart=False,
|
|||
|
||||
print(f"Loading VALL-E... (Config: {valle_model})")
|
||||
tts = VALLE_TTS(config=args.valle_model)
|
||||
elif args.tts_backend == "bark":
|
||||
|
||||
print(f"Loading Bark...")
|
||||
tts = Bark_TTS(small=args.low_vram)
|
||||
|
||||
print("Loaded TTS, ready for generation.")
|
||||
tts_loading = False
|
||||
|
|
40
src/webui.py
40
src/webui.py
|
@ -167,6 +167,10 @@ def reset_generate_settings_proxy():
|
|||
return tuple(res)
|
||||
|
||||
def compute_latents_proxy(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
|
||||
if args.tts_backend == "bark":
|
||||
global tts
|
||||
tts.create_voice( voice )
|
||||
return voice
|
||||
compute_latents( voice=voice, voice_latents_chunks=voice_latents_chunks, progress=progress )
|
||||
return voice
|
||||
|
||||
|
@ -222,13 +226,13 @@ def prepare_all_datasets( language, validation_text_length, validation_audio_len
|
|||
print("Processing:", voice)
|
||||
message = transcribe_dataset( voice=voice, language=language, skip_existings=skip_existings, progress=progress )
|
||||
messages.append(message)
|
||||
"""
|
||||
|
||||
if slice_audio:
|
||||
for voice in voices:
|
||||
print("Processing:", voice)
|
||||
message = slice_dataset( voice, trim_silence=trim_silence, start_offset=slice_start_offset, end_offset=slice_end_offset, results=None, progress=progress )
|
||||
messages.append(message)
|
||||
"""
|
||||
|
||||
for voice in voices:
|
||||
print("Processing:", voice)
|
||||
|
@ -400,12 +404,13 @@ def setup_gradio():
|
|||
outputs=GENERATE_SETTINGS["mic_audio"],
|
||||
)
|
||||
with gr.Column():
|
||||
preset = None
|
||||
GENERATE_SETTINGS["candidates"] = gr.Slider(value=1, minimum=1, maximum=6, step=1, label="Candidates", visible=args.tts_backend=="tortoise")
|
||||
GENERATE_SETTINGS["seed"] = gr.Number(value=0, precision=0, label="Seed")
|
||||
GENERATE_SETTINGS["seed"] = gr.Number(value=0, precision=0, label="Seed", visible=args.tts_backend!="tortoise")
|
||||
|
||||
preset = gr.Radio( ["Ultra Fast", "Fast", "Standard", "High Quality"], label="Preset", type="value", value="Ultra Fast" )
|
||||
preset = gr.Radio( ["Ultra Fast", "Fast", "Standard", "High Quality"], label="Preset", type="value", value="Ultra Fast", visible=args.tts_backend=="tortoise" )
|
||||
|
||||
GENERATE_SETTINGS["num_autoregressive_samples"] = gr.Slider(value=16, minimum=2, maximum=512, step=1, label="Samples")
|
||||
GENERATE_SETTINGS["num_autoregressive_samples"] = gr.Slider(value=16, minimum=2, maximum=512, step=1, label="Samples", visible=args.tts_backend!="bark")
|
||||
GENERATE_SETTINGS["diffusion_iterations"] = gr.Slider(value=30, minimum=0, maximum=512, step=1, label="Iterations", visible=args.tts_backend=="tortoise")
|
||||
|
||||
GENERATE_SETTINGS["temperature"] = gr.Slider(value=0.2, minimum=0, maximum=1, step=0.1, label="Temperature")
|
||||
|
@ -490,7 +495,7 @@ def setup_gradio():
|
|||
merger_button = gr.Button(value="Run Merger")
|
||||
with gr.Column():
|
||||
merger_output = gr.TextArea(label="Console Output", max_lines=8)
|
||||
with gr.Tab("Training"):
|
||||
with gr.Tab("Training", visible=args.tts_backend != "bark"):
|
||||
with gr.Tab("Prepare Dataset"):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
|
@ -586,8 +591,10 @@ def setup_gradio():
|
|||
keep_x_past_checkpoints = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
|
||||
|
||||
with gr.Row():
|
||||
training_graph_x_lim = gr.Number(label="X Limit", precision=0, value=0)
|
||||
training_graph_y_lim = gr.Number(label="Y Limit", precision=0, value=0)
|
||||
training_graph_x_min = gr.Number(label="X Min", precision=0, value=0)
|
||||
training_graph_x_max = gr.Number(label="X Max", precision=0, value=0)
|
||||
training_graph_y_min = gr.Number(label="Y Min", precision=0, value=0)
|
||||
training_graph_y_max = gr.Number(label="Y Max", precision=0, value=0)
|
||||
|
||||
with gr.Row():
|
||||
start_training_button = gr.Button(value="Train")
|
||||
|
@ -597,7 +604,7 @@ def setup_gradio():
|
|||
|
||||
with gr.Column():
|
||||
training_loss_graph = gr.LinePlot(label="Training Metrics",
|
||||
x="epoch",
|
||||
x="it", # x="epoch",
|
||||
y="value",
|
||||
title="Loss Metrics",
|
||||
color="type",
|
||||
|
@ -606,7 +613,7 @@ def setup_gradio():
|
|||
height=350,
|
||||
)
|
||||
training_lr_graph = gr.LinePlot(label="Training Metrics",
|
||||
x="epoch",
|
||||
x="it", # x="epoch",
|
||||
y="value",
|
||||
title="Learning Rate",
|
||||
color="type",
|
||||
|
@ -615,7 +622,7 @@ def setup_gradio():
|
|||
height=350,
|
||||
)
|
||||
training_grad_norm_graph = gr.LinePlot(label="Training Metrics",
|
||||
x="epoch",
|
||||
x="it", # x="epoch",
|
||||
y="value",
|
||||
title="Gradient Normals",
|
||||
color="type",
|
||||
|
@ -765,6 +772,7 @@ def setup_gradio():
|
|||
inputs=show_experimental_settings,
|
||||
outputs=experimental_column
|
||||
)
|
||||
if preset:
|
||||
preset.change(fn=update_presets,
|
||||
inputs=preset,
|
||||
outputs=[
|
||||
|
@ -860,8 +868,10 @@ def setup_gradio():
|
|||
training_output.change(
|
||||
fn=update_training_dataplot,
|
||||
inputs=[
|
||||
training_graph_x_lim,
|
||||
training_graph_y_lim,
|
||||
training_graph_x_min,
|
||||
training_graph_x_max,
|
||||
training_graph_y_min,
|
||||
training_graph_y_max,
|
||||
],
|
||||
outputs=[
|
||||
training_loss_graph,
|
||||
|
@ -874,8 +884,10 @@ def setup_gradio():
|
|||
view_losses.click(
|
||||
fn=update_training_dataplot,
|
||||
inputs=[
|
||||
training_graph_x_lim,
|
||||
training_graph_y_lim,
|
||||
training_graph_x_min,
|
||||
training_graph_x_max,
|
||||
training_graph_y_min,
|
||||
training_graph_y_max,
|
||||
training_configs,
|
||||
],
|
||||
outputs=[
|
||||
|
|
Loading…
Reference in New Issue
Block a user