forked from ecker/ai-voice-cloning
Update 'src/utils.py'
whisper->whisperx
This commit is contained in:
parent
4f123910fb
commit
f657f30e2b
257
src/utils.py
257
src/utils.py
@ -28,6 +28,7 @@ import music_tag
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
import gradio.utils
|
import gradio.utils
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import whisperx
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
@ -40,7 +41,6 @@ from tortoise.utils.device import get_device_name, set_device_name
|
|||||||
MODELS[
|
MODELS[
|
||||||
'dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth"
|
'dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth"
|
||||||
WHISPER_MODELS = ["tiny", "base", "small", "medium", "large", "large-v2"]
|
WHISPER_MODELS = ["tiny", "base", "small", "medium", "large", "large-v2"]
|
||||||
WHISPER_SPECIALIZED_MODELS = ["tiny.en", "base.en", "small.en", "medium.en"]
|
|
||||||
EPOCH_SCHEDULE = [9, 18, 25, 33]
|
EPOCH_SCHEDULE = [9, 18, 25, 33]
|
||||||
|
|
||||||
args = None
|
args = None
|
||||||
@ -943,13 +943,6 @@ def run_training(config_path, verbose=False, gpus=1, keep_x_past_datasets=0, pro
|
|||||||
training_state = None
|
training_state = None
|
||||||
|
|
||||||
|
|
||||||
def get_training_losses():
|
|
||||||
global training_state
|
|
||||||
if not training_state or not training_state.statistics:
|
|
||||||
return
|
|
||||||
return pd.DataFrame(training_state.statistics)
|
|
||||||
|
|
||||||
|
|
||||||
def update_training_dataplot(config_path=None):
|
def update_training_dataplot(config_path=None):
|
||||||
global training_state
|
global training_state
|
||||||
update = None
|
update = None
|
||||||
@ -958,12 +951,17 @@ def update_training_dataplot(config_path=None):
|
|||||||
if config_path:
|
if config_path:
|
||||||
training_state = TrainingState(config_path=config_path, start=False)
|
training_state = TrainingState(config_path=config_path, start=False)
|
||||||
if training_state.statistics:
|
if training_state.statistics:
|
||||||
update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics))
|
update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics),
|
||||||
|
x_lim=[0, training_state.its], x="step", y="value",
|
||||||
|
title="Training Metrics", color="type", tooltip=['step', 'value', 'type'],
|
||||||
|
width=600, height=350, )
|
||||||
del training_state
|
del training_state
|
||||||
training_state = None
|
training_state = None
|
||||||
elif training_state.statistics:
|
elif training_state.statistics:
|
||||||
training_state.load_losses()
|
training_state.load_losses()
|
||||||
update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics))
|
update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics), x_lim=[0, training_state.its],
|
||||||
|
x="step", y="value", title="Training Metrics", color="type",
|
||||||
|
tooltip=['step', 'value', 'type'], width=600, height=350, )
|
||||||
|
|
||||||
return update
|
return update
|
||||||
|
|
||||||
@ -1033,18 +1031,8 @@ def prepare_dataset(files, outdir, language=None, progress=None):
|
|||||||
unload_tts()
|
unload_tts()
|
||||||
|
|
||||||
global whisper_model
|
global whisper_model
|
||||||
import whisperx
|
if whisper_model is None:
|
||||||
|
load_whisper_model()
|
||||||
device = "cuda" # add cpu option?
|
|
||||||
|
|
||||||
# original whisper https://github.com/openai/whisper
|
|
||||||
# whisperx fork https://github.com/m-bain/whisperX
|
|
||||||
# supports en, fr, de, es, it, ja, zh, nl, uk, pt
|
|
||||||
|
|
||||||
# tiny, base, small, medium, large, large-v2
|
|
||||||
|
|
||||||
whisper_model = whisperx.load_model("medium", device)
|
|
||||||
# some additional model features require huggingface token
|
|
||||||
|
|
||||||
os.makedirs(outdir, exist_ok=True)
|
os.makedirs(outdir, exist_ok=True)
|
||||||
|
|
||||||
@ -1052,6 +1040,15 @@ def prepare_dataset(files, outdir, language=None, progress=None):
|
|||||||
results = {}
|
results = {}
|
||||||
transcription = []
|
transcription = []
|
||||||
|
|
||||||
|
idx = 0
|
||||||
|
results = {}
|
||||||
|
transcription = []
|
||||||
|
|
||||||
|
if (torch.cuda.is_available()):
|
||||||
|
device = "cuda"
|
||||||
|
else:
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress):
|
for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress):
|
||||||
print(f"Transcribing file: {file}")
|
print(f"Transcribing file: {file}")
|
||||||
|
|
||||||
@ -1091,15 +1088,46 @@ def prepare_dataset(files, outdir, language=None, progress=None):
|
|||||||
with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f:
|
with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f:
|
||||||
f.write(f'{line}\n')
|
f.write(f'{line}\n')
|
||||||
|
|
||||||
|
'''for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress):
|
||||||
|
basename = os.path.basename(file)
|
||||||
|
result = whisper_transcribe(file, language=language)
|
||||||
|
results[basename] = result
|
||||||
|
print(f"Transcribed file: {file}, {len(result['segments'])} found.")
|
||||||
|
|
||||||
|
waveform, sampling_rate = torchaudio.load(file)
|
||||||
|
num_channels, num_frames = waveform.shape
|
||||||
|
|
||||||
|
idx = 0
|
||||||
|
for segment in result[
|
||||||
|
'segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress):
|
||||||
|
start = int(segment['start'] * sampling_rate)
|
||||||
|
end = int(segment['end'] * sampling_rate)
|
||||||
|
|
||||||
|
sliced_waveform = waveform[:, start:end]
|
||||||
|
sliced_name = basename.replace(".wav", f"_{pad(idx, 4)}.wav")
|
||||||
|
|
||||||
|
if not torch.any(sliced_waveform < 0):
|
||||||
|
print(f"Error with {sliced_name}, skipping...")
|
||||||
|
continue
|
||||||
|
|
||||||
|
torchaudio.save(f"{outdir}/{sliced_name}", sliced_waveform, sampling_rate)
|
||||||
|
|
||||||
|
idx = idx + 1
|
||||||
|
line = f"{sliced_name}|{segment['text'].strip()}"
|
||||||
|
transcription.append(line)
|
||||||
|
with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f:
|
||||||
|
f.write(f'{line}\n')
|
||||||
|
'''
|
||||||
with open(f'{outdir}/whisper.json', 'w', encoding="utf-8") as f:
|
with open(f'{outdir}/whisper.json', 'w', encoding="utf-8") as f:
|
||||||
f.write(json.dumps(results, indent='\t'))
|
f.write(json.dumps(results, indent='\t'))
|
||||||
|
|
||||||
|
joined = '\n'.join(transcription)
|
||||||
with open(f'{outdir}/train.txt', 'w', encoding="utf-8") as f:
|
with open(f'{outdir}/train.txt', 'w', encoding="utf-8") as f:
|
||||||
f.write("\n".join(transcription))
|
f.write(joined)
|
||||||
|
|
||||||
unload_whisper()
|
unload_whisper()
|
||||||
|
|
||||||
return f"Processed dataset to: {outdir}"
|
return f"Processed dataset to: {outdir}\n{joined}"
|
||||||
|
|
||||||
|
|
||||||
def calc_iterations(epochs, lines, batch_size):
|
def calc_iterations(epochs, lines, batch_size):
|
||||||
@ -1196,159 +1224,6 @@ def optimize_training_settings(epochs, learning_rate, text_ce_lr_weight, learnin
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def save_training_settings(iterations=None, learning_rate=None, text_ce_lr_weight=None, learning_rate_schedule=None,
|
|
||||||
batch_size=None, gradient_accumulation_size=None, print_rate=None, save_rate=None, name=None,
|
|
||||||
dataset_name=None, dataset_path=None, validation_name=None, validation_path=None,
|
|
||||||
output_name=None, resume_path=None, half_p=None, bnb=None, workers=None, source_model=None):
|
|
||||||
if not source_model:
|
|
||||||
source_model = f"./models/tortoise/autoregressive{'_half' if half_p else ''}.pth"
|
|
||||||
|
|
||||||
settings = {
|
|
||||||
"iterations": iterations if iterations else 500,
|
|
||||||
"batch_size": batch_size if batch_size else 64,
|
|
||||||
"learning_rate": learning_rate if learning_rate else 1e-5,
|
|
||||||
"gen_lr_steps": learning_rate_schedule if learning_rate_schedule else EPOCH_SCHEDULE,
|
|
||||||
"gradient_accumulation_size": gradient_accumulation_size if gradient_accumulation_size else 4,
|
|
||||||
"print_rate": print_rate if print_rate else 1,
|
|
||||||
"save_rate": save_rate if save_rate else 50,
|
|
||||||
"name": name if name else "finetune",
|
|
||||||
"dataset_name": dataset_name if dataset_name else "finetune",
|
|
||||||
"dataset_path": dataset_path if dataset_path else "./training/finetune/train.txt",
|
|
||||||
"validation_name": validation_name if validation_name else "finetune",
|
|
||||||
"validation_path": validation_path if validation_path else "./training/finetune/train.txt",
|
|
||||||
|
|
||||||
"text_ce_lr_weight": text_ce_lr_weight if text_ce_lr_weight else 0.01,
|
|
||||||
|
|
||||||
'resume_state': f"resume_state: '{resume_path}'",
|
|
||||||
'pretrain_model_gpt': f"pretrain_model_gpt: '{source_model}'",
|
|
||||||
|
|
||||||
'float16': 'true' if half_p else 'false',
|
|
||||||
'bitsandbytes': 'true' if bnb else 'false',
|
|
||||||
|
|
||||||
'workers': workers if workers else 2,
|
|
||||||
}
|
|
||||||
|
|
||||||
if resume_path:
|
|
||||||
settings['pretrain_model_gpt'] = f"# {settings['pretrain_model_gpt']}"
|
|
||||||
else:
|
|
||||||
settings['resume_state'] = f"# resume_state: './training/{name if name else 'finetune'}/training_state/#.state'"
|
|
||||||
|
|
||||||
if half_p:
|
|
||||||
if not os.path.exists(get_halfp_model_path()):
|
|
||||||
convert_to_halfp()
|
|
||||||
|
|
||||||
if not output_name:
|
|
||||||
output_name = f'{settings["name"]}.yaml'
|
|
||||||
|
|
||||||
with open(f'./models/.template.yaml', 'r', encoding="utf-8") as f:
|
|
||||||
yaml = f.read()
|
|
||||||
|
|
||||||
# i could just load and edit the YAML directly, but this is easier, as I don't need to bother with path traversals
|
|
||||||
for k in settings:
|
|
||||||
if settings[k] is None:
|
|
||||||
continue
|
|
||||||
yaml = yaml.replace(f"${{{k}}}", str(settings[k]))
|
|
||||||
|
|
||||||
outfile = f'./training/{output_name}'
|
|
||||||
with open(outfile, 'w', encoding="utf-8") as f:
|
|
||||||
f.write(yaml)
|
|
||||||
|
|
||||||
return f"Training settings saved to: {outfile}"
|
|
||||||
|
|
||||||
def calc_iterations(epochs, lines, batch_size):
|
|
||||||
iterations = int(epochs * lines / float(batch_size))
|
|
||||||
return iterations
|
|
||||||
|
|
||||||
|
|
||||||
def schedule_learning_rate(iterations, schedule=EPOCH_SCHEDULE):
|
|
||||||
return [int(iterations * d) for d in schedule]
|
|
||||||
|
|
||||||
|
|
||||||
def optimize_training_settings(epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size,
|
|
||||||
gradient_accumulation_size, print_rate, save_rate, resume_path, half_p, bnb, workers,
|
|
||||||
source_model, voice):
|
|
||||||
name = f"{voice}-finetune"
|
|
||||||
dataset_name = f"{voice}-train"
|
|
||||||
dataset_path = f"./training/{voice}/train.txt"
|
|
||||||
validation_name = f"{voice}-val"
|
|
||||||
validation_path = f"./training/{voice}/train.txt"
|
|
||||||
|
|
||||||
with open(dataset_path, 'r', encoding="utf-8") as f:
|
|
||||||
lines = len(f.readlines())
|
|
||||||
|
|
||||||
messages = []
|
|
||||||
|
|
||||||
if batch_size > lines:
|
|
||||||
batch_size = lines
|
|
||||||
messages.append(f"Batch size is larger than your dataset, clamping batch size to: {batch_size}")
|
|
||||||
|
|
||||||
if batch_size % lines != 0:
|
|
||||||
nearest_slice = int(lines / batch_size) + 1
|
|
||||||
batch_size = int(lines / nearest_slice)
|
|
||||||
messages.append(
|
|
||||||
f"Batch size not neatly divisible by dataset size, adjusting batch size to: {batch_size} ({nearest_slice} steps per epoch)")
|
|
||||||
|
|
||||||
if gradient_accumulation_size == 0:
|
|
||||||
gradient_accumulation_size = 1
|
|
||||||
|
|
||||||
if batch_size / gradient_accumulation_size < 2:
|
|
||||||
gradient_accumulation_size = int(batch_size / 2)
|
|
||||||
if gradient_accumulation_size == 0:
|
|
||||||
gradient_accumulation_size = 1
|
|
||||||
|
|
||||||
messages.append(
|
|
||||||
f"Gradient accumulation size is too large for a given batch size, clamping gradient accumulation size to: {gradient_accumulation_size}")
|
|
||||||
elif batch_size % gradient_accumulation_size != 0:
|
|
||||||
gradient_accumulation_size = int(batch_size / gradient_accumulation_size)
|
|
||||||
if gradient_accumulation_size == 0:
|
|
||||||
gradient_accumulation_size = 1
|
|
||||||
|
|
||||||
messages.append(
|
|
||||||
f"Batch size is not evenly divisible by the gradient accumulation size, adjusting gradient accumulation size to: {gradient_accumulation_size}")
|
|
||||||
|
|
||||||
iterations = calc_iterations(epochs=epochs, lines=lines, batch_size=batch_size)
|
|
||||||
|
|
||||||
if epochs < print_rate:
|
|
||||||
print_rate = epochs
|
|
||||||
messages.append(f"Print rate is too small for the given iteration step, clamping print rate to: {print_rate}")
|
|
||||||
|
|
||||||
if epochs < save_rate:
|
|
||||||
save_rate = epochs
|
|
||||||
messages.append(f"Save rate is too small for the given iteration step, clamping save rate to: {save_rate}")
|
|
||||||
|
|
||||||
if resume_path and not os.path.exists(resume_path):
|
|
||||||
resume_path = None
|
|
||||||
messages.append("Resume path specified, but does not exist. Disabling...")
|
|
||||||
|
|
||||||
if bnb:
|
|
||||||
messages.append("BitsAndBytes requested. Please note this is ! EXPERIMENTAL !")
|
|
||||||
|
|
||||||
if half_p:
|
|
||||||
if bnb:
|
|
||||||
half_p = False
|
|
||||||
messages.append(
|
|
||||||
"Half Precision requested, but BitsAndBytes is also requested. Due to redundancies, disabling half precision...")
|
|
||||||
else:
|
|
||||||
messages.append("Half Precision requested. Please note this is ! EXPERIMENTAL !")
|
|
||||||
if not os.path.exists(get_halfp_model_path()):
|
|
||||||
convert_to_halfp()
|
|
||||||
|
|
||||||
messages.append(
|
|
||||||
f"For {epochs} epochs with {lines} lines in batches of {batch_size}, iterating for {iterations} steps ({int(iterations / epochs)} steps per epoch)")
|
|
||||||
|
|
||||||
return (
|
|
||||||
learning_rate,
|
|
||||||
text_ce_lr_weight,
|
|
||||||
learning_rate_schedule,
|
|
||||||
batch_size,
|
|
||||||
gradient_accumulation_size,
|
|
||||||
print_rate,
|
|
||||||
save_rate,
|
|
||||||
resume_path,
|
|
||||||
messages
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def save_training_settings(iterations=None, learning_rate=None, text_ce_lr_weight=None, learning_rate_schedule=None,
|
def save_training_settings(iterations=None, learning_rate=None, text_ce_lr_weight=None, learning_rate_schedule=None,
|
||||||
batch_size=None, gradient_accumulation_size=None, print_rate=None, save_rate=None, name=None,
|
batch_size=None, gradient_accumulation_size=None, print_rate=None, save_rate=None, name=None,
|
||||||
dataset_name=None, dataset_path=None, validation_name=None, validation_path=None,
|
dataset_name=None, dataset_path=None, validation_name=None, validation_path=None,
|
||||||
@ -2007,7 +1882,7 @@ def unload_voicefixer():
|
|||||||
do_gc()
|
do_gc()
|
||||||
|
|
||||||
|
|
||||||
def load_whisper_model(language=None, model_name=None, progress=None):
|
def load_whisper_model(model_name=None, progress=None):
|
||||||
global whisper_model
|
global whisper_model
|
||||||
|
|
||||||
if not model_name:
|
if not model_name:
|
||||||
@ -2016,24 +1891,16 @@ def load_whisper_model(language=None, model_name=None, progress=None):
|
|||||||
args.whisper_model = model_name
|
args.whisper_model = model_name
|
||||||
save_args_settings()
|
save_args_settings()
|
||||||
|
|
||||||
if language and f'{model_name}.{language}' in WHISPER_SPECIALIZED_MODELS:
|
if (torch.cuda.is_available()):
|
||||||
model_name = f'{model_name}.{language}'
|
device = "cuda"
|
||||||
print(f"Loading specialized model for language: {language}")
|
|
||||||
|
|
||||||
notify_progress(f"Loading Whisper model: {model_name}", progress)
|
|
||||||
|
|
||||||
if args.whisper_cpp:
|
|
||||||
from whispercpp import Whisper
|
|
||||||
if not language:
|
|
||||||
language = 'auto'
|
|
||||||
|
|
||||||
b_lang = language.encode('ascii')
|
|
||||||
whisper_model = Whisper(model_name, models_dir='./models/', language=b_lang)
|
|
||||||
else:
|
else:
|
||||||
import whisper
|
device = "cpu"
|
||||||
whisper_model = whisper.load_model(model_name)
|
|
||||||
|
|
||||||
print("Loaded Whisper model")
|
notify_progress(f"Loading WhisperX model: {model_name} using {device}", progress)
|
||||||
|
|
||||||
|
whisper_model = whisperx.load_model(model_name, device)
|
||||||
|
|
||||||
|
print("Loaded WhisperX model")
|
||||||
|
|
||||||
|
|
||||||
def unload_whisper():
|
def unload_whisper():
|
||||||
@ -2042,6 +1909,6 @@ def unload_whisper():
|
|||||||
if whisper_model:
|
if whisper_model:
|
||||||
del whisper_model
|
del whisper_model
|
||||||
whisper_model = None
|
whisper_model = None
|
||||||
print("Unloaded Whisper")
|
print("Unloaded WhisperX")
|
||||||
|
|
||||||
do_gc()
|
do_gc()
|
||||||
Loading…
Reference in New Issue
Block a user