Update 'src/utils.py'

whisper->whisperx
This commit is contained in:
yqxtqymn 2023-03-06 01:59:58 +00:00
parent 4f123910fb
commit f657f30e2b

View File

@ -28,6 +28,7 @@ import music_tag
import gradio as gr import gradio as gr
import gradio.utils import gradio.utils
import pandas as pd import pandas as pd
import whisperx
from datetime import datetime from datetime import datetime
from datetime import timedelta from datetime import timedelta
@ -40,7 +41,6 @@ from tortoise.utils.device import get_device_name, set_device_name
MODELS[ MODELS[
'dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth" 'dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth"
WHISPER_MODELS = ["tiny", "base", "small", "medium", "large", "large-v2"] WHISPER_MODELS = ["tiny", "base", "small", "medium", "large", "large-v2"]
WHISPER_SPECIALIZED_MODELS = ["tiny.en", "base.en", "small.en", "medium.en"]
EPOCH_SCHEDULE = [9, 18, 25, 33] EPOCH_SCHEDULE = [9, 18, 25, 33]
args = None args = None
@ -332,17 +332,17 @@ def generate(
} }
""" """
# kludgy yucky codesmells # kludgy yucky codesmells
for name in audio_cache: for name in audio_cache:
if 'output' not in audio_cache[name]: if 'output' not in audio_cache[name]:
continue continue
#output_voices.append(f'{outdir}/{voice}_{name}.wav') #output_voices.append(f'{outdir}/{voice}_{name}.wav')
output_voices.append(name) output_voices.append(name)
if not args.embed_output_metadata: if not args.embed_output_metadata:
with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f: with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f:
f.write(json.dumps(info, indent='\t') ) f.write(json.dumps(info, indent='\t') )
""" """
if args.voice_fixer: if args.voice_fixer:
if not voicefixer: if not voicefixer:
@ -943,13 +943,6 @@ def run_training(config_path, verbose=False, gpus=1, keep_x_past_datasets=0, pro
training_state = None training_state = None
def get_training_losses():
global training_state
if not training_state or not training_state.statistics:
return
return pd.DataFrame(training_state.statistics)
def update_training_dataplot(config_path=None): def update_training_dataplot(config_path=None):
global training_state global training_state
update = None update = None
@ -958,12 +951,17 @@ def update_training_dataplot(config_path=None):
if config_path: if config_path:
training_state = TrainingState(config_path=config_path, start=False) training_state = TrainingState(config_path=config_path, start=False)
if training_state.statistics: if training_state.statistics:
update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics)) update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics),
x_lim=[0, training_state.its], x="step", y="value",
title="Training Metrics", color="type", tooltip=['step', 'value', 'type'],
width=600, height=350, )
del training_state del training_state
training_state = None training_state = None
elif training_state.statistics: elif training_state.statistics:
training_state.load_losses() training_state.load_losses()
update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics)) update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics), x_lim=[0, training_state.its],
x="step", y="value", title="Training Metrics", color="type",
tooltip=['step', 'value', 'type'], width=600, height=350, )
return update return update
@ -1030,231 +1028,108 @@ def convert_to_halfp():
def prepare_dataset(files, outdir, language=None, progress=None): def prepare_dataset(files, outdir, language=None, progress=None):
unload_tts() unload_tts()
global whisper_model global whisper_model
import whisperx if whisper_model is None:
load_whisper_model()
device = "cuda" # add cpu option? os.makedirs(outdir, exist_ok=True)
# original whisper https://github.com/openai/whisper idx = 0
# whisperx fork https://github.com/m-bain/whisperX results = {}
# supports en, fr, de, es, it, ja, zh, nl, uk, pt transcription = []
# tiny, base, small, medium, large, large-v2 idx = 0
results = {}
transcription = []
whisper_model = whisperx.load_model("medium", device) if (torch.cuda.is_available()):
# some additional model features require huggingface token device = "cuda"
else:
device = "cpu"
os.makedirs(outdir, exist_ok=True) for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress):
print(f"Transcribing file: {file}")
idx = 0 result = whisper_model.transcribe(file)
results = {}
transcription = []
for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress): print(result["segments"]) # before alignment
print(f"Transcribing file: {file}")
result = whisper_model.transcribe(file) # load alignment model and metadata
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
print(result["segments"]) # before alignment # align whisper output
result_aligned = whisperx.align(result["segments"], model_a, metadata, file, device)
# load alignment model and metadata print(result_aligned["segments"]) # after alignment
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device) print(result_aligned["word_segments"]) # after alignment
# align whisper output results[os.path.basename(file)] = result
result_aligned = whisperx.align(result["segments"], model_a, metadata, file, device)
print(result_aligned["segments"]) # after alignment print(f"Transcribed file: {file}, {len(result['segments'])} found.")
print(result_aligned["word_segments"]) # after alignment
results[os.path.basename(file)] = result waveform, sampling_rate = torchaudio.load(file)
num_channels, num_frames = waveform.shape
print(f"Transcribed file: {file}, {len(result['segments'])} found.") for segment in result[
'segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress):
start = int(segment['start'] * sampling_rate)
end = int(segment['end'] * sampling_rate)
waveform, sampling_rate = torchaudio.load(file) sliced_waveform = waveform[:, start:end]
num_channels, num_frames = waveform.shape sliced_name = f"{pad(idx, 4)}.wav"
for segment in result[ torchaudio.save(f"{outdir}/{sliced_name}", sliced_waveform, sampling_rate)
'segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress):
start = int(segment['start'] * sampling_rate)
end = int(segment['end'] * sampling_rate)
sliced_waveform = waveform[:, start:end] idx = idx + 1
sliced_name = f"{pad(idx, 4)}.wav" line = f"{sliced_name}|{segment['text'].strip()}"
transcription.append(line)
with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f:
f.write(f'{line}\n')
torchaudio.save(f"{outdir}/{sliced_name}", sliced_waveform, sampling_rate) '''for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress):
basename = os.path.basename(file)
result = whisper_transcribe(file, language=language)
results[basename] = result
print(f"Transcribed file: {file}, {len(result['segments'])} found.")
idx = idx + 1 waveform, sampling_rate = torchaudio.load(file)
line = f"{sliced_name}|{segment['text'].strip()}" num_channels, num_frames = waveform.shape
transcription.append(line)
with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f:
f.write(f'{line}\n')
with open(f'{outdir}/whisper.json', 'w', encoding="utf-8") as f: idx = 0
f.write(json.dumps(results, indent='\t')) for segment in result[
'segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress):
start = int(segment['start'] * sampling_rate)
end = int(segment['end'] * sampling_rate)
with open(f'{outdir}/train.txt', 'w', encoding="utf-8") as f: sliced_waveform = waveform[:, start:end]
f.write("\n".join(transcription)) sliced_name = basename.replace(".wav", f"_{pad(idx, 4)}.wav")
unload_whisper() if not torch.any(sliced_waveform < 0):
print(f"Error with {sliced_name}, skipping...")
continue
return f"Processed dataset to: {outdir}" torchaudio.save(f"{outdir}/{sliced_name}", sliced_waveform, sampling_rate)
idx = idx + 1
line = f"{sliced_name}|{segment['text'].strip()}"
transcription.append(line)
with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f:
f.write(f'{line}\n')
'''
with open(f'{outdir}/whisper.json', 'w', encoding="utf-8") as f:
f.write(json.dumps(results, indent='\t'))
def calc_iterations(epochs, lines, batch_size): joined = '\n'.join(transcription)
iterations = int(epochs * lines / float(batch_size)) with open(f'{outdir}/train.txt', 'w', encoding="utf-8") as f:
return iterations f.write(joined)
unload_whisper()
def schedule_learning_rate(iterations, schedule=EPOCH_SCHEDULE): return f"Processed dataset to: {outdir}\n{joined}"
return [int(iterations * d) for d in schedule]
def optimize_training_settings(epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size,
gradient_accumulation_size, print_rate, save_rate, resume_path, half_p, bnb, workers,
source_model, voice):
name = f"{voice}-finetune"
dataset_name = f"{voice}-train"
dataset_path = f"./training/{voice}/train.txt"
validation_name = f"{voice}-val"
validation_path = f"./training/{voice}/train.txt"
with open(dataset_path, 'r', encoding="utf-8") as f:
lines = len(f.readlines())
messages = []
if batch_size > lines:
batch_size = lines
messages.append(f"Batch size is larger than your dataset, clamping batch size to: {batch_size}")
if batch_size % lines != 0:
nearest_slice = int(lines / batch_size) + 1
batch_size = int(lines / nearest_slice)
messages.append(
f"Batch size not neatly divisible by dataset size, adjusting batch size to: {batch_size} ({nearest_slice} steps per epoch)")
if gradient_accumulation_size == 0:
gradient_accumulation_size = 1
if batch_size / gradient_accumulation_size < 2:
gradient_accumulation_size = int(batch_size / 2)
if gradient_accumulation_size == 0:
gradient_accumulation_size = 1
messages.append(
f"Gradient accumulation size is too large for a given batch size, clamping gradient accumulation size to: {gradient_accumulation_size}")
elif batch_size % gradient_accumulation_size != 0:
gradient_accumulation_size = int(batch_size / gradient_accumulation_size)
if gradient_accumulation_size == 0:
gradient_accumulation_size = 1
messages.append(
f"Batch size is not evenly divisible by the gradient accumulation size, adjusting gradient accumulation size to: {gradient_accumulation_size}")
iterations = calc_iterations(epochs=epochs, lines=lines, batch_size=batch_size)
if epochs < print_rate:
print_rate = epochs
messages.append(f"Print rate is too small for the given iteration step, clamping print rate to: {print_rate}")
if epochs < save_rate:
save_rate = epochs
messages.append(f"Save rate is too small for the given iteration step, clamping save rate to: {save_rate}")
if resume_path and not os.path.exists(resume_path):
resume_path = None
messages.append("Resume path specified, but does not exist. Disabling...")
if bnb:
messages.append("BitsAndBytes requested. Please note this is ! EXPERIMENTAL !")
if half_p:
if bnb:
half_p = False
messages.append(
"Half Precision requested, but BitsAndBytes is also requested. Due to redundancies, disabling half precision...")
else:
messages.append("Half Precision requested. Please note this is ! EXPERIMENTAL !")
if not os.path.exists(get_halfp_model_path()):
convert_to_halfp()
messages.append(
f"For {epochs} epochs with {lines} lines in batches of {batch_size}, iterating for {iterations} steps ({int(iterations / epochs)} steps per epoch)")
return (
learning_rate,
text_ce_lr_weight,
learning_rate_schedule,
batch_size,
gradient_accumulation_size,
print_rate,
save_rate,
resume_path,
messages
)
def save_training_settings(iterations=None, learning_rate=None, text_ce_lr_weight=None, learning_rate_schedule=None,
batch_size=None, gradient_accumulation_size=None, print_rate=None, save_rate=None, name=None,
dataset_name=None, dataset_path=None, validation_name=None, validation_path=None,
output_name=None, resume_path=None, half_p=None, bnb=None, workers=None, source_model=None):
if not source_model:
source_model = f"./models/tortoise/autoregressive{'_half' if half_p else ''}.pth"
settings = {
"iterations": iterations if iterations else 500,
"batch_size": batch_size if batch_size else 64,
"learning_rate": learning_rate if learning_rate else 1e-5,
"gen_lr_steps": learning_rate_schedule if learning_rate_schedule else EPOCH_SCHEDULE,
"gradient_accumulation_size": gradient_accumulation_size if gradient_accumulation_size else 4,
"print_rate": print_rate if print_rate else 1,
"save_rate": save_rate if save_rate else 50,
"name": name if name else "finetune",
"dataset_name": dataset_name if dataset_name else "finetune",
"dataset_path": dataset_path if dataset_path else "./training/finetune/train.txt",
"validation_name": validation_name if validation_name else "finetune",
"validation_path": validation_path if validation_path else "./training/finetune/train.txt",
"text_ce_lr_weight": text_ce_lr_weight if text_ce_lr_weight else 0.01,
'resume_state': f"resume_state: '{resume_path}'",
'pretrain_model_gpt': f"pretrain_model_gpt: '{source_model}'",
'float16': 'true' if half_p else 'false',
'bitsandbytes': 'true' if bnb else 'false',
'workers': workers if workers else 2,
}
if resume_path:
settings['pretrain_model_gpt'] = f"# {settings['pretrain_model_gpt']}"
else:
settings['resume_state'] = f"# resume_state: './training/{name if name else 'finetune'}/training_state/#.state'"
if half_p:
if not os.path.exists(get_halfp_model_path()):
convert_to_halfp()
if not output_name:
output_name = f'{settings["name"]}.yaml'
with open(f'./models/.template.yaml', 'r', encoding="utf-8") as f:
yaml = f.read()
# i could just load and edit the YAML directly, but this is easier, as I don't need to bother with path traversals
for k in settings:
if settings[k] is None:
continue
yaml = yaml.replace(f"${{{k}}}", str(settings[k]))
outfile = f'./training/{output_name}'
with open(outfile, 'w', encoding="utf-8") as f:
f.write(yaml)
return f"Training settings saved to: {outfile}"
def calc_iterations(epochs, lines, batch_size): def calc_iterations(epochs, lines, batch_size):
iterations = int(epochs * lines / float(batch_size)) iterations = int(epochs * lines / float(batch_size))
return iterations return iterations
@ -2007,7 +1882,7 @@ def unload_voicefixer():
do_gc() do_gc()
def load_whisper_model(language=None, model_name=None, progress=None): def load_whisper_model(model_name=None, progress=None):
global whisper_model global whisper_model
if not model_name: if not model_name:
@ -2016,24 +1891,16 @@ def load_whisper_model(language=None, model_name=None, progress=None):
args.whisper_model = model_name args.whisper_model = model_name
save_args_settings() save_args_settings()
if language and f'{model_name}.{language}' in WHISPER_SPECIALIZED_MODELS: if (torch.cuda.is_available()):
model_name = f'{model_name}.{language}' device = "cuda"
print(f"Loading specialized model for language: {language}")
notify_progress(f"Loading Whisper model: {model_name}", progress)
if args.whisper_cpp:
from whispercpp import Whisper
if not language:
language = 'auto'
b_lang = language.encode('ascii')
whisper_model = Whisper(model_name, models_dir='./models/', language=b_lang)
else: else:
import whisper device = "cpu"
whisper_model = whisper.load_model(model_name)
print("Loaded Whisper model") notify_progress(f"Loading WhisperX model: {model_name} using {device}", progress)
whisper_model = whisperx.load_model(model_name, device)
print("Loaded WhisperX model")
def unload_whisper(): def unload_whisper():
@ -2042,6 +1909,6 @@ def unload_whisper():
if whisper_model: if whisper_model:
del whisper_model del whisper_model
whisper_model = None whisper_model = None
print("Unloaded Whisper") print("Unloaded WhisperX")
do_gc() do_gc()