Update 'src/utils.py'

whisper->whisperx
This commit is contained in:
yqxtqymn 2023-03-06 01:59:58 +00:00
parent 4f123910fb
commit f657f30e2b

View File

@ -28,6 +28,7 @@ import music_tag
import gradio as gr import gradio as gr
import gradio.utils import gradio.utils
import pandas as pd import pandas as pd
import whisperx
from datetime import datetime from datetime import datetime
from datetime import timedelta from datetime import timedelta
@ -40,7 +41,6 @@ from tortoise.utils.device import get_device_name, set_device_name
MODELS[ MODELS[
'dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth" 'dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth"
WHISPER_MODELS = ["tiny", "base", "small", "medium", "large", "large-v2"] WHISPER_MODELS = ["tiny", "base", "small", "medium", "large", "large-v2"]
WHISPER_SPECIALIZED_MODELS = ["tiny.en", "base.en", "small.en", "medium.en"]
EPOCH_SCHEDULE = [9, 18, 25, 33] EPOCH_SCHEDULE = [9, 18, 25, 33]
args = None args = None
@ -943,13 +943,6 @@ def run_training(config_path, verbose=False, gpus=1, keep_x_past_datasets=0, pro
training_state = None training_state = None
def get_training_losses():
global training_state
if not training_state or not training_state.statistics:
return
return pd.DataFrame(training_state.statistics)
def update_training_dataplot(config_path=None): def update_training_dataplot(config_path=None):
global training_state global training_state
update = None update = None
@ -958,12 +951,17 @@ def update_training_dataplot(config_path=None):
if config_path: if config_path:
training_state = TrainingState(config_path=config_path, start=False) training_state = TrainingState(config_path=config_path, start=False)
if training_state.statistics: if training_state.statistics:
update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics)) update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics),
x_lim=[0, training_state.its], x="step", y="value",
title="Training Metrics", color="type", tooltip=['step', 'value', 'type'],
width=600, height=350, )
del training_state del training_state
training_state = None training_state = None
elif training_state.statistics: elif training_state.statistics:
training_state.load_losses() training_state.load_losses()
update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics)) update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics), x_lim=[0, training_state.its],
x="step", y="value", title="Training Metrics", color="type",
tooltip=['step', 'value', 'type'], width=600, height=350, )
return update return update
@ -1033,18 +1031,8 @@ def prepare_dataset(files, outdir, language=None, progress=None):
unload_tts() unload_tts()
global whisper_model global whisper_model
import whisperx if whisper_model is None:
load_whisper_model()
device = "cuda" # add cpu option?
# original whisper https://github.com/openai/whisper
# whisperx fork https://github.com/m-bain/whisperX
# supports en, fr, de, es, it, ja, zh, nl, uk, pt
# tiny, base, small, medium, large, large-v2
whisper_model = whisperx.load_model("medium", device)
# some additional model features require huggingface token
os.makedirs(outdir, exist_ok=True) os.makedirs(outdir, exist_ok=True)
@ -1052,6 +1040,15 @@ def prepare_dataset(files, outdir, language=None, progress=None):
results = {} results = {}
transcription = [] transcription = []
idx = 0
results = {}
transcription = []
if (torch.cuda.is_available()):
device = "cuda"
else:
device = "cpu"
for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress): for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress):
print(f"Transcribing file: {file}") print(f"Transcribing file: {file}")
@ -1091,15 +1088,46 @@ def prepare_dataset(files, outdir, language=None, progress=None):
with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f: with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f:
f.write(f'{line}\n') f.write(f'{line}\n')
'''for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress):
basename = os.path.basename(file)
result = whisper_transcribe(file, language=language)
results[basename] = result
print(f"Transcribed file: {file}, {len(result['segments'])} found.")
waveform, sampling_rate = torchaudio.load(file)
num_channels, num_frames = waveform.shape
idx = 0
for segment in result[
'segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress):
start = int(segment['start'] * sampling_rate)
end = int(segment['end'] * sampling_rate)
sliced_waveform = waveform[:, start:end]
sliced_name = basename.replace(".wav", f"_{pad(idx, 4)}.wav")
if not torch.any(sliced_waveform < 0):
print(f"Error with {sliced_name}, skipping...")
continue
torchaudio.save(f"{outdir}/{sliced_name}", sliced_waveform, sampling_rate)
idx = idx + 1
line = f"{sliced_name}|{segment['text'].strip()}"
transcription.append(line)
with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f:
f.write(f'{line}\n')
'''
with open(f'{outdir}/whisper.json', 'w', encoding="utf-8") as f: with open(f'{outdir}/whisper.json', 'w', encoding="utf-8") as f:
f.write(json.dumps(results, indent='\t')) f.write(json.dumps(results, indent='\t'))
joined = '\n'.join(transcription)
with open(f'{outdir}/train.txt', 'w', encoding="utf-8") as f: with open(f'{outdir}/train.txt', 'w', encoding="utf-8") as f:
f.write("\n".join(transcription)) f.write(joined)
unload_whisper() unload_whisper()
return f"Processed dataset to: {outdir}" return f"Processed dataset to: {outdir}\n{joined}"
def calc_iterations(epochs, lines, batch_size): def calc_iterations(epochs, lines, batch_size):
@ -1196,159 +1224,6 @@ def optimize_training_settings(epochs, learning_rate, text_ce_lr_weight, learnin
) )
def save_training_settings(iterations=None, learning_rate=None, text_ce_lr_weight=None, learning_rate_schedule=None,
batch_size=None, gradient_accumulation_size=None, print_rate=None, save_rate=None, name=None,
dataset_name=None, dataset_path=None, validation_name=None, validation_path=None,
output_name=None, resume_path=None, half_p=None, bnb=None, workers=None, source_model=None):
if not source_model:
source_model = f"./models/tortoise/autoregressive{'_half' if half_p else ''}.pth"
settings = {
"iterations": iterations if iterations else 500,
"batch_size": batch_size if batch_size else 64,
"learning_rate": learning_rate if learning_rate else 1e-5,
"gen_lr_steps": learning_rate_schedule if learning_rate_schedule else EPOCH_SCHEDULE,
"gradient_accumulation_size": gradient_accumulation_size if gradient_accumulation_size else 4,
"print_rate": print_rate if print_rate else 1,
"save_rate": save_rate if save_rate else 50,
"name": name if name else "finetune",
"dataset_name": dataset_name if dataset_name else "finetune",
"dataset_path": dataset_path if dataset_path else "./training/finetune/train.txt",
"validation_name": validation_name if validation_name else "finetune",
"validation_path": validation_path if validation_path else "./training/finetune/train.txt",
"text_ce_lr_weight": text_ce_lr_weight if text_ce_lr_weight else 0.01,
'resume_state': f"resume_state: '{resume_path}'",
'pretrain_model_gpt': f"pretrain_model_gpt: '{source_model}'",
'float16': 'true' if half_p else 'false',
'bitsandbytes': 'true' if bnb else 'false',
'workers': workers if workers else 2,
}
if resume_path:
settings['pretrain_model_gpt'] = f"# {settings['pretrain_model_gpt']}"
else:
settings['resume_state'] = f"# resume_state: './training/{name if name else 'finetune'}/training_state/#.state'"
if half_p:
if not os.path.exists(get_halfp_model_path()):
convert_to_halfp()
if not output_name:
output_name = f'{settings["name"]}.yaml'
with open(f'./models/.template.yaml', 'r', encoding="utf-8") as f:
yaml = f.read()
# i could just load and edit the YAML directly, but this is easier, as I don't need to bother with path traversals
for k in settings:
if settings[k] is None:
continue
yaml = yaml.replace(f"${{{k}}}", str(settings[k]))
outfile = f'./training/{output_name}'
with open(outfile, 'w', encoding="utf-8") as f:
f.write(yaml)
return f"Training settings saved to: {outfile}"
def calc_iterations(epochs, lines, batch_size):
iterations = int(epochs * lines / float(batch_size))
return iterations
def schedule_learning_rate(iterations, schedule=EPOCH_SCHEDULE):
return [int(iterations * d) for d in schedule]
def optimize_training_settings(epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size,
gradient_accumulation_size, print_rate, save_rate, resume_path, half_p, bnb, workers,
source_model, voice):
name = f"{voice}-finetune"
dataset_name = f"{voice}-train"
dataset_path = f"./training/{voice}/train.txt"
validation_name = f"{voice}-val"
validation_path = f"./training/{voice}/train.txt"
with open(dataset_path, 'r', encoding="utf-8") as f:
lines = len(f.readlines())
messages = []
if batch_size > lines:
batch_size = lines
messages.append(f"Batch size is larger than your dataset, clamping batch size to: {batch_size}")
if batch_size % lines != 0:
nearest_slice = int(lines / batch_size) + 1
batch_size = int(lines / nearest_slice)
messages.append(
f"Batch size not neatly divisible by dataset size, adjusting batch size to: {batch_size} ({nearest_slice} steps per epoch)")
if gradient_accumulation_size == 0:
gradient_accumulation_size = 1
if batch_size / gradient_accumulation_size < 2:
gradient_accumulation_size = int(batch_size / 2)
if gradient_accumulation_size == 0:
gradient_accumulation_size = 1
messages.append(
f"Gradient accumulation size is too large for a given batch size, clamping gradient accumulation size to: {gradient_accumulation_size}")
elif batch_size % gradient_accumulation_size != 0:
gradient_accumulation_size = int(batch_size / gradient_accumulation_size)
if gradient_accumulation_size == 0:
gradient_accumulation_size = 1
messages.append(
f"Batch size is not evenly divisible by the gradient accumulation size, adjusting gradient accumulation size to: {gradient_accumulation_size}")
iterations = calc_iterations(epochs=epochs, lines=lines, batch_size=batch_size)
if epochs < print_rate:
print_rate = epochs
messages.append(f"Print rate is too small for the given iteration step, clamping print rate to: {print_rate}")
if epochs < save_rate:
save_rate = epochs
messages.append(f"Save rate is too small for the given iteration step, clamping save rate to: {save_rate}")
if resume_path and not os.path.exists(resume_path):
resume_path = None
messages.append("Resume path specified, but does not exist. Disabling...")
if bnb:
messages.append("BitsAndBytes requested. Please note this is ! EXPERIMENTAL !")
if half_p:
if bnb:
half_p = False
messages.append(
"Half Precision requested, but BitsAndBytes is also requested. Due to redundancies, disabling half precision...")
else:
messages.append("Half Precision requested. Please note this is ! EXPERIMENTAL !")
if not os.path.exists(get_halfp_model_path()):
convert_to_halfp()
messages.append(
f"For {epochs} epochs with {lines} lines in batches of {batch_size}, iterating for {iterations} steps ({int(iterations / epochs)} steps per epoch)")
return (
learning_rate,
text_ce_lr_weight,
learning_rate_schedule,
batch_size,
gradient_accumulation_size,
print_rate,
save_rate,
resume_path,
messages
)
def save_training_settings(iterations=None, learning_rate=None, text_ce_lr_weight=None, learning_rate_schedule=None, def save_training_settings(iterations=None, learning_rate=None, text_ce_lr_weight=None, learning_rate_schedule=None,
batch_size=None, gradient_accumulation_size=None, print_rate=None, save_rate=None, name=None, batch_size=None, gradient_accumulation_size=None, print_rate=None, save_rate=None, name=None,
dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None,
@ -2007,7 +1882,7 @@ def unload_voicefixer():
do_gc() do_gc()
def load_whisper_model(language=None, model_name=None, progress=None): def load_whisper_model(model_name=None, progress=None):
global whisper_model global whisper_model
if not model_name: if not model_name:
@ -2016,24 +1891,16 @@ def load_whisper_model(language=None, model_name=None, progress=None):
args.whisper_model = model_name args.whisper_model = model_name
save_args_settings() save_args_settings()
if language and f'{model_name}.{language}' in WHISPER_SPECIALIZED_MODELS: if (torch.cuda.is_available()):
model_name = f'{model_name}.{language}' device = "cuda"
print(f"Loading specialized model for language: {language}")
notify_progress(f"Loading Whisper model: {model_name}", progress)
if args.whisper_cpp:
from whispercpp import Whisper
if not language:
language = 'auto'
b_lang = language.encode('ascii')
whisper_model = Whisper(model_name, models_dir='./models/', language=b_lang)
else: else:
import whisper device = "cpu"
whisper_model = whisper.load_model(model_name)
print("Loaded Whisper model") notify_progress(f"Loading WhisperX model: {model_name} using {device}", progress)
whisper_model = whisperx.load_model(model_name, device)
print("Loaded WhisperX model")
def unload_whisper(): def unload_whisper():
@ -2042,6 +1909,6 @@ def unload_whisper():
if whisper_model: if whisper_model:
del whisper_model del whisper_model
whisper_model = None whisper_model = None
print("Unloaded Whisper") print("Unloaded WhisperX")
do_gc() do_gc()