rely on the whisper.json for handling a lot more things
This commit is contained in:
parent
9b376c381f
commit
382a3e4104
194
src/utils.py
194
src/utils.py
|
@ -33,7 +33,7 @@ from datetime import datetime
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
|
||||||
from tortoise.api import TextToSpeech, MODELS, get_model_path, pad_or_truncate
|
from tortoise.api import TextToSpeech, MODELS, get_model_path, pad_or_truncate
|
||||||
from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir
|
from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir, get_voices
|
||||||
from tortoise.utils.text import split_and_recombine_text
|
from tortoise.utils.text import split_and_recombine_text
|
||||||
from tortoise.utils.device import get_device_name, set_device_name, get_device_count, get_device_vram
|
from tortoise.utils.device import get_device_name, set_device_name, get_device_count, get_device_vram
|
||||||
|
|
||||||
|
@ -1059,6 +1059,47 @@ def validate_waveform( waveform, sample_rate ):
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def transcribe_dataset( voice, language=None, skip_existings=False, progress=None ):
|
||||||
|
unload_tts()
|
||||||
|
|
||||||
|
global whisper_model
|
||||||
|
if whisper_model is None:
|
||||||
|
load_whisper_model(language=language)
|
||||||
|
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
files = sorted( get_voices(load_latents=False)[voice] )
|
||||||
|
indir = f'./training/{voice}/'
|
||||||
|
infile = f'{indir}/whisper.json'
|
||||||
|
|
||||||
|
os.makedirs(f'{indir}/audio/', exist_ok=True)
|
||||||
|
|
||||||
|
if os.path.exists(infile):
|
||||||
|
results = json.load(open(infile, 'r', encoding="utf-8"))
|
||||||
|
|
||||||
|
for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress):
|
||||||
|
basename = os.path.basename(file)
|
||||||
|
|
||||||
|
if basename in results and skip_existings:
|
||||||
|
print(f"Skipping already parsed file: {basename}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
results[basename] = whisper_transcribe(file, language=language)
|
||||||
|
|
||||||
|
# lazy copy
|
||||||
|
waveform, sampling_rate = torchaudio.load(file)
|
||||||
|
torchaudio.save(f"{indir}/audio/{basename}", waveform, sampling_rate)
|
||||||
|
|
||||||
|
with open(infile, 'w', encoding="utf-8") as f:
|
||||||
|
f.write(json.dumps(results, indent='\t'))
|
||||||
|
|
||||||
|
do_gc()
|
||||||
|
|
||||||
|
unload_whisper()
|
||||||
|
|
||||||
|
return f"Processed dataset to: {indir}"
|
||||||
|
|
||||||
def slice_dataset( voice, start_offset=0, end_offset=0 ):
|
def slice_dataset( voice, start_offset=0, end_offset=0 ):
|
||||||
indir = f'./training/{voice}/'
|
indir = f'./training/{voice}/'
|
||||||
infile = f'{indir}/whisper.json'
|
infile = f'{indir}/whisper.json'
|
||||||
|
@ -1066,148 +1107,71 @@ def slice_dataset( voice, start_offset=0, end_offset=0 ):
|
||||||
if not os.path.exists(infile):
|
if not os.path.exists(infile):
|
||||||
raise Exception(f"Missing dataset: {infile}")
|
raise Exception(f"Missing dataset: {infile}")
|
||||||
|
|
||||||
with open(infile, 'r', encoding="utf-8") as f:
|
results = json.load(open(infile, 'r', encoding="utf-8"))
|
||||||
results = json.load(f)
|
|
||||||
|
|
||||||
transcription = []
|
files = 0
|
||||||
|
segments = 0
|
||||||
for filename in results:
|
for filename in results:
|
||||||
idx = 0
|
files += 1
|
||||||
|
|
||||||
result = results[filename]
|
result = results[filename]
|
||||||
waveform, sampling_rate = torchaudio.load(f'./voices/{voice}/{filename}')
|
waveform, sampling_rate = torchaudio.load(f'./voices/{voice}/{filename}')
|
||||||
|
|
||||||
for segment in result['segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress):
|
for segment in result['segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress):
|
||||||
|
segments +=1
|
||||||
start = int((segment['start'] + start_offset) * sampling_rate)
|
start = int((segment['start'] + start_offset) * sampling_rate)
|
||||||
end = int((segment['end'] + end_offset) * sampling_rate)
|
end = int((segment['end'] + end_offset) * sampling_rate)
|
||||||
|
|
||||||
sliced_waveform = waveform[:, start:end]
|
sliced = waveform[:, start:end]
|
||||||
sliced_name = filename.replace(".wav", f"_{pad(idx, 4)}.wav")
|
file = filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav")
|
||||||
|
|
||||||
if not validate_waveform( sliced_waveform, sampling_rate ):
|
if not validate_waveform( sliced, sampling_rate ):
|
||||||
print(f"Invalid waveform segment ({segment['start']}:{segment['end']}): {sliced_name}, skipping...")
|
print(f"Invalid waveform segment ({segment['start']}:{segment['end']}): {file}, skipping...")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
torchaudio.save(f"{indir}/audio/{sliced_name}", sliced_waveform, sampling_rate)
|
torchaudio.save(f"{indir}/audio/{file}", sliced, sampling_rate)
|
||||||
|
|
||||||
idx = idx + 1
|
return f"Sliced segments: {files} => {segments}."
|
||||||
line = f"audio/{sliced_name}|{segment['text'].strip()}"
|
|
||||||
transcription.append(line)
|
|
||||||
with open(f'{indir}/train.txt', 'a', encoding="utf-8") as f:
|
|
||||||
f.write(f'\n{line}')
|
|
||||||
|
|
||||||
joined = "\n".join(transcription)
|
def prepare_dataset( voice, use_segments, text_length, audio_length ):
|
||||||
with open(f'{indir}/train.txt', 'w', encoding="utf-8") as f:
|
|
||||||
f.write(joined)
|
|
||||||
|
|
||||||
return f"Processed dataset to: {indir}\n{joined}"
|
|
||||||
|
|
||||||
def prepare_dataset( files, outdir, language=None, skip_existings=False, progress=None ):
|
|
||||||
unload_tts()
|
|
||||||
|
|
||||||
global whisper_model
|
|
||||||
if whisper_model is None:
|
|
||||||
load_whisper_model(language=language)
|
|
||||||
|
|
||||||
os.makedirs(f'{outdir}/audio/', exist_ok=True)
|
|
||||||
|
|
||||||
results = {}
|
|
||||||
transcription = []
|
|
||||||
files = sorted(files)
|
|
||||||
|
|
||||||
previous_list = []
|
|
||||||
if skip_existings and os.path.exists(f'{outdir}/train.txt'):
|
|
||||||
parsed_list = []
|
|
||||||
with open(f'{outdir}/train.txt', 'r', encoding="utf-8") as f:
|
|
||||||
parsed_list = f.readlines()
|
|
||||||
|
|
||||||
for line in parsed_list:
|
|
||||||
match = re.findall(r"^(.+?)_\d+\.wav$", line.split("|")[0])
|
|
||||||
|
|
||||||
if match is None or len(match) == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if match[0] not in previous_list:
|
|
||||||
previous_list.append(f'{match[0].split("/")[-1]}.wav')
|
|
||||||
|
|
||||||
|
|
||||||
for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress):
|
|
||||||
basename = os.path.basename(file)
|
|
||||||
|
|
||||||
if basename in previous_list:
|
|
||||||
print(f"Skipping already parsed file: {basename}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
result = whisper_transcribe(file, language=language)
|
|
||||||
results[basename] = result
|
|
||||||
print(f"Transcribed file: {file}, {len(result['segments'])} found.")
|
|
||||||
|
|
||||||
waveform, sampling_rate = torchaudio.load(file)
|
|
||||||
|
|
||||||
if not validate_waveform( waveform, sampling_rate ):
|
|
||||||
print(f"Invalid waveform: {basename}, skipping...")
|
|
||||||
continue
|
|
||||||
|
|
||||||
torchaudio.save(f"{outdir}/audio/{basename}", waveform, sampling_rate)
|
|
||||||
line = f"audio/{basename}|{result['text'].strip()}"
|
|
||||||
transcription.append(line)
|
|
||||||
with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f:
|
|
||||||
f.write(f'\n{line}')
|
|
||||||
|
|
||||||
do_gc()
|
|
||||||
|
|
||||||
with open(f'{outdir}/whisper.json', 'w', encoding="utf-8") as f:
|
|
||||||
f.write(json.dumps(results, indent='\t'))
|
|
||||||
|
|
||||||
unload_whisper()
|
|
||||||
|
|
||||||
joined = "\n".join(transcription)
|
|
||||||
if not skip_existings:
|
|
||||||
with open(f'{outdir}/train.txt', 'w', encoding="utf-8") as f:
|
|
||||||
f.write(joined)
|
|
||||||
|
|
||||||
return f"Processed dataset to: {outdir}\n{joined}"
|
|
||||||
|
|
||||||
def prepare_validation_dataset( voice, text_length, audio_length ):
|
|
||||||
indir = f'./training/{voice}/'
|
indir = f'./training/{voice}/'
|
||||||
infile = f'{indir}/dataset.txt'
|
infile = f'{indir}/whisper.json'
|
||||||
if not os.path.exists(infile):
|
|
||||||
infile = f'{indir}/train.txt'
|
|
||||||
with open(f'{indir}/train.txt', 'r', encoding="utf-8") as src:
|
|
||||||
with open(f'{indir}/dataset.txt', 'w', encoding="utf-8") as dst:
|
|
||||||
dst.write(src.read())
|
|
||||||
|
|
||||||
if not os.path.exists(infile):
|
if not os.path.exists(infile):
|
||||||
raise Exception(f"Missing dataset: {infile}")
|
raise Exception(f"Missing dataset: {infile}")
|
||||||
|
|
||||||
with open(infile, 'r', encoding="utf-8") as f:
|
results = json.load(open(infile, 'r', encoding="utf-8"))
|
||||||
lines = f.readlines()
|
|
||||||
|
|
||||||
validation = []
|
lines = {
|
||||||
training = []
|
'training': [],
|
||||||
|
'validation': [],
|
||||||
|
}
|
||||||
|
|
||||||
for line in lines:
|
for filename in results:
|
||||||
split = line.split("|")
|
result = results[filename]
|
||||||
filename = split[0]
|
segments = result['segments'] if use_segments else [{'text': result['text']}]
|
||||||
text = split[1]
|
for segment in segments:
|
||||||
culled = len(text) < text_length
|
text = segment['text'].strip()
|
||||||
|
file = filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav") if use_segments else filename
|
||||||
|
|
||||||
if not culled and audio_length > 0:
|
culled = len(text) < text_length
|
||||||
metadata = torchaudio.info(f'{indir}/{filename}')
|
if not culled and audio_length > 0:
|
||||||
duration = metadata.num_channels * metadata.num_frames / metadata.sample_rate
|
metadata = torchaudio.info(f'{indir}/audio/{file}')
|
||||||
culled = duration < audio_length
|
duration = metadata.num_channels * metadata.num_frames / metadata.sample_rate
|
||||||
|
culled = duration < audio_length
|
||||||
|
|
||||||
if culled:
|
lines['training' if not culled else 'validation'].append(f'audio/{file}|{text}')
|
||||||
validation.append(line.strip())
|
|
||||||
else:
|
training_joined = "\n".join(lines['training'])
|
||||||
training.append(line.strip())
|
validation_joined = "\n".join(lines['validation'])
|
||||||
|
|
||||||
with open(f'{indir}/train.txt', 'w', encoding="utf-8") as f:
|
with open(f'{indir}/train.txt', 'w', encoding="utf-8") as f:
|
||||||
f.write("\n".join(training))
|
f.write(training_joined)
|
||||||
|
|
||||||
with open(f'{indir}/validation.txt', 'w', encoding="utf-8") as f:
|
with open(f'{indir}/validation.txt', 'w', encoding="utf-8") as f:
|
||||||
f.write("\n".join(validation))
|
f.write(validation_joined)
|
||||||
|
|
||||||
msg = f"Culled {len(validation)}/{len(lines)} lines."
|
msg = f"Prepared {len(lines['training'])} lines (validation: {len(lines['validation'])}).\n{training_joined}\n\n{validation_joined}"
|
||||||
print(msg)
|
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
def calc_iterations( epochs, lines, batch_size ):
|
def calc_iterations( epochs, lines, batch_size ):
|
||||||
|
|
26
src/webui.py
26
src/webui.py
|
@ -182,16 +182,19 @@ def read_generate_settings_proxy(file, saveAs='.temp'):
|
||||||
gr.update(visible=j is not None),
|
gr.update(visible=j is not None),
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_dataset_proxy( voice, language, validation_text_length, validation_audio_length, skip_existings, slice_audio, progress=gr.Progress(track_tqdm=True) ):
|
def prepare_dataset_proxy( voice, language, validation_text_length, validation_audio_length, skip_existings, slice_audio, progress=gr.Progress(track_tqdm=False) ):
|
||||||
messages = []
|
messages = []
|
||||||
message = prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, skip_existings=skip_existings, progress=progress )
|
|
||||||
|
message = transcribe_dataset( voice=voice, language=language, skip_existings=skip_existings, progress=progress )
|
||||||
messages.append(message)
|
messages.append(message)
|
||||||
|
|
||||||
if slice_audio:
|
if slice_audio:
|
||||||
message = slice_dataset( voice )
|
message = slice_dataset( voice )
|
||||||
messages.append(message)
|
messages.append(message)
|
||||||
if validation_text_length > 0 or validation_audio_length > 0:
|
|
||||||
message = prepare_validation_dataset( voice, text_length=validation_text_length, audio_length=validation_audio_length )
|
message = prepare_dataset( voice, use_segments=slice_audio, text_length=validation_text_length, audio_length=validation_audio_length )
|
||||||
messages.append(message)
|
messages.append(message)
|
||||||
|
|
||||||
return "\n".join(messages)
|
return "\n".join(messages)
|
||||||
|
|
||||||
def update_args_proxy( *args ):
|
def update_args_proxy( *args ):
|
||||||
|
@ -421,8 +424,8 @@ def setup_gradio():
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
transcribe_button = gr.Button(value="Transcribe")
|
transcribe_button = gr.Button(value="Transcribe")
|
||||||
prepare_validation_button = gr.Button(value="(Re)Create Validation Dataset")
|
|
||||||
slice_dataset_button = gr.Button(value="(Re)Slice Audio")
|
slice_dataset_button = gr.Button(value="(Re)Slice Audio")
|
||||||
|
prepare_dataset_button = gr.Button(value="(Re)Create Dataset")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
EXEC_SETTINGS['whisper_backend'] = gr.Dropdown(WHISPER_BACKENDS, label="Whisper Backends", value=args.whisper_backend)
|
EXEC_SETTINGS['whisper_backend'] = gr.Dropdown(WHISPER_BACKENDS, label="Whisper Backends", value=args.whisper_backend)
|
||||||
|
@ -654,7 +657,7 @@ def setup_gradio():
|
||||||
inputs=None,
|
inputs=None,
|
||||||
outputs=[
|
outputs=[
|
||||||
GENERATE_SETTINGS['voice'],
|
GENERATE_SETTINGS['voice'],
|
||||||
dataset_settings[0],
|
DATASET_SETTINGS['voice'],
|
||||||
history_voices
|
history_voices
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -742,10 +745,11 @@ def setup_gradio():
|
||||||
inputs=dataset_settings,
|
inputs=dataset_settings,
|
||||||
outputs=prepare_dataset_output #console_output
|
outputs=prepare_dataset_output #console_output
|
||||||
)
|
)
|
||||||
prepare_validation_button.click(
|
prepare_dataset_button.click(
|
||||||
prepare_validation_dataset,
|
prepare_dataset,
|
||||||
inputs=[
|
inputs=[
|
||||||
dataset_settings[0],
|
DATASET_SETTINGS['voice'],
|
||||||
|
DATASET_SETTINGS['slice'],
|
||||||
DATASET_SETTINGS['validation_text_length'],
|
DATASET_SETTINGS['validation_text_length'],
|
||||||
DATASET_SETTINGS['validation_audio_length'],
|
DATASET_SETTINGS['validation_audio_length'],
|
||||||
],
|
],
|
||||||
|
@ -754,7 +758,7 @@ def setup_gradio():
|
||||||
slice_dataset_button.click(
|
slice_dataset_button.click(
|
||||||
slice_dataset,
|
slice_dataset,
|
||||||
inputs=[
|
inputs=[
|
||||||
dataset_settings[0]
|
DATASET_SETTINGS['voice']
|
||||||
],
|
],
|
||||||
outputs=prepare_dataset_output
|
outputs=prepare_dataset_output
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user