added option to not slice audio when transcribing, added option to prepare validation dataset on audio duration, added a warning if youre using whisperx and you're slicing audio

This commit is contained in:
mrq 2023-03-11 16:32:35 +00:00
parent dcdcf8516c
commit 2424c455cb
2 changed files with 69 additions and 37 deletions

View File

@ -667,7 +667,7 @@ class TrainingState():
self.steps = int(self.info['steps']) self.steps = int(self.info['steps'])
if 'iteration_rate' in self.info: if 'iteration_rate' in self.info:
it_rate = self.info['iteration_rate'] it_rate = self.info['iteration_rate'] / self.batch_size # why
self.it_rate = f'{"{:.3f}".format(1/it_rate)}it/s' if 0 < it_rate and it_rate < 1 else f'{"{:.3f}".format(it_rate)}s/it' self.it_rate = f'{"{:.3f}".format(1/it_rate)}it/s' if 0 < it_rate and it_rate < 1 else f'{"{:.3f}".format(it_rate)}s/it'
self.it_rates += it_rate self.it_rates += it_rate
@ -676,6 +676,7 @@ class TrainingState():
eta = str(timedelta(seconds=int(self.eta))) eta = str(timedelta(seconds=int(self.eta)))
self.eta_hhmmss = eta self.eta_hhmmss = eta
except Exception as e: except Exception as e:
self.eta_hhmmss = "?"
pass pass
self.metrics['step'] = [f"{self.epoch}/{self.epochs}"] self.metrics['step'] = [f"{self.epoch}/{self.epochs}"]
@ -1064,13 +1065,16 @@ def whisper_transcribe( file, language=None ):
return result return result
def prepare_dataset( files, outdir, language=None, skip_existings=False, progress=None ): def prepare_dataset( files, outdir, language=None, skip_existings=False, slice_audio=False, progress=None ):
unload_tts() unload_tts()
global whisper_model global whisper_model
if whisper_model is None: if whisper_model is None:
load_whisper_model(language=language) load_whisper_model(language=language)
if args.whisper_backend == "m-bain/whisperx" and slice_audio:
print("! CAUTION ! Slicing audio with whisperx is terrible. Please consider using a different whisper backend if you want to slice audio.")
os.makedirs(f'{outdir}/audio/', exist_ok=True) os.makedirs(f'{outdir}/audio/', exist_ok=True)
results = {} results = {}
@ -1092,6 +1096,14 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres
if match[0] not in previous_list: if match[0] not in previous_list:
previous_list.append(f'{match[0].split("/")[-1]}.wav') previous_list.append(f'{match[0].split("/")[-1]}.wav')
def validate_waveform( waveform, sample_rate, name ):
if not torch.any(waveform < 0):
return False
if waveform.shape[-1] < (.6 * sampling_rate):
return False
return True
for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress): for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress):
basename = os.path.basename(file) basename = os.path.basename(file)
@ -1106,6 +1118,17 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres
waveform, sampling_rate = torchaudio.load(file) waveform, sampling_rate = torchaudio.load(file)
num_channels, num_frames = waveform.shape num_channels, num_frames = waveform.shape
if not slice_audio:
if not validate_waveform( waveform, sampling_rate, name ):
print(f"Segment invalid: {name}, skipping...")
continue
torchaudio.save(f"{outdir}/audio/{basename}", waveform, sampling_rate)
line = f"audio/{basename}|{result['text'].strip()}"
transcription.append(line)
with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f:
f.write(f'\n{line}')
else:
idx = 0 idx = 0
for segment in result['segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress): for segment in result['segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress):
start = int(segment['start'] * sampling_rate) start = int(segment['start'] * sampling_rate)
@ -1114,12 +1137,8 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres
sliced_waveform = waveform[:, start:end] sliced_waveform = waveform[:, start:end]
sliced_name = basename.replace(".wav", f"_{pad(idx, 4)}.wav") sliced_name = basename.replace(".wav", f"_{pad(idx, 4)}.wav")
if not torch.any(sliced_waveform < 0): if not validate_waveform( sliced_waveform, sampling_rate, sliced_name ):
print(f"Sound file is silent: {sliced_name}, skipping...") print(f"Trimmed segment invalid: {sliced_name}, skipping...")
continue
if sliced_waveform.shape[-1] < (.6 * sampling_rate):
print(f"Sound file is too short: {sliced_name}, skipping...")
continue continue
torchaudio.save(f"{outdir}/audio/{sliced_name}", sliced_waveform, sampling_rate) torchaudio.save(f"{outdir}/audio/{sliced_name}", sliced_waveform, sampling_rate)
@ -1144,7 +1163,7 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres
return f"Processed dataset to: {outdir}\n{joined}" return f"Processed dataset to: {outdir}\n{joined}"
def prepare_validation_dataset( voice, text_length ): def prepare_validation_dataset( voice, text_length, audio_length ):
indir = f'./training/{voice}/' indir = f'./training/{voice}/'
infile = f'{indir}/dataset.txt' infile = f'{indir}/dataset.txt'
if not os.path.exists(infile): if not os.path.exists(infile):
@ -1166,8 +1185,14 @@ def prepare_validation_dataset( voice, text_length ):
split = line.split("|") split = line.split("|")
filename = split[0] filename = split[0]
text = split[1] text = split[1]
culled = len(text) < text_length
if len(text) < text_length: if not culled and audio_length > 0:
metadata = torchaudio.info(f'{indir}/{filename}')
duration = metadata.num_channels * metadata.num_frames / metadata.sample_rate
culled = duration < audio_length
if culled:
validation.append(line.strip()) validation.append(line.strip())
else: else:
training.append(line.strip()) training.append(line.strip())
@ -1178,7 +1203,7 @@ def prepare_validation_dataset( voice, text_length ):
with open(f'{indir}/validation.txt', 'w', encoding="utf-8") as f: with open(f'{indir}/validation.txt', 'w', encoding="utf-8") as f:
f.write("\n".join(validation)) f.write("\n".join(validation))
msg = f"Culled {len(validation)} lines" msg = f"Culled {len(validation)}/{len(lines)} lines."
print(msg) print(msg)
return msg return msg
@ -1896,6 +1921,9 @@ def load_tts( restart=False, autoregressive_model=None ):
print(f"Loading TorToiSe... (AR: {autoregressive_model}, vocoder: {args.vocoder_model})") print(f"Loading TorToiSe... (AR: {autoregressive_model}, vocoder: {args.vocoder_model})")
if get_device_name() == "cpu":
print("!!!! WARNING !!!! No GPU available in PyTorch. You may need to reinstall PyTorch.")
tts_loading = True tts_loading = True
try: try:
tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=autoregressive_model, vocoder_model=args.vocoder_model) tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=autoregressive_model, vocoder_model=args.vocoder_model)

30
src/webui.py Normal file → Executable file
View File

@ -152,9 +152,7 @@ def import_generate_settings_proxy( file=None ):
res = [] res = []
for k in GENERATE_SETTINGS_ARGS: for k in GENERATE_SETTINGS_ARGS:
res.append(settings[k] if k in settings else None) res.append(settings[k] if k in settings else None)
print(GENERATE_SETTINGS_ARGS)
print(settings)
print(res)
return tuple(res) return tuple(res)
def compute_latents_proxy(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)): def compute_latents_proxy(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
@ -184,12 +182,12 @@ def read_generate_settings_proxy(file, saveAs='.temp'):
gr.update(visible=j is not None), gr.update(visible=j is not None),
) )
def prepare_dataset_proxy( voice, language, validation_size, skip_existings, progress=gr.Progress(track_tqdm=True) ): def prepare_dataset_proxy( voice, language, validation_text_length, validation_audio_length, skip_existings, slice_audio, progress=gr.Progress(track_tqdm=True) ):
messages = [] messages = []
message = prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, skip_existings=skip_existings, progress=progress ) message = prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, skip_existings=skip_existings, slice_audio=slice_audio, progress=progress )
messages.append(message) messages.append(message)
if validation_size > 0: if validation_text_length > 0 or validation_audio_length > 0:
message = prepare_validation_dataset( voice, text_length=validation_size ) message = prepare_validation_dataset( voice, text_length=validation_text_length, audio_length=validation_audio_length )
messages.append(message) messages.append(message)
return "\n".join(messages) return "\n".join(messages)
@ -246,8 +244,7 @@ def import_training_settings_proxy( voice ):
output[k] = settings[k] output[k] = settings[k]
output = list(output.values()) output = list(output.values())
print(list(TRAINING_SETTINGS.keys()))
print(output)
messages.append(f"Imported training settings: {injson}") messages.append(f"Imported training settings: {injson}")
return output[:-1] + ["\n".join(messages)] return output[:-1] + ["\n".join(messages)]
@ -413,13 +410,20 @@ def setup_gradio():
DATASET_SETTINGS['voice'] = gr.Dropdown( choices=voice_list, label="Dataset Source", type="value", value=voice_list[0] if len(voice_list) > 0 else "" ) DATASET_SETTINGS['voice'] = gr.Dropdown( choices=voice_list, label="Dataset Source", type="value", value=voice_list[0] if len(voice_list) > 0 else "" )
with gr.Row(): with gr.Row():
DATASET_SETTINGS['language'] = gr.Textbox(label="Language", value="en") DATASET_SETTINGS['language'] = gr.Textbox(label="Language", value="en")
DATASET_SETTINGS['validation_size'] = gr.Number(label="Validation Text Length Cull Size", value=12, precision=0) DATASET_SETTINGS['validation_text_length'] = gr.Number(label="Validation Text Length Threshold", value=12, precision=0)
DATASET_SETTINGS['validation_audio_length'] = gr.Number(label="Validation Audio Length Threshold", value=1 )
with gr.Row():
DATASET_SETTINGS['skip'] = gr.Checkbox(label="Skip Already Transcribed", value=False) DATASET_SETTINGS['skip'] = gr.Checkbox(label="Skip Already Transcribed", value=False)
DATASET_SETTINGS['slice'] = gr.Checkbox(label="Slice Segments", value=False)
with gr.Row(): with gr.Row():
transcribe_button = gr.Button(value="Transcribe") transcribe_button = gr.Button(value="Transcribe")
prepare_validation_button = gr.Button(value="Prepare Validation") prepare_validation_button = gr.Button(value="Prepare Validation")
with gr.Row():
EXEC_SETTINGS['whisper_backend'] = gr.Dropdown(WHISPER_BACKENDS, label="Whisper Backends", value=args.whisper_backend)
EXEC_SETTINGS['whisper_model'] = gr.Dropdown(WHISPER_MODELS, label="Whisper Model", value=args.whisper_model)
dataset_settings = list(DATASET_SETTINGS.values()) dataset_settings = list(DATASET_SETTINGS.values())
with gr.Column(): with gr.Column():
prepare_dataset_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8) prepare_dataset_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
@ -533,8 +537,7 @@ def setup_gradio():
EXEC_SETTINGS['autoregressive_model'] = gr.Dropdown(choices=autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else autoregressive_models[0]) EXEC_SETTINGS['autoregressive_model'] = gr.Dropdown(choices=autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else autoregressive_models[0])
EXEC_SETTINGS['vocoder_model'] = gr.Dropdown(VOCODERS, label="Vocoder", value=args.vocoder_model if args.vocoder_model else VOCODERS[-1]) EXEC_SETTINGS['vocoder_model'] = gr.Dropdown(VOCODERS, label="Vocoder", value=args.vocoder_model if args.vocoder_model else VOCODERS[-1])
EXEC_SETTINGS['whisper_backend'] = gr.Dropdown(WHISPER_BACKENDS, label="Whisper Backends", value=args.whisper_backend)
EXEC_SETTINGS['whisper_model'] = gr.Dropdown(WHISPER_MODELS, label="Whisper Model", value=args.whisper_model)
EXEC_SETTINGS['training_default_halfp'] = TRAINING_SETTINGS['half_p'] EXEC_SETTINGS['training_default_halfp'] = TRAINING_SETTINGS['half_p']
EXEC_SETTINGS['training_default_bnb'] = TRAINING_SETTINGS['bitsandbytes'] EXEC_SETTINGS['training_default_bnb'] = TRAINING_SETTINGS['bitsandbytes']
@ -739,7 +742,8 @@ def setup_gradio():
prepare_validation_dataset, prepare_validation_dataset,
inputs=[ inputs=[
dataset_settings[0], dataset_settings[0],
DATASET_SETTINGS['validation_size'], DATASET_SETTINGS['validation_text_length'],
DATASET_SETTINGS['validation_audio_length'],
], ],
outputs=prepare_dataset_output #console_output outputs=prepare_dataset_output #console_output
) )