forked from mrq/ai-voice-cloning
added option to not slice audio when transcribing, added option to prepare validation dataset on audio duration, added a warning if youre using whisperx and you're slicing audio
This commit is contained in:
parent
dcdcf8516c
commit
2424c455cb
50
src/utils.py
50
src/utils.py
|
@ -667,7 +667,7 @@ class TrainingState():
|
||||||
self.steps = int(self.info['steps'])
|
self.steps = int(self.info['steps'])
|
||||||
|
|
||||||
if 'iteration_rate' in self.info:
|
if 'iteration_rate' in self.info:
|
||||||
it_rate = self.info['iteration_rate']
|
it_rate = self.info['iteration_rate'] / self.batch_size # why
|
||||||
self.it_rate = f'{"{:.3f}".format(1/it_rate)}it/s' if 0 < it_rate and it_rate < 1 else f'{"{:.3f}".format(it_rate)}s/it'
|
self.it_rate = f'{"{:.3f}".format(1/it_rate)}it/s' if 0 < it_rate and it_rate < 1 else f'{"{:.3f}".format(it_rate)}s/it'
|
||||||
self.it_rates += it_rate
|
self.it_rates += it_rate
|
||||||
|
|
||||||
|
@ -676,6 +676,7 @@ class TrainingState():
|
||||||
eta = str(timedelta(seconds=int(self.eta)))
|
eta = str(timedelta(seconds=int(self.eta)))
|
||||||
self.eta_hhmmss = eta
|
self.eta_hhmmss = eta
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
self.eta_hhmmss = "?"
|
||||||
pass
|
pass
|
||||||
|
|
||||||
self.metrics['step'] = [f"{self.epoch}/{self.epochs}"]
|
self.metrics['step'] = [f"{self.epoch}/{self.epochs}"]
|
||||||
|
@ -1064,13 +1065,16 @@ def whisper_transcribe( file, language=None ):
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def prepare_dataset( files, outdir, language=None, skip_existings=False, progress=None ):
|
def prepare_dataset( files, outdir, language=None, skip_existings=False, slice_audio=False, progress=None ):
|
||||||
unload_tts()
|
unload_tts()
|
||||||
|
|
||||||
global whisper_model
|
global whisper_model
|
||||||
if whisper_model is None:
|
if whisper_model is None:
|
||||||
load_whisper_model(language=language)
|
load_whisper_model(language=language)
|
||||||
|
|
||||||
|
if args.whisper_backend == "m-bain/whisperx" and slice_audio:
|
||||||
|
print("! CAUTION ! Slicing audio with whisperx is terrible. Please consider using a different whisper backend if you want to slice audio.")
|
||||||
|
|
||||||
os.makedirs(f'{outdir}/audio/', exist_ok=True)
|
os.makedirs(f'{outdir}/audio/', exist_ok=True)
|
||||||
|
|
||||||
results = {}
|
results = {}
|
||||||
|
@ -1092,6 +1096,14 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres
|
||||||
if match[0] not in previous_list:
|
if match[0] not in previous_list:
|
||||||
previous_list.append(f'{match[0].split("/")[-1]}.wav')
|
previous_list.append(f'{match[0].split("/")[-1]}.wav')
|
||||||
|
|
||||||
|
def validate_waveform( waveform, sample_rate, name ):
|
||||||
|
if not torch.any(waveform < 0):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if waveform.shape[-1] < (.6 * sampling_rate):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress):
|
for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress):
|
||||||
basename = os.path.basename(file)
|
basename = os.path.basename(file)
|
||||||
|
|
||||||
|
@ -1106,6 +1118,17 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres
|
||||||
waveform, sampling_rate = torchaudio.load(file)
|
waveform, sampling_rate = torchaudio.load(file)
|
||||||
num_channels, num_frames = waveform.shape
|
num_channels, num_frames = waveform.shape
|
||||||
|
|
||||||
|
if not slice_audio:
|
||||||
|
if not validate_waveform( waveform, sampling_rate, name ):
|
||||||
|
print(f"Segment invalid: {name}, skipping...")
|
||||||
|
continue
|
||||||
|
|
||||||
|
torchaudio.save(f"{outdir}/audio/{basename}", waveform, sampling_rate)
|
||||||
|
line = f"audio/{basename}|{result['text'].strip()}"
|
||||||
|
transcription.append(line)
|
||||||
|
with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f:
|
||||||
|
f.write(f'\n{line}')
|
||||||
|
else:
|
||||||
idx = 0
|
idx = 0
|
||||||
for segment in result['segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress):
|
for segment in result['segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress):
|
||||||
start = int(segment['start'] * sampling_rate)
|
start = int(segment['start'] * sampling_rate)
|
||||||
|
@ -1114,12 +1137,8 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres
|
||||||
sliced_waveform = waveform[:, start:end]
|
sliced_waveform = waveform[:, start:end]
|
||||||
sliced_name = basename.replace(".wav", f"_{pad(idx, 4)}.wav")
|
sliced_name = basename.replace(".wav", f"_{pad(idx, 4)}.wav")
|
||||||
|
|
||||||
if not torch.any(sliced_waveform < 0):
|
if not validate_waveform( sliced_waveform, sampling_rate, sliced_name ):
|
||||||
print(f"Sound file is silent: {sliced_name}, skipping...")
|
print(f"Trimmed segment invalid: {sliced_name}, skipping...")
|
||||||
continue
|
|
||||||
|
|
||||||
if sliced_waveform.shape[-1] < (.6 * sampling_rate):
|
|
||||||
print(f"Sound file is too short: {sliced_name}, skipping...")
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
torchaudio.save(f"{outdir}/audio/{sliced_name}", sliced_waveform, sampling_rate)
|
torchaudio.save(f"{outdir}/audio/{sliced_name}", sliced_waveform, sampling_rate)
|
||||||
|
@ -1144,7 +1163,7 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres
|
||||||
|
|
||||||
return f"Processed dataset to: {outdir}\n{joined}"
|
return f"Processed dataset to: {outdir}\n{joined}"
|
||||||
|
|
||||||
def prepare_validation_dataset( voice, text_length ):
|
def prepare_validation_dataset( voice, text_length, audio_length ):
|
||||||
indir = f'./training/{voice}/'
|
indir = f'./training/{voice}/'
|
||||||
infile = f'{indir}/dataset.txt'
|
infile = f'{indir}/dataset.txt'
|
||||||
if not os.path.exists(infile):
|
if not os.path.exists(infile):
|
||||||
|
@ -1166,8 +1185,14 @@ def prepare_validation_dataset( voice, text_length ):
|
||||||
split = line.split("|")
|
split = line.split("|")
|
||||||
filename = split[0]
|
filename = split[0]
|
||||||
text = split[1]
|
text = split[1]
|
||||||
|
culled = len(text) < text_length
|
||||||
|
|
||||||
if len(text) < text_length:
|
if not culled and audio_length > 0:
|
||||||
|
metadata = torchaudio.info(f'{indir}/{filename}')
|
||||||
|
duration = metadata.num_channels * metadata.num_frames / metadata.sample_rate
|
||||||
|
culled = duration < audio_length
|
||||||
|
|
||||||
|
if culled:
|
||||||
validation.append(line.strip())
|
validation.append(line.strip())
|
||||||
else:
|
else:
|
||||||
training.append(line.strip())
|
training.append(line.strip())
|
||||||
|
@ -1178,7 +1203,7 @@ def prepare_validation_dataset( voice, text_length ):
|
||||||
with open(f'{indir}/validation.txt', 'w', encoding="utf-8") as f:
|
with open(f'{indir}/validation.txt', 'w', encoding="utf-8") as f:
|
||||||
f.write("\n".join(validation))
|
f.write("\n".join(validation))
|
||||||
|
|
||||||
msg = f"Culled {len(validation)} lines"
|
msg = f"Culled {len(validation)}/{len(lines)} lines."
|
||||||
print(msg)
|
print(msg)
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
|
@ -1896,6 +1921,9 @@ def load_tts( restart=False, autoregressive_model=None ):
|
||||||
|
|
||||||
print(f"Loading TorToiSe... (AR: {autoregressive_model}, vocoder: {args.vocoder_model})")
|
print(f"Loading TorToiSe... (AR: {autoregressive_model}, vocoder: {args.vocoder_model})")
|
||||||
|
|
||||||
|
if get_device_name() == "cpu":
|
||||||
|
print("!!!! WARNING !!!! No GPU available in PyTorch. You may need to reinstall PyTorch.")
|
||||||
|
|
||||||
tts_loading = True
|
tts_loading = True
|
||||||
try:
|
try:
|
||||||
tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=autoregressive_model, vocoder_model=args.vocoder_model)
|
tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=autoregressive_model, vocoder_model=args.vocoder_model)
|
||||||
|
|
30
src/webui.py
Normal file → Executable file
30
src/webui.py
Normal file → Executable file
|
@ -152,9 +152,7 @@ def import_generate_settings_proxy( file=None ):
|
||||||
res = []
|
res = []
|
||||||
for k in GENERATE_SETTINGS_ARGS:
|
for k in GENERATE_SETTINGS_ARGS:
|
||||||
res.append(settings[k] if k in settings else None)
|
res.append(settings[k] if k in settings else None)
|
||||||
print(GENERATE_SETTINGS_ARGS)
|
|
||||||
print(settings)
|
|
||||||
print(res)
|
|
||||||
return tuple(res)
|
return tuple(res)
|
||||||
|
|
||||||
def compute_latents_proxy(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
|
def compute_latents_proxy(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
|
||||||
|
@ -184,12 +182,12 @@ def read_generate_settings_proxy(file, saveAs='.temp'):
|
||||||
gr.update(visible=j is not None),
|
gr.update(visible=j is not None),
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_dataset_proxy( voice, language, validation_size, skip_existings, progress=gr.Progress(track_tqdm=True) ):
|
def prepare_dataset_proxy( voice, language, validation_text_length, validation_audio_length, skip_existings, slice_audio, progress=gr.Progress(track_tqdm=True) ):
|
||||||
messages = []
|
messages = []
|
||||||
message = prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, skip_existings=skip_existings, progress=progress )
|
message = prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, skip_existings=skip_existings, slice_audio=slice_audio, progress=progress )
|
||||||
messages.append(message)
|
messages.append(message)
|
||||||
if validation_size > 0:
|
if validation_text_length > 0 or validation_audio_length > 0:
|
||||||
message = prepare_validation_dataset( voice, text_length=validation_size )
|
message = prepare_validation_dataset( voice, text_length=validation_text_length, audio_length=validation_audio_length )
|
||||||
messages.append(message)
|
messages.append(message)
|
||||||
return "\n".join(messages)
|
return "\n".join(messages)
|
||||||
|
|
||||||
|
@ -246,8 +244,7 @@ def import_training_settings_proxy( voice ):
|
||||||
output[k] = settings[k]
|
output[k] = settings[k]
|
||||||
|
|
||||||
output = list(output.values())
|
output = list(output.values())
|
||||||
print(list(TRAINING_SETTINGS.keys()))
|
|
||||||
print(output)
|
|
||||||
messages.append(f"Imported training settings: {injson}")
|
messages.append(f"Imported training settings: {injson}")
|
||||||
|
|
||||||
return output[:-1] + ["\n".join(messages)]
|
return output[:-1] + ["\n".join(messages)]
|
||||||
|
@ -413,13 +410,20 @@ def setup_gradio():
|
||||||
DATASET_SETTINGS['voice'] = gr.Dropdown( choices=voice_list, label="Dataset Source", type="value", value=voice_list[0] if len(voice_list) > 0 else "" )
|
DATASET_SETTINGS['voice'] = gr.Dropdown( choices=voice_list, label="Dataset Source", type="value", value=voice_list[0] if len(voice_list) > 0 else "" )
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
DATASET_SETTINGS['language'] = gr.Textbox(label="Language", value="en")
|
DATASET_SETTINGS['language'] = gr.Textbox(label="Language", value="en")
|
||||||
DATASET_SETTINGS['validation_size'] = gr.Number(label="Validation Text Length Cull Size", value=12, precision=0)
|
DATASET_SETTINGS['validation_text_length'] = gr.Number(label="Validation Text Length Threshold", value=12, precision=0)
|
||||||
|
DATASET_SETTINGS['validation_audio_length'] = gr.Number(label="Validation Audio Length Threshold", value=1 )
|
||||||
|
with gr.Row():
|
||||||
DATASET_SETTINGS['skip'] = gr.Checkbox(label="Skip Already Transcribed", value=False)
|
DATASET_SETTINGS['skip'] = gr.Checkbox(label="Skip Already Transcribed", value=False)
|
||||||
|
DATASET_SETTINGS['slice'] = gr.Checkbox(label="Slice Segments", value=False)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
transcribe_button = gr.Button(value="Transcribe")
|
transcribe_button = gr.Button(value="Transcribe")
|
||||||
prepare_validation_button = gr.Button(value="Prepare Validation")
|
prepare_validation_button = gr.Button(value="Prepare Validation")
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
EXEC_SETTINGS['whisper_backend'] = gr.Dropdown(WHISPER_BACKENDS, label="Whisper Backends", value=args.whisper_backend)
|
||||||
|
EXEC_SETTINGS['whisper_model'] = gr.Dropdown(WHISPER_MODELS, label="Whisper Model", value=args.whisper_model)
|
||||||
|
|
||||||
dataset_settings = list(DATASET_SETTINGS.values())
|
dataset_settings = list(DATASET_SETTINGS.values())
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
prepare_dataset_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
|
prepare_dataset_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
|
||||||
|
@ -533,8 +537,7 @@ def setup_gradio():
|
||||||
EXEC_SETTINGS['autoregressive_model'] = gr.Dropdown(choices=autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else autoregressive_models[0])
|
EXEC_SETTINGS['autoregressive_model'] = gr.Dropdown(choices=autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else autoregressive_models[0])
|
||||||
|
|
||||||
EXEC_SETTINGS['vocoder_model'] = gr.Dropdown(VOCODERS, label="Vocoder", value=args.vocoder_model if args.vocoder_model else VOCODERS[-1])
|
EXEC_SETTINGS['vocoder_model'] = gr.Dropdown(VOCODERS, label="Vocoder", value=args.vocoder_model if args.vocoder_model else VOCODERS[-1])
|
||||||
EXEC_SETTINGS['whisper_backend'] = gr.Dropdown(WHISPER_BACKENDS, label="Whisper Backends", value=args.whisper_backend)
|
|
||||||
EXEC_SETTINGS['whisper_model'] = gr.Dropdown(WHISPER_MODELS, label="Whisper Model", value=args.whisper_model)
|
|
||||||
|
|
||||||
EXEC_SETTINGS['training_default_halfp'] = TRAINING_SETTINGS['half_p']
|
EXEC_SETTINGS['training_default_halfp'] = TRAINING_SETTINGS['half_p']
|
||||||
EXEC_SETTINGS['training_default_bnb'] = TRAINING_SETTINGS['bitsandbytes']
|
EXEC_SETTINGS['training_default_bnb'] = TRAINING_SETTINGS['bitsandbytes']
|
||||||
|
@ -739,7 +742,8 @@ def setup_gradio():
|
||||||
prepare_validation_dataset,
|
prepare_validation_dataset,
|
||||||
inputs=[
|
inputs=[
|
||||||
dataset_settings[0],
|
dataset_settings[0],
|
||||||
DATASET_SETTINGS['validation_size'],
|
DATASET_SETTINGS['validation_text_length'],
|
||||||
|
DATASET_SETTINGS['validation_audio_length'],
|
||||||
],
|
],
|
||||||
outputs=prepare_dataset_output #console_output
|
outputs=prepare_dataset_output #console_output
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user