Compare commits

...

5 Commits

Author SHA1 Message Date
yqxtqymn
1e2436aac9 Update 'src/utils.py'
removed some comments
2023-03-06 02:04:19 +00:00
yqxtqymn
f657f30e2b Update 'src/utils.py'
whisper->whisperx
2023-03-06 01:59:58 +00:00
yqxtqymn
4f123910fb Update 'src/webui.py'
whisper->whisperx
2023-03-06 01:59:42 +00:00
yqxtqymn
9ca5192309 Update 'src/utils.py'
whisper->whisperx
2023-03-06 00:47:56 +00:00
yqxtqymn
079cd32074 Update 'requirements.txt'
whisper->whisperx
2023-03-06 00:47:03 +00:00
3 changed files with 2731 additions and 2646 deletions

View File

@ -1,4 +1,4 @@
git+https://github.com/openai/whisper.git git+https://github.com/m-bain/whisperx.git
more-itertools more-itertools
ffmpeg-python ffmpeg-python
gradio gradio

View File

@ -1,4 +1,5 @@
import os import os
if 'XDG_CACHE_HOME' not in os.environ: if 'XDG_CACHE_HOME' not in os.environ:
os.environ['XDG_CACHE_HOME'] = os.path.realpath(os.path.join(os.getcwd(), './models/')) os.environ['XDG_CACHE_HOME'] = os.path.realpath(os.path.join(os.getcwd(), './models/'))
@ -27,6 +28,7 @@ import music_tag
import gradio as gr import gradio as gr
import gradio.utils import gradio.utils
import pandas as pd import pandas as pd
import whisperx
from datetime import datetime from datetime import datetime
from datetime import timedelta from datetime import timedelta
@ -36,9 +38,9 @@ from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_
from tortoise.utils.text import split_and_recombine_text from tortoise.utils.text import split_and_recombine_text
from tortoise.utils.device import get_device_name, set_device_name from tortoise.utils.device import get_device_name, set_device_name
MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth" MODELS[
WHISPER_MODELS = ["tiny", "base", "small", "medium", "large"] 'dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth"
WHISPER_SPECIALIZED_MODELS = ["tiny.en", "base.en", "small.en", "medium.en"] WHISPER_MODELS = ["tiny", "base", "small", "medium", "large", "large-v2"]
EPOCH_SCHEDULE = [9, 18, 25, 33] EPOCH_SCHEDULE = [9, 18, 25, 33]
args = None args = None
@ -49,6 +51,7 @@ voicefixer = None
whisper_model = None whisper_model = None
training_state = None training_state = None
def generate( def generate(
text, text,
delimiter, delimiter,
@ -110,13 +113,16 @@ def generate(
if voice_samples and len(voice_samples) > 0: if voice_samples and len(voice_samples) > 0:
sample_voice = torch.cat(voice_samples, dim=-1).squeeze().cpu() sample_voice = torch.cat(voice_samples, dim=-1).squeeze().cpu()
conditioning_latents = tts.get_conditioning_latents(voice_samples, return_mels=not args.latents_lean_and_mean, progress=progress, slices=voice_latents_chunks, force_cpu=args.force_cpu_for_conditioning_latents) conditioning_latents = tts.get_conditioning_latents(voice_samples, return_mels=not args.latents_lean_and_mean,
progress=progress, slices=voice_latents_chunks,
force_cpu=args.force_cpu_for_conditioning_latents)
if len(conditioning_latents) == 4: if len(conditioning_latents) == 4:
conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None) conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None)
if voice != "microphone": if voice != "microphone":
if hasattr(tts, 'autoregressive_model_hash'): if hasattr(tts, 'autoregressive_model_hash'):
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth') torch.save(conditioning_latents,
f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth')
else: else:
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth') torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth')
voice_samples = None voice_samples = None
@ -132,10 +138,10 @@ def generate(
seed = None seed = None
if conditioning_latents is not None and len(conditioning_latents) == 2 and cvvp_weight > 0: if conditioning_latents is not None and len(conditioning_latents) == 2 and cvvp_weight > 0:
print("Requesting weighing against CVVP weight, but voice latents are missing some extra data. Please regenerate your voice latents.") print(
"Requesting weighing against CVVP weight, but voice latents are missing some extra data. Please regenerate your voice latents.")
cvvp_weight = 0 cvvp_weight = 0
settings = { settings = {
'temperature': float(temperature), 'temperature': float(temperature),
@ -199,7 +205,8 @@ def generate(
beta=8.555504641634386, beta=8.555504641634386,
) )
volume_adjust = torchaudio.transforms.Vol(gain=args.output_volume, gain_type="amplitude") if args.output_volume != 1 else None volume_adjust = torchaudio.transforms.Vol(gain=args.output_volume,
gain_type="amplitude") if args.output_volume != 1 else None
idx = 0 idx = 0
idx_cache = {} idx_cache = {}
@ -380,7 +387,6 @@ def generate(
with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f: with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f:
f.write(json.dumps(info, indent='\t')) f.write(json.dumps(info, indent='\t'))
if voice and voice != "random" and conditioning_latents is not None: if voice and voice != "random" and conditioning_latents is not None:
latents_path = f'{get_voice_dir()}/{voice}/cond_latents.pth' latents_path = f'{get_voice_dir()}/{voice}/cond_latents.pth'
@ -428,10 +434,12 @@ def generate(
stats, stats,
) )
def cancel_generate(): def cancel_generate():
import tortoise.api import tortoise.api
tortoise.api.STOP_SIGNAL = True tortoise.api.STOP_SIGNAL = True
def hash_file(path, algo="md5", buffer_size=0): def hash_file(path, algo="md5", buffer_size=0):
import hashlib import hashlib
@ -458,6 +466,7 @@ def hash_file(path, algo="md5", buffer_size=0):
return "{0}".format(hash.hexdigest()) return "{0}".format(hash.hexdigest())
def update_baseline_for_latents_chunks(voice): def update_baseline_for_latents_chunks(voice):
path = f'{get_voice_dir()}/{voice}/' path = f'{get_voice_dir()}/{voice}/'
if not os.path.isdir(path): if not os.path.isdir(path):
@ -481,6 +490,7 @@ def update_baseline_for_latents_chunks( voice ):
return int(total_duration / total) if total > 0 else 1 return int(total_duration / total) if total > 0 else 1
return int(total_duration / args.autocalculate_voice_chunk_duration_size) if total_duration > 0 else 1 return int(total_duration / args.autocalculate_voice_chunk_duration_size) if total_duration > 0 else 1
def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)): def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
global tts global tts
global args global args
@ -498,18 +508,22 @@ def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm
if voice_samples is None: if voice_samples is None:
return return
conditioning_latents = tts.get_conditioning_latents(voice_samples, return_mels=not args.latents_lean_and_mean, progress=progress, slices=voice_latents_chunks, force_cpu=args.force_cpu_for_conditioning_latents) conditioning_latents = tts.get_conditioning_latents(voice_samples, return_mels=not args.latents_lean_and_mean,
progress=progress, slices=voice_latents_chunks,
force_cpu=args.force_cpu_for_conditioning_latents)
if len(conditioning_latents) == 4: if len(conditioning_latents) == 4:
conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None) conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None)
if hasattr(tts, 'autoregressive_model_hash'): if hasattr(tts, 'autoregressive_model_hash'):
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth') torch.save(conditioning_latents,
f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth')
else: else:
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth') torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth')
return voice return voice
# superfluous, but it cleans up some things # superfluous, but it cleans up some things
class TrainingState(): class TrainingState():
def __init__(self, config_path, keep_x_past_datasets=0, start=True, gpus=1): def __init__(self, config_path, keep_x_past_datasets=0, start=True, gpus=1):
@ -580,7 +594,8 @@ class TrainingState():
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['./train.sh', str(int(gpus)), config_path] self.cmd = ['train.bat', config_path] if os.name == "nt" else ['./train.sh', str(int(gpus)), config_path]
print("Spawning process: ", " ".join(self.cmd)) print("Spawning process: ", " ".join(self.cmd))
self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True) self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
universal_newlines=True)
def load_losses(self, update=False): def load_losses(self, update=False):
if not os.path.isdir(f'{self.dataset_dir}/tb_logger/'): if not os.path.isdir(f'{self.dataset_dir}/tb_logger/'):
@ -599,7 +614,8 @@ class TrainingState():
self.statistics = [] self.statistics = []
if use_tensorboard: if use_tensorboard:
logs = sorted([f'{self.dataset_dir}/tb_logger/{d}' for d in os.listdir(f'{self.dataset_dir}/tb_logger/') if d[:6] == "events" ]) logs = sorted([f'{self.dataset_dir}/tb_logger/{d}' for d in os.listdir(f'{self.dataset_dir}/tb_logger/') if
d[:6] == "events"])
if update: if update:
logs = [logs[-1]] logs = [logs[-1]]
@ -893,6 +909,7 @@ class TrainingState():
message, message,
) )
def run_training(config_path, verbose=False, gpus=1, keep_x_past_datasets=0, progress=gr.Progress(track_tqdm=True)): def run_training(config_path, verbose=False, gpus=1, keep_x_past_datasets=0, progress=gr.Progress(track_tqdm=True)):
global training_state global training_state
if training_state and training_state.process: if training_state and training_state.process:
@ -911,7 +928,8 @@ def run_training(config_path, verbose=False, gpus=1, keep_x_past_datasets=0, pro
if training_state.killed: if training_state.killed:
return return
result, percent, message = training_state.parse( line=line, verbose=verbose, keep_x_past_datasets=keep_x_past_datasets, progress=progress ) result, percent, message = training_state.parse(line=line, verbose=verbose,
keep_x_past_datasets=keep_x_past_datasets, progress=progress)
print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}") print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}")
if result: if result:
yield result yield result
@ -924,6 +942,7 @@ def run_training(config_path, verbose=False, gpus=1, keep_x_past_datasets=0, pro
return_code = training_state.process.wait() return_code = training_state.process.wait()
training_state = None training_state = None
def update_training_dataplot(config_path=None): def update_training_dataplot(config_path=None):
global training_state global training_state
update = None update = None
@ -932,22 +951,29 @@ def update_training_dataplot(config_path=None):
if config_path: if config_path:
training_state = TrainingState(config_path=config_path, start=False) training_state = TrainingState(config_path=config_path, start=False)
if training_state.statistics: if training_state.statistics:
update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics), x_lim=[0,training_state.its], x="step", y="value", title="Training Metrics", color="type", tooltip=['step', 'value', 'type'], width=600, height=350,) update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics),
x_lim=[0, training_state.its], x="step", y="value",
title="Training Metrics", color="type", tooltip=['step', 'value', 'type'],
width=600, height=350, )
del training_state del training_state
training_state = None training_state = None
elif training_state.statistics: elif training_state.statistics:
training_state.load_losses() training_state.load_losses()
update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics), x_lim=[0,training_state.its], x="step", y="value", title="Training Metrics", color="type", tooltip=['step', 'value', 'type'], width=600, height=350,) update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics), x_lim=[0, training_state.its],
x="step", y="value", title="Training Metrics", color="type",
tooltip=['step', 'value', 'type'], width=600, height=350, )
return update return update
def reconnect_training(verbose=False, progress=gr.Progress(track_tqdm=True)): def reconnect_training(verbose=False, progress=gr.Progress(track_tqdm=True)):
global training_state global training_state
if not training_state or not training_state.process: if not training_state or not training_state.process:
return "Training not in progress" return "Training not in progress"
for line in iter(training_state.process.stdout.readline, ""): for line in iter(training_state.process.stdout.readline, ""):
result, percent, message = training_state.parse( line=line, verbose=verbose, keep_x_past_datasets=keep_x_past_datasets, progress=progress ) result, percent, message = training_state.parse(line=line, verbose=verbose,
keep_x_past_datasets=keep_x_past_datasets, progress=progress)
print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}") print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}")
if result: if result:
yield result yield result
@ -955,6 +981,7 @@ def reconnect_training(verbose=False, progress=gr.Progress(track_tqdm=True)):
if progress is not None and message: if progress is not None and message:
progress(percent, message) progress(percent, message)
def stop_training(): def stop_training():
global training_state global training_state
if training_state is None: if training_state is None:
@ -965,7 +992,8 @@ def stop_training():
children = [] children = []
# wrapped in a try/catch in case for some reason this fails outside of Linux # wrapped in a try/catch in case for some reason this fails outside of Linux
try: try:
children = [p.info for p in psutil.process_iter(attrs=['pid', 'name', 'cmdline']) if './src/train.py' in p.info['cmdline']] children = [p.info for p in psutil.process_iter(attrs=['pid', 'name', 'cmdline']) if
'./src/train.py' in p.info['cmdline']]
except Exception as e: except Exception as e:
pass pass
@ -981,10 +1009,12 @@ def stop_training():
print("Killed training process.") print("Killed training process.")
return f"Training cancelled: {return_code}" return f"Training cancelled: {return_code}"
def get_halfp_model_path(): def get_halfp_model_path():
autoregressive_model_path = get_model_path('autoregressive.pth') autoregressive_model_path = get_model_path('autoregressive.pth')
return autoregressive_model_path.replace(".pth", "_half.pth") return autoregressive_model_path.replace(".pth", "_half.pth")
def convert_to_halfp(): def convert_to_halfp():
autoregressive_model_path = get_model_path('autoregressive.pth') autoregressive_model_path = get_model_path('autoregressive.pth')
print(f'Converting model to half precision: {autoregressive_model_path}') print(f'Converting model to half precision: {autoregressive_model_path}')
@ -996,66 +1026,59 @@ def convert_to_halfp():
torch.save(model, outfile) torch.save(model, outfile)
print(f'Converted model to half precision: {outfile}') print(f'Converted model to half precision: {outfile}')
def whisper_transcribe( file, language=None ):
# shouldn't happen, but it's for safety
if not whisper_model:
load_whisper_model(language=language)
if not args.whisper_cpp:
if not language:
language = None
return whisper_model.transcribe(file, language=language)
res = whisper_model.transcribe(file)
segments = whisper_model.extract_text_and_timestamps( res )
result = {
'segments': []
}
for segment in segments:
reparsed = {
'start': segment[0] / 100.0,
'end': segment[1] / 100.0,
'text': segment[2],
}
result['segments'].append(reparsed)
return result
def prepare_dataset(files, outdir, language=None, progress=None): def prepare_dataset(files, outdir, language=None, progress=None):
unload_tts() unload_tts()
global whisper_model global whisper_model
if whisper_model is None: if whisper_model is None:
load_whisper_model(language=language) load_whisper_model()
os.makedirs(outdir, exist_ok=True) os.makedirs(outdir, exist_ok=True)
idx = 0
results = {} results = {}
transcription = [] transcription = []
idx = 0
results = {}
transcription = []
if (torch.cuda.is_available()):
device = "cuda"
else:
device = "cpu"
for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress): for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress):
basename = os.path.basename(file) print(f"Transcribing file: {file}")
result = whisper_transcribe(file, language=language)
results[basename] = result result = whisper_model.transcribe(file)
print(result["segments"]) # before alignment
# load alignment model and metadata
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
# align whisper output
result_aligned = whisperx.align(result["segments"], model_a, metadata, file, device)
print(result_aligned["segments"]) # after alignment
print(result_aligned["word_segments"]) # after alignment
results[os.path.basename(file)] = result
print(f"Transcribed file: {file}, {len(result['segments'])} found.") print(f"Transcribed file: {file}, {len(result['segments'])} found.")
waveform, sampling_rate = torchaudio.load(file) waveform, sampling_rate = torchaudio.load(file)
num_channels, num_frames = waveform.shape num_channels, num_frames = waveform.shape
idx = 0 for segment in result[
for segment in result['segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress): 'segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress):
start = int(segment['start'] * sampling_rate) start = int(segment['start'] * sampling_rate)
end = int(segment['end'] * sampling_rate) end = int(segment['end'] * sampling_rate)
sliced_waveform = waveform[:, start:end] sliced_waveform = waveform[:, start:end]
sliced_name = basename.replace(".wav", f"_{pad(idx, 4)}.wav") sliced_name = f"{pad(idx, 4)}.wav"
if not torch.any(sliced_waveform < 0):
print(f"Error with {sliced_name}, skipping...")
continue
torchaudio.save(f"{outdir}/{sliced_name}", sliced_waveform, sampling_rate) torchaudio.save(f"{outdir}/{sliced_name}", sliced_waveform, sampling_rate)
@ -1076,14 +1099,19 @@ def prepare_dataset( files, outdir, language=None, progress=None ):
return f"Processed dataset to: {outdir}\n{joined}" return f"Processed dataset to: {outdir}\n{joined}"
def calc_iterations(epochs, lines, batch_size): def calc_iterations(epochs, lines, batch_size):
iterations = int(epochs * lines / float(batch_size)) iterations = int(epochs * lines / float(batch_size))
return iterations return iterations
def schedule_learning_rate(iterations, schedule=EPOCH_SCHEDULE): def schedule_learning_rate(iterations, schedule=EPOCH_SCHEDULE):
return [int(iterations * d) for d in schedule] return [int(iterations * d) for d in schedule]
def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, gradient_accumulation_size, print_rate, save_rate, resume_path, half_p, bnb, workers, source_model, voice ):
def optimize_training_settings(epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size,
gradient_accumulation_size, print_rate, save_rate, resume_path, half_p, bnb, workers,
source_model, voice):
name = f"{voice}-finetune" name = f"{voice}-finetune"
dataset_name = f"{voice}-train" dataset_name = f"{voice}-train"
dataset_path = f"./training/{voice}/train.txt" dataset_path = f"./training/{voice}/train.txt"
@ -1102,7 +1130,8 @@ def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learni
if batch_size % lines != 0: if batch_size % lines != 0:
nearest_slice = int(lines / batch_size) + 1 nearest_slice = int(lines / batch_size) + 1
batch_size = int(lines / nearest_slice) batch_size = int(lines / nearest_slice)
messages.append(f"Batch size not neatly divisible by dataset size, adjusting batch size to: {batch_size} ({nearest_slice} steps per epoch)") messages.append(
f"Batch size not neatly divisible by dataset size, adjusting batch size to: {batch_size} ({nearest_slice} steps per epoch)")
if gradient_accumulation_size == 0: if gradient_accumulation_size == 0:
gradient_accumulation_size = 1 gradient_accumulation_size = 1
@ -1112,13 +1141,15 @@ def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learni
if gradient_accumulation_size == 0: if gradient_accumulation_size == 0:
gradient_accumulation_size = 1 gradient_accumulation_size = 1
messages.append(f"Gradient accumulation size is too large for a given batch size, clamping gradient accumulation size to: {gradient_accumulation_size}") messages.append(
f"Gradient accumulation size is too large for a given batch size, clamping gradient accumulation size to: {gradient_accumulation_size}")
elif batch_size % gradient_accumulation_size != 0: elif batch_size % gradient_accumulation_size != 0:
gradient_accumulation_size = int(batch_size / gradient_accumulation_size) gradient_accumulation_size = int(batch_size / gradient_accumulation_size)
if gradient_accumulation_size == 0: if gradient_accumulation_size == 0:
gradient_accumulation_size = 1 gradient_accumulation_size = 1
messages.append(f"Batch size is not evenly divisible by the gradient accumulation size, adjusting gradient accumulation size to: {gradient_accumulation_size}") messages.append(
f"Batch size is not evenly divisible by the gradient accumulation size, adjusting gradient accumulation size to: {gradient_accumulation_size}")
iterations = calc_iterations(epochs=epochs, lines=lines, batch_size=batch_size) iterations = calc_iterations(epochs=epochs, lines=lines, batch_size=batch_size)
@ -1140,13 +1171,15 @@ def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learni
if half_p: if half_p:
if bnb: if bnb:
half_p = False half_p = False
messages.append("Half Precision requested, but BitsAndBytes is also requested. Due to redundancies, disabling half precision...") messages.append(
"Half Precision requested, but BitsAndBytes is also requested. Due to redundancies, disabling half precision...")
else: else:
messages.append("Half Precision requested. Please note this is ! EXPERIMENTAL !") messages.append("Half Precision requested. Please note this is ! EXPERIMENTAL !")
if not os.path.exists(get_halfp_model_path()): if not os.path.exists(get_halfp_model_path()):
convert_to_halfp() convert_to_halfp()
messages.append(f"For {epochs} epochs with {lines} lines in batches of {batch_size}, iterating for {iterations} steps ({int(iterations / epochs)} steps per epoch)") messages.append(
f"For {epochs} epochs with {lines} lines in batches of {batch_size}, iterating for {iterations} steps ({int(iterations / epochs)} steps per epoch)")
return ( return (
learning_rate, learning_rate,
@ -1160,7 +1193,11 @@ def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learni
messages messages
) )
def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weight=None, learning_rate_schedule=None, batch_size=None, gradient_accumulation_size=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None, half_p=None, bnb=None, workers=None, source_model=None ):
def save_training_settings(iterations=None, learning_rate=None, text_ce_lr_weight=None, learning_rate_schedule=None,
batch_size=None, gradient_accumulation_size=None, print_rate=None, save_rate=None, name=None,
dataset_name=None, dataset_path=None, validation_name=None, validation_path=None,
output_name=None, resume_path=None, half_p=None, bnb=None, workers=None, source_model=None):
if not source_model: if not source_model:
source_model = f"./models/tortoise/autoregressive{'_half' if half_p else ''}.pth" source_model = f"./models/tortoise/autoregressive{'_half' if half_p else ''}.pth"
@ -1201,7 +1238,6 @@ def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weig
if not output_name: if not output_name:
output_name = f'{settings["name"]}.yaml' output_name = f'{settings["name"]}.yaml'
with open(f'./models/.template.yaml', 'r', encoding="utf-8") as f: with open(f'./models/.template.yaml', 'r', encoding="utf-8") as f:
yaml = f.read() yaml = f.read()
@ -1217,6 +1253,7 @@ def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weig
return f"Training settings saved to: {outfile}" return f"Training settings saved to: {outfile}"
def import_voices(files, saveAs=None, progress=None): def import_voices(files, saveAs=None, progress=None):
global args global args
@ -1282,13 +1319,16 @@ def import_voices(files, saveAs=None, progress=None):
print(f"Imported voice to {path}") print(f"Imported voice to {path}")
def get_voice_list(dir=get_voice_dir(), append_defaults=False): def get_voice_list(dir=get_voice_dir(), append_defaults=False):
os.makedirs(dir, exist_ok=True) os.makedirs(dir, exist_ok=True)
res = sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 ]) res = sorted([d for d in os.listdir(dir) if
os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0])
if append_defaults: if append_defaults:
res = res + ["random", "microphone"] res = res + ["random", "microphone"]
return res return res
def get_autoregressive_models(dir="./models/finetunes/", prefixed=False): def get_autoregressive_models(dir="./models/finetunes/", prefixed=False):
os.makedirs(dir, exist_ok=True) os.makedirs(dir, exist_ok=True)
base = [get_model_path('autoregressive.pth')] base = [get_model_path('autoregressive.pth')]
@ -1316,11 +1356,16 @@ def get_autoregressive_models(dir="./models/finetunes/", prefixed=False):
return res return res
def get_dataset_list(dir="./training/"): def get_dataset_list(dir="./training/"):
return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 and "train.txt" in os.listdir(os.path.join(dir, d)) ]) return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(
os.listdir(os.path.join(dir, d))) > 0 and "train.txt" in os.listdir(os.path.join(dir, d))])
def get_training_list(dir="./training/"): def get_training_list(dir="./training/"):
return sorted([f'./training/{d}/train.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 and "train.yaml" in os.listdir(os.path.join(dir, d)) ]) return sorted([f'./training/{d}/train.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(
os.listdir(os.path.join(dir, d))) > 0 and "train.yaml" in os.listdir(os.path.join(dir, d))])
def do_gc(): def do_gc():
gc.collect() gc.collect()
@ -1329,9 +1374,11 @@ def do_gc():
except Exception as e: except Exception as e:
pass pass
def pad(num, zeroes): def pad(num, zeroes):
return str(num).zfill(zeroes + 1) return str(num).zfill(zeroes + 1)
def curl(url): def curl(url):
try: try:
req = urllib.request.Request(url, headers={'User-Agent': 'Python'}) req = urllib.request.Request(url, headers={'User-Agent': 'Python'})
@ -1345,6 +1392,7 @@ def curl(url):
print(e) print(e)
return None return None
def check_for_updates(): def check_for_updates():
if not os.path.isfile('./.git/FETCH_HEAD'): if not os.path.isfile('./.git/FETCH_HEAD'):
print("Cannot check for updates: not from a git repo") print("Cannot check for updates: not from a git repo")
@ -1379,13 +1427,16 @@ def check_for_updates():
return False return False
def enumerate_progress(iterable, desc=None, progress=None, verbose=None): def enumerate_progress(iterable, desc=None, progress=None, verbose=None):
if verbose and desc is not None: if verbose and desc is not None:
print(desc) print(desc)
if progress is None: if progress is None:
return tqdm(iterable, disable=not verbose) return tqdm(iterable, disable=not verbose)
return progress.tqdm(iterable, desc=f'{progress.msg_prefix} {desc}' if hasattr(progress, 'msg_prefix') else desc, track_tqdm=True) return progress.tqdm(iterable, desc=f'{progress.msg_prefix} {desc}' if hasattr(progress, 'msg_prefix') else desc,
track_tqdm=True)
def notify_progress(message, progress=None, verbose=True): def notify_progress(message, progress=None, verbose=True):
if verbose: if verbose:
@ -1396,10 +1447,12 @@ def notify_progress(message, progress=None, verbose=True):
progress(0, desc=message) progress(0, desc=message)
def get_args(): def get_args():
global args global args
return args return args
def setup_args(): def setup_args():
global args global args
@ -1412,7 +1465,8 @@ def setup_args():
'sample-batch-size': None, 'sample-batch-size': None,
'embed-output-metadata': True, 'embed-output-metadata': True,
'latents-lean-and-mean': True, 'latents-lean-and-mean': True,
'voice-fixer': False, # getting tired of long initialization times in a Colab for downloading a large dataset for it 'voice-fixer': False,
# getting tired of long initialization times in a Colab for downloading a large dataset for it
'voice-fixer-use-cuda': True, 'voice-fixer-use-cuda': True,
'force-cpu-for-conditioning-latents': False, 'force-cpu-for-conditioning-latents': False,
'defer-tts-load': False, 'defer-tts-load': False,
@ -1443,32 +1497,61 @@ def setup_args():
pass pass
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--share", action='store_true', default=default_arguments['share'], help="Lets Gradio return a public URL to use anywhere") parser.add_argument("--share", action='store_true', default=default_arguments['share'],
help="Lets Gradio return a public URL to use anywhere")
parser.add_argument("--listen", default=default_arguments['listen'], help="Path for Gradio to listen on") parser.add_argument("--listen", default=default_arguments['listen'], help="Path for Gradio to listen on")
parser.add_argument("--check-for-updates", action='store_true', default=default_arguments['check-for-updates'], help="Checks for update on startup") parser.add_argument("--check-for-updates", action='store_true', default=default_arguments['check-for-updates'],
parser.add_argument("--models-from-local-only", action='store_true', default=default_arguments['models-from-local-only'], help="Only loads models from disk, does not check for updates for models") help="Checks for update on startup")
parser.add_argument("--low-vram", action='store_true', default=default_arguments['low-vram'], help="Disables some optimizations that increases VRAM usage") parser.add_argument("--models-from-local-only", action='store_true',
parser.add_argument("--no-embed-output-metadata", action='store_false', default=not default_arguments['embed-output-metadata'], help="Disables embedding output metadata into resulting WAV files for easily fetching its settings used with the web UI (data is stored in the lyrics metadata tag)") default=default_arguments['models-from-local-only'],
parser.add_argument("--latents-lean-and-mean", action='store_true', default=default_arguments['latents-lean-and-mean'], help="Exports the bare essentials for latents.") help="Only loads models from disk, does not check for updates for models")
parser.add_argument("--voice-fixer", action='store_true', default=default_arguments['voice-fixer'], help="Uses python module 'voicefixer' to improve audio quality, if available.") parser.add_argument("--low-vram", action='store_true', default=default_arguments['low-vram'],
parser.add_argument("--voice-fixer-use-cuda", action='store_true', default=default_arguments['voice-fixer-use-cuda'], help="Hints to voicefixer to use CUDA, if available.") help="Disables some optimizations that increases VRAM usage")
parser.add_argument("--force-cpu-for-conditioning-latents", default=default_arguments['force-cpu-for-conditioning-latents'], action='store_true', help="Forces computing conditional latents to be done on the CPU (if you constantyl OOM on low chunk counts)") parser.add_argument("--no-embed-output-metadata", action='store_false',
parser.add_argument("--defer-tts-load", default=default_arguments['defer-tts-load'], action='store_true', help="Defers loading TTS model") default=not default_arguments['embed-output-metadata'],
parser.add_argument("--prune-nonfinal-outputs", default=default_arguments['prune-nonfinal-outputs'], action='store_true', help="Deletes non-final output files on completing a generation") help="Disables embedding output metadata into resulting WAV files for easily fetching its settings used with the web UI (data is stored in the lyrics metadata tag)")
parser.add_argument("--use-bigvgan-vocoder", default=default_arguments['use-bigvgan-vocoder'], action='store_true', help="Uses BigVGAN in place of the default vocoder") parser.add_argument("--latents-lean-and-mean", action='store_true',
parser.add_argument("--device-override", default=default_arguments['device-override'], help="A device string to override pass through Torch") default=default_arguments['latents-lean-and-mean'],
parser.add_argument("--sample-batch-size", default=default_arguments['sample-batch-size'], type=int, help="Sets how many batches to use during the autoregressive samples pass") help="Exports the bare essentials for latents.")
parser.add_argument("--concurrency-count", type=int, default=default_arguments['concurrency-count'], help="How many Gradio events to process at once") parser.add_argument("--voice-fixer", action='store_true', default=default_arguments['voice-fixer'],
parser.add_argument("--autocalculate-voice-chunk-duration-size", type=float, default=default_arguments['autocalculate-voice-chunk-duration-size'], help="Number of seconds to suggest voice chunk size for (for example, 100 seconds of audio at 10 seconds per chunk will suggest 10 chunks)") help="Uses python module 'voicefixer' to improve audio quality, if available.")
parser.add_argument("--output-sample-rate", type=int, default=default_arguments['output-sample-rate'], help="Sample rate to resample the output to (from 24KHz)") parser.add_argument("--voice-fixer-use-cuda", action='store_true',
parser.add_argument("--output-volume", type=float, default=default_arguments['output-volume'], help="Adjusts volume of output") default=default_arguments['voice-fixer-use-cuda'],
help="Hints to voicefixer to use CUDA, if available.")
parser.add_argument("--force-cpu-for-conditioning-latents",
default=default_arguments['force-cpu-for-conditioning-latents'], action='store_true',
help="Forces computing conditional latents to be done on the CPU (if you constantyl OOM on low chunk counts)")
parser.add_argument("--defer-tts-load", default=default_arguments['defer-tts-load'], action='store_true',
help="Defers loading TTS model")
parser.add_argument("--prune-nonfinal-outputs", default=default_arguments['prune-nonfinal-outputs'],
action='store_true', help="Deletes non-final output files on completing a generation")
parser.add_argument("--use-bigvgan-vocoder", default=default_arguments['use-bigvgan-vocoder'], action='store_true',
help="Uses BigVGAN in place of the default vocoder")
parser.add_argument("--device-override", default=default_arguments['device-override'],
help="A device string to override pass through Torch")
parser.add_argument("--sample-batch-size", default=default_arguments['sample-batch-size'], type=int,
help="Sets how many batches to use during the autoregressive samples pass")
parser.add_argument("--concurrency-count", type=int, default=default_arguments['concurrency-count'],
help="How many Gradio events to process at once")
parser.add_argument("--autocalculate-voice-chunk-duration-size", type=float,
default=default_arguments['autocalculate-voice-chunk-duration-size'],
help="Number of seconds to suggest voice chunk size for (for example, 100 seconds of audio at 10 seconds per chunk will suggest 10 chunks)")
parser.add_argument("--output-sample-rate", type=int, default=default_arguments['output-sample-rate'],
help="Sample rate to resample the output to (from 24KHz)")
parser.add_argument("--output-volume", type=float, default=default_arguments['output-volume'],
help="Adjusts volume of output")
parser.add_argument("--autoregressive-model", default=default_arguments['autoregressive-model'], help="Specifies which autoregressive model to use for sampling.") parser.add_argument("--autoregressive-model", default=default_arguments['autoregressive-model'],
parser.add_argument("--whisper-model", default=default_arguments['whisper-model'], help="Specifies which whisper model to use for transcription.") help="Specifies which autoregressive model to use for sampling.")
parser.add_argument("--whisper-cpp", default=default_arguments['whisper-cpp'], action='store_true', help="Leverages lightmare/whispercpp for transcription") parser.add_argument("--whisper-model", default=default_arguments['whisper-model'],
help="Specifies which whisper model to use for transcription.")
parser.add_argument("--whisper-cpp", default=default_arguments['whisper-cpp'], action='store_true',
help="Leverages lightmare/whispercpp for transcription")
parser.add_argument("--training-default-halfp", action='store_true', default=default_arguments['training-default-halfp'], help="Training default: halfp") parser.add_argument("--training-default-halfp", action='store_true',
parser.add_argument("--training-default-bnb", action='store_true', default=default_arguments['training-default-bnb'], help="Training default: bnb") default=default_arguments['training-default-halfp'], help="Training default: halfp")
parser.add_argument("--training-default-bnb", action='store_true',
default=default_arguments['training-default-bnb'], help="Training default: bnb")
parser.add_argument("--os", default="unix", help="Specifies which OS, easily") parser.add_argument("--os", default="unix", help="Specifies which OS, easily")
args = parser.parse_args() args = parser.parse_args()
@ -1478,7 +1561,6 @@ def setup_args():
if not args.device_override: if not args.device_override:
set_device_name(args.device_override) set_device_name(args.device_override)
args.listen_host = None args.listen_host = None
args.listen_port = None args.listen_port = None
args.listen_path = None args.listen_path = None
@ -1499,7 +1581,12 @@ def setup_args():
return args return args
def update_args( listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents, defer_tts_load, prune_nonfinal_outputs, use_bigvgan_vocoder, device_override, sample_batch_size, concurrency_count, autocalculate_voice_chunk_duration_size, output_volume, autoregressive_model, whisper_model, whisper_cpp, training_default_halfp, training_default_bnb ):
def update_args(listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata,
latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents,
defer_tts_load, prune_nonfinal_outputs, use_bigvgan_vocoder, device_override, sample_batch_size,
concurrency_count, autocalculate_voice_chunk_duration_size, output_volume, autoregressive_model,
whisper_model, whisper_cpp, training_default_halfp, training_default_bnb):
global args global args
args.listen = listen args.listen = listen
@ -1531,6 +1618,7 @@ def update_args( listen, share, check_for_updates, models_from_local_only, low_v
save_args_settings() save_args_settings()
def save_args_settings(): def save_args_settings():
global args global args
settings = { settings = {
@ -1567,7 +1655,6 @@ def save_args_settings():
f.write(json.dumps(settings, indent='\t')) f.write(json.dumps(settings, indent='\t'))
def import_generate_settings(file="./config/generate.json"): def import_generate_settings(file="./config/generate.json"):
settings, _ = read_generate_settings(file, read_latents=False) settings, _ = read_generate_settings(file, read_latents=False)
@ -1604,6 +1691,7 @@ def reset_generation_settings():
f.write(json.dumps({}, indent='\t')) f.write(json.dumps({}, indent='\t'))
return import_generate_settings() return import_generate_settings()
def read_generate_settings(file, read_latents=True): def read_generate_settings(file, read_latents=True):
j = None j = None
latents = None latents = None
@ -1632,17 +1720,15 @@ def read_generate_settings(file, read_latents=True):
latents = base64.b64decode(j['latents']) latents = base64.b64decode(j['latents'])
del j['latents'] del j['latents']
if "time" in j: if "time" in j:
j["time"] = "{:.3f}".format(j["time"]) j["time"] = "{:.3f}".format(j["time"])
return ( return (
j, j,
latents, latents,
) )
def load_tts(restart=False, model=None): def load_tts(restart=False, model=None):
global args global args
global tts global tts
@ -1650,7 +1736,6 @@ def load_tts( restart=False, model=None ):
if restart: if restart:
unload_tts() unload_tts()
if model: if model:
args.autoregressive_model = model args.autoregressive_model = model
@ -1672,8 +1757,10 @@ def load_tts( restart=False, model=None ):
print("Loaded TorToiSe, ready for generation.") print("Loaded TorToiSe, ready for generation.")
return tts return tts
setup_tortoise = load_tts setup_tortoise = load_tts
def unload_tts(): def unload_tts():
global tts global tts
@ -1683,9 +1770,11 @@ def unload_tts():
print("Unloaded TTS") print("Unloaded TTS")
do_gc() do_gc()
def reload_tts(model=None): def reload_tts(model=None):
load_tts(restart=True, model=model) load_tts(restart=True, model=model)
def update_autoregressive_model(autoregressive_model_path): def update_autoregressive_model(autoregressive_model_path):
match = re.findall(r'^\[[a-fA-F0-9]{8}\] (.+?)$', autoregressive_model_path) match = re.findall(r'^\[[a-fA-F0-9]{8}\] (.+?)$', autoregressive_model_path)
if match: if match:
@ -1714,7 +1803,8 @@ def update_autoregressive_model(autoregressive_model_path):
else: else:
from tortoise.models.autoregressive import UnifiedVoice from tortoise.models.autoregressive import UnifiedVoice
tts.autoregressive_model_path = autoregressive_model_path if autoregressive_model_path and os.path.exists(autoregressive_model_path) else get_model_path('autoregressive.pth', tts.models_dir) tts.autoregressive_model_path = autoregressive_model_path if autoregressive_model_path and os.path.exists(
autoregressive_model_path) else get_model_path('autoregressive.pth', tts.models_dir)
del tts.autoregressive del tts.autoregressive
tts.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30, tts.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30,
@ -1735,6 +1825,7 @@ def update_autoregressive_model(autoregressive_model_path):
return autoregressive_model_path return autoregressive_model_path
def load_voicefixer(restart=False): def load_voicefixer(restart=False):
global voicefixer global voicefixer
@ -1749,6 +1840,7 @@ def load_voicefixer(restart=False):
except Exception as e: except Exception as e:
print(f"Error occurred while tring to initialize voicefixer: {e}") print(f"Error occurred while tring to initialize voicefixer: {e}")
def unload_voicefixer(): def unload_voicefixer():
global voicefixer global voicefixer
@ -1759,7 +1851,8 @@ def unload_voicefixer():
do_gc() do_gc()
def load_whisper_model(language=None, model_name=None, progress=None):
def load_whisper_model(model_name=None, progress=None):
global whisper_model global whisper_model
if not model_name: if not model_name:
@ -1768,24 +1861,17 @@ def load_whisper_model(language=None, model_name=None, progress=None):
args.whisper_model = model_name args.whisper_model = model_name
save_args_settings() save_args_settings()
if language and f'{model_name}.{language}' in WHISPER_SPECIALIZED_MODELS: if (torch.cuda.is_available()):
model_name = f'{model_name}.{language}' device = "cuda"
print(f"Loading specialized model for language: {language}")
notify_progress(f"Loading Whisper model: {model_name}", progress)
if args.whisper_cpp:
from whispercpp import Whisper
if not language:
language = 'auto'
b_lang = language.encode('ascii')
whisper_model = Whisper(model_name, models_dir='./models/', language=b_lang)
else: else:
import whisper device = "cpu"
whisper_model = whisper.load_model(model_name)
notify_progress(f"Loading WhisperX model: {model_name} using {device}", progress)
whisper_model = whisperx.load_model(model_name, device)
print("Loaded WhisperX model")
print("Loaded Whisper model")
def unload_whisper(): def unload_whisper():
global whisper_model global whisper_model
@ -1793,6 +1879,6 @@ def unload_whisper():
if whisper_model: if whisper_model:
del whisper_model del whisper_model
whisper_model = None whisper_model = None
print("Unloaded Whisper") print("Unloaded WhisperX")
do_gc() do_gc()

View File

@ -590,9 +590,8 @@ def setup_gradio():
autoregressive_model_dropdown = gr.Dropdown(choices=autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else autoregressive_models[0]) autoregressive_model_dropdown = gr.Dropdown(choices=autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else autoregressive_models[0])
whisper_model_dropdown = gr.Dropdown(WHISPER_MODELS, label="Whisper Model", value=args.whisper_model) whisper_model_dropdown = gr.Dropdown(WHISPER_MODELS, label="Whisper Model", value=args.whisper_model)
use_whisper_cpp = gr.Checkbox(label="Use Whisper.cpp", value=args.whisper_cpp)
exec_inputs = exec_inputs + [ autoregressive_model_dropdown, whisper_model_dropdown, use_whisper_cpp, training_halfp, training_bnb ] exec_inputs = exec_inputs + [ autoregressive_model_dropdown, whisper_model_dropdown, training_halfp, training_bnb ]
with gr.Row(): with gr.Row():
autoregressive_models_update_button = gr.Button(value="Refresh Model List") autoregressive_models_update_button = gr.Button(value="Refresh Model List")