forked from mrq/ai-voice-cloning
added option to not slice audio when transcribing, added option to prepare validation dataset on audio duration, added a warning if youre using whisperx and you're slicing audio
This commit is contained in:
parent
dcdcf8516c
commit
2424c455cb
74
src/utils.py
74
src/utils.py
|
@ -667,7 +667,7 @@ class TrainingState():
|
|||
self.steps = int(self.info['steps'])
|
||||
|
||||
if 'iteration_rate' in self.info:
|
||||
it_rate = self.info['iteration_rate']
|
||||
it_rate = self.info['iteration_rate'] / self.batch_size # why
|
||||
self.it_rate = f'{"{:.3f}".format(1/it_rate)}it/s' if 0 < it_rate and it_rate < 1 else f'{"{:.3f}".format(it_rate)}s/it'
|
||||
self.it_rates += it_rate
|
||||
|
||||
|
@ -676,6 +676,7 @@ class TrainingState():
|
|||
eta = str(timedelta(seconds=int(self.eta)))
|
||||
self.eta_hhmmss = eta
|
||||
except Exception as e:
|
||||
self.eta_hhmmss = "?"
|
||||
pass
|
||||
|
||||
self.metrics['step'] = [f"{self.epoch}/{self.epochs}"]
|
||||
|
@ -1064,13 +1065,16 @@ def whisper_transcribe( file, language=None ):
|
|||
|
||||
return result
|
||||
|
||||
def prepare_dataset( files, outdir, language=None, skip_existings=False, progress=None ):
|
||||
def prepare_dataset( files, outdir, language=None, skip_existings=False, slice_audio=False, progress=None ):
|
||||
unload_tts()
|
||||
|
||||
global whisper_model
|
||||
if whisper_model is None:
|
||||
load_whisper_model(language=language)
|
||||
|
||||
if args.whisper_backend == "m-bain/whisperx" and slice_audio:
|
||||
print("! CAUTION ! Slicing audio with whisperx is terrible. Please consider using a different whisper backend if you want to slice audio.")
|
||||
|
||||
os.makedirs(f'{outdir}/audio/', exist_ok=True)
|
||||
|
||||
results = {}
|
||||
|
@ -1092,6 +1096,14 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres
|
|||
if match[0] not in previous_list:
|
||||
previous_list.append(f'{match[0].split("/")[-1]}.wav')
|
||||
|
||||
def validate_waveform( waveform, sample_rate, name ):
|
||||
if not torch.any(waveform < 0):
|
||||
return False
|
||||
|
||||
if waveform.shape[-1] < (.6 * sampling_rate):
|
||||
return False
|
||||
return True
|
||||
|
||||
for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress):
|
||||
basename = os.path.basename(file)
|
||||
|
||||
|
@ -1106,29 +1118,36 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres
|
|||
waveform, sampling_rate = torchaudio.load(file)
|
||||
num_channels, num_frames = waveform.shape
|
||||
|
||||
idx = 0
|
||||
for segment in result['segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress):
|
||||
start = int(segment['start'] * sampling_rate)
|
||||
end = int(segment['end'] * sampling_rate)
|
||||
|
||||
sliced_waveform = waveform[:, start:end]
|
||||
sliced_name = basename.replace(".wav", f"_{pad(idx, 4)}.wav")
|
||||
|
||||
if not torch.any(sliced_waveform < 0):
|
||||
print(f"Sound file is silent: {sliced_name}, skipping...")
|
||||
if not slice_audio:
|
||||
if not validate_waveform( waveform, sampling_rate, name ):
|
||||
print(f"Segment invalid: {name}, skipping...")
|
||||
continue
|
||||
|
||||
if sliced_waveform.shape[-1] < (.6 * sampling_rate):
|
||||
print(f"Sound file is too short: {sliced_name}, skipping...")
|
||||
continue
|
||||
|
||||
torchaudio.save(f"{outdir}/audio/{sliced_name}", sliced_waveform, sampling_rate)
|
||||
|
||||
idx = idx + 1
|
||||
line = f"audio/{sliced_name}|{segment['text'].strip()}"
|
||||
torchaudio.save(f"{outdir}/audio/{basename}", waveform, sampling_rate)
|
||||
line = f"audio/{basename}|{result['text'].strip()}"
|
||||
transcription.append(line)
|
||||
with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f:
|
||||
f.write(f'\n{line}')
|
||||
else:
|
||||
idx = 0
|
||||
for segment in result['segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress):
|
||||
start = int(segment['start'] * sampling_rate)
|
||||
end = int(segment['end'] * sampling_rate)
|
||||
|
||||
sliced_waveform = waveform[:, start:end]
|
||||
sliced_name = basename.replace(".wav", f"_{pad(idx, 4)}.wav")
|
||||
|
||||
if not validate_waveform( sliced_waveform, sampling_rate, sliced_name ):
|
||||
print(f"Trimmed segment invalid: {sliced_name}, skipping...")
|
||||
continue
|
||||
|
||||
torchaudio.save(f"{outdir}/audio/{sliced_name}", sliced_waveform, sampling_rate)
|
||||
|
||||
idx = idx + 1
|
||||
line = f"audio/{sliced_name}|{segment['text'].strip()}"
|
||||
transcription.append(line)
|
||||
with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f:
|
||||
f.write(f'\n{line}')
|
||||
|
||||
do_gc()
|
||||
|
||||
|
@ -1144,7 +1163,7 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres
|
|||
|
||||
return f"Processed dataset to: {outdir}\n{joined}"
|
||||
|
||||
def prepare_validation_dataset( voice, text_length ):
|
||||
def prepare_validation_dataset( voice, text_length, audio_length ):
|
||||
indir = f'./training/{voice}/'
|
||||
infile = f'{indir}/dataset.txt'
|
||||
if not os.path.exists(infile):
|
||||
|
@ -1166,8 +1185,14 @@ def prepare_validation_dataset( voice, text_length ):
|
|||
split = line.split("|")
|
||||
filename = split[0]
|
||||
text = split[1]
|
||||
culled = len(text) < text_length
|
||||
|
||||
if len(text) < text_length:
|
||||
if not culled and audio_length > 0:
|
||||
metadata = torchaudio.info(f'{indir}/{filename}')
|
||||
duration = metadata.num_channels * metadata.num_frames / metadata.sample_rate
|
||||
culled = duration < audio_length
|
||||
|
||||
if culled:
|
||||
validation.append(line.strip())
|
||||
else:
|
||||
training.append(line.strip())
|
||||
|
@ -1178,7 +1203,7 @@ def prepare_validation_dataset( voice, text_length ):
|
|||
with open(f'{indir}/validation.txt', 'w', encoding="utf-8") as f:
|
||||
f.write("\n".join(validation))
|
||||
|
||||
msg = f"Culled {len(validation)} lines"
|
||||
msg = f"Culled {len(validation)}/{len(lines)} lines."
|
||||
print(msg)
|
||||
return msg
|
||||
|
||||
|
@ -1896,6 +1921,9 @@ def load_tts( restart=False, autoregressive_model=None ):
|
|||
|
||||
print(f"Loading TorToiSe... (AR: {autoregressive_model}, vocoder: {args.vocoder_model})")
|
||||
|
||||
if get_device_name() == "cpu":
|
||||
print("!!!! WARNING !!!! No GPU available in PyTorch. You may need to reinstall PyTorch.")
|
||||
|
||||
tts_loading = True
|
||||
try:
|
||||
tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=autoregressive_model, vocoder_model=args.vocoder_model)
|
||||
|
|
32
src/webui.py
Normal file → Executable file
32
src/webui.py
Normal file → Executable file
|
@ -152,9 +152,7 @@ def import_generate_settings_proxy( file=None ):
|
|||
res = []
|
||||
for k in GENERATE_SETTINGS_ARGS:
|
||||
res.append(settings[k] if k in settings else None)
|
||||
print(GENERATE_SETTINGS_ARGS)
|
||||
print(settings)
|
||||
print(res)
|
||||
|
||||
return tuple(res)
|
||||
|
||||
def compute_latents_proxy(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
|
||||
|
@ -184,12 +182,12 @@ def read_generate_settings_proxy(file, saveAs='.temp'):
|
|||
gr.update(visible=j is not None),
|
||||
)
|
||||
|
||||
def prepare_dataset_proxy( voice, language, validation_size, skip_existings, progress=gr.Progress(track_tqdm=True) ):
|
||||
def prepare_dataset_proxy( voice, language, validation_text_length, validation_audio_length, skip_existings, slice_audio, progress=gr.Progress(track_tqdm=True) ):
|
||||
messages = []
|
||||
message = prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, skip_existings=skip_existings, progress=progress )
|
||||
message = prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, skip_existings=skip_existings, slice_audio=slice_audio, progress=progress )
|
||||
messages.append(message)
|
||||
if validation_size > 0:
|
||||
message = prepare_validation_dataset( voice, text_length=validation_size )
|
||||
if validation_text_length > 0 or validation_audio_length > 0:
|
||||
message = prepare_validation_dataset( voice, text_length=validation_text_length, audio_length=validation_audio_length )
|
||||
messages.append(message)
|
||||
return "\n".join(messages)
|
||||
|
||||
|
@ -246,8 +244,7 @@ def import_training_settings_proxy( voice ):
|
|||
output[k] = settings[k]
|
||||
|
||||
output = list(output.values())
|
||||
print(list(TRAINING_SETTINGS.keys()))
|
||||
print(output)
|
||||
|
||||
messages.append(f"Imported training settings: {injson}")
|
||||
|
||||
return output[:-1] + ["\n".join(messages)]
|
||||
|
@ -413,13 +410,20 @@ def setup_gradio():
|
|||
DATASET_SETTINGS['voice'] = gr.Dropdown( choices=voice_list, label="Dataset Source", type="value", value=voice_list[0] if len(voice_list) > 0 else "" )
|
||||
with gr.Row():
|
||||
DATASET_SETTINGS['language'] = gr.Textbox(label="Language", value="en")
|
||||
DATASET_SETTINGS['validation_size'] = gr.Number(label="Validation Text Length Cull Size", value=12, precision=0)
|
||||
DATASET_SETTINGS['skip'] = gr.Checkbox(label="Skip Already Transcribed", value=False)
|
||||
DATASET_SETTINGS['validation_text_length'] = gr.Number(label="Validation Text Length Threshold", value=12, precision=0)
|
||||
DATASET_SETTINGS['validation_audio_length'] = gr.Number(label="Validation Audio Length Threshold", value=1 )
|
||||
with gr.Row():
|
||||
DATASET_SETTINGS['skip'] = gr.Checkbox(label="Skip Already Transcribed", value=False)
|
||||
DATASET_SETTINGS['slice'] = gr.Checkbox(label="Slice Segments", value=False)
|
||||
|
||||
with gr.Row():
|
||||
transcribe_button = gr.Button(value="Transcribe")
|
||||
prepare_validation_button = gr.Button(value="Prepare Validation")
|
||||
|
||||
with gr.Row():
|
||||
EXEC_SETTINGS['whisper_backend'] = gr.Dropdown(WHISPER_BACKENDS, label="Whisper Backends", value=args.whisper_backend)
|
||||
EXEC_SETTINGS['whisper_model'] = gr.Dropdown(WHISPER_MODELS, label="Whisper Model", value=args.whisper_model)
|
||||
|
||||
dataset_settings = list(DATASET_SETTINGS.values())
|
||||
with gr.Column():
|
||||
prepare_dataset_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
|
||||
|
@ -533,8 +537,7 @@ def setup_gradio():
|
|||
EXEC_SETTINGS['autoregressive_model'] = gr.Dropdown(choices=autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else autoregressive_models[0])
|
||||
|
||||
EXEC_SETTINGS['vocoder_model'] = gr.Dropdown(VOCODERS, label="Vocoder", value=args.vocoder_model if args.vocoder_model else VOCODERS[-1])
|
||||
EXEC_SETTINGS['whisper_backend'] = gr.Dropdown(WHISPER_BACKENDS, label="Whisper Backends", value=args.whisper_backend)
|
||||
EXEC_SETTINGS['whisper_model'] = gr.Dropdown(WHISPER_MODELS, label="Whisper Model", value=args.whisper_model)
|
||||
|
||||
|
||||
EXEC_SETTINGS['training_default_halfp'] = TRAINING_SETTINGS['half_p']
|
||||
EXEC_SETTINGS['training_default_bnb'] = TRAINING_SETTINGS['bitsandbytes']
|
||||
|
@ -739,7 +742,8 @@ def setup_gradio():
|
|||
prepare_validation_dataset,
|
||||
inputs=[
|
||||
dataset_settings[0],
|
||||
DATASET_SETTINGS['validation_size'],
|
||||
DATASET_SETTINGS['validation_text_length'],
|
||||
DATASET_SETTINGS['validation_audio_length'],
|
||||
],
|
||||
outputs=prepare_dataset_output #console_output
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue
Block a user