rely on the whisper.json for handling a lot more things

This commit is contained in:
mrq 2023-03-11 21:17:11 +00:00
parent 9b376c381f
commit 382a3e4104
2 changed files with 94 additions and 126 deletions

View File

@ -33,7 +33,7 @@ from datetime import datetime
from datetime import timedelta from datetime import timedelta
from tortoise.api import TextToSpeech, MODELS, get_model_path, pad_or_truncate from tortoise.api import TextToSpeech, MODELS, get_model_path, pad_or_truncate
from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir, get_voices
from tortoise.utils.text import split_and_recombine_text from tortoise.utils.text import split_and_recombine_text
from tortoise.utils.device import get_device_name, set_device_name, get_device_count, get_device_vram from tortoise.utils.device import get_device_name, set_device_name, get_device_count, get_device_vram
@ -1059,6 +1059,47 @@ def validate_waveform( waveform, sample_rate ):
return False return False
return True return True
def transcribe_dataset( voice, language=None, skip_existings=False, progress=None ):
unload_tts()
global whisper_model
if whisper_model is None:
load_whisper_model(language=language)
results = {}
files = sorted( get_voices(load_latents=False)[voice] )
indir = f'./training/{voice}/'
infile = f'{indir}/whisper.json'
os.makedirs(f'{indir}/audio/', exist_ok=True)
if os.path.exists(infile):
results = json.load(open(infile, 'r', encoding="utf-8"))
for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress):
basename = os.path.basename(file)
if basename in results and skip_existings:
print(f"Skipping already parsed file: {basename}")
continue
results[basename] = whisper_transcribe(file, language=language)
# lazy copy
waveform, sampling_rate = torchaudio.load(file)
torchaudio.save(f"{indir}/audio/{basename}", waveform, sampling_rate)
with open(infile, 'w', encoding="utf-8") as f:
f.write(json.dumps(results, indent='\t'))
do_gc()
unload_whisper()
return f"Processed dataset to: {indir}"
def slice_dataset( voice, start_offset=0, end_offset=0 ): def slice_dataset( voice, start_offset=0, end_offset=0 ):
indir = f'./training/{voice}/' indir = f'./training/{voice}/'
infile = f'{indir}/whisper.json' infile = f'{indir}/whisper.json'
@ -1066,148 +1107,71 @@ def slice_dataset( voice, start_offset=0, end_offset=0 ):
if not os.path.exists(infile): if not os.path.exists(infile):
raise Exception(f"Missing dataset: {infile}") raise Exception(f"Missing dataset: {infile}")
with open(infile, 'r', encoding="utf-8") as f: results = json.load(open(infile, 'r', encoding="utf-8"))
results = json.load(f)
transcription = [] files = 0
segments = 0
for filename in results: for filename in results:
idx = 0 files += 1
result = results[filename] result = results[filename]
waveform, sampling_rate = torchaudio.load(f'./voices/{voice}/{filename}') waveform, sampling_rate = torchaudio.load(f'./voices/{voice}/{filename}')
for segment in result['segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress): for segment in result['segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress):
segments +=1
start = int((segment['start'] + start_offset) * sampling_rate) start = int((segment['start'] + start_offset) * sampling_rate)
end = int((segment['end'] + end_offset) * sampling_rate) end = int((segment['end'] + end_offset) * sampling_rate)
sliced_waveform = waveform[:, start:end] sliced = waveform[:, start:end]
sliced_name = filename.replace(".wav", f"_{pad(idx, 4)}.wav") file = filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav")
if not validate_waveform( sliced_waveform, sampling_rate ): if not validate_waveform( sliced, sampling_rate ):
print(f"Invalid waveform segment ({segment['start']}:{segment['end']}): {sliced_name}, skipping...") print(f"Invalid waveform segment ({segment['start']}:{segment['end']}): {file}, skipping...")
continue continue
torchaudio.save(f"{indir}/audio/{sliced_name}", sliced_waveform, sampling_rate) torchaudio.save(f"{indir}/audio/{file}", sliced, sampling_rate)
idx = idx + 1 return f"Sliced segments: {files} => {segments}."
line = f"audio/{sliced_name}|{segment['text'].strip()}"
transcription.append(line)
with open(f'{indir}/train.txt', 'a', encoding="utf-8") as f:
f.write(f'\n{line}')
joined = "\n".join(transcription) def prepare_dataset( voice, use_segments, text_length, audio_length ):
with open(f'{indir}/train.txt', 'w', encoding="utf-8") as f:
f.write(joined)
return f"Processed dataset to: {indir}\n{joined}"
def prepare_dataset( files, outdir, language=None, skip_existings=False, progress=None ):
unload_tts()
global whisper_model
if whisper_model is None:
load_whisper_model(language=language)
os.makedirs(f'{outdir}/audio/', exist_ok=True)
results = {}
transcription = []
files = sorted(files)
previous_list = []
if skip_existings and os.path.exists(f'{outdir}/train.txt'):
parsed_list = []
with open(f'{outdir}/train.txt', 'r', encoding="utf-8") as f:
parsed_list = f.readlines()
for line in parsed_list:
match = re.findall(r"^(.+?)_\d+\.wav$", line.split("|")[0])
if match is None or len(match) == 0:
continue
if match[0] not in previous_list:
previous_list.append(f'{match[0].split("/")[-1]}.wav')
for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress):
basename = os.path.basename(file)
if basename in previous_list:
print(f"Skipping already parsed file: {basename}")
continue
result = whisper_transcribe(file, language=language)
results[basename] = result
print(f"Transcribed file: {file}, {len(result['segments'])} found.")
waveform, sampling_rate = torchaudio.load(file)
if not validate_waveform( waveform, sampling_rate ):
print(f"Invalid waveform: {basename}, skipping...")
continue
torchaudio.save(f"{outdir}/audio/{basename}", waveform, sampling_rate)
line = f"audio/{basename}|{result['text'].strip()}"
transcription.append(line)
with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f:
f.write(f'\n{line}')
do_gc()
with open(f'{outdir}/whisper.json', 'w', encoding="utf-8") as f:
f.write(json.dumps(results, indent='\t'))
unload_whisper()
joined = "\n".join(transcription)
if not skip_existings:
with open(f'{outdir}/train.txt', 'w', encoding="utf-8") as f:
f.write(joined)
return f"Processed dataset to: {outdir}\n{joined}"
def prepare_validation_dataset( voice, text_length, audio_length ):
indir = f'./training/{voice}/' indir = f'./training/{voice}/'
infile = f'{indir}/dataset.txt' infile = f'{indir}/whisper.json'
if not os.path.exists(infile):
infile = f'{indir}/train.txt'
with open(f'{indir}/train.txt', 'r', encoding="utf-8") as src:
with open(f'{indir}/dataset.txt', 'w', encoding="utf-8") as dst:
dst.write(src.read())
if not os.path.exists(infile): if not os.path.exists(infile):
raise Exception(f"Missing dataset: {infile}") raise Exception(f"Missing dataset: {infile}")
with open(infile, 'r', encoding="utf-8") as f: results = json.load(open(infile, 'r', encoding="utf-8"))
lines = f.readlines()
validation = [] lines = {
training = [] 'training': [],
'validation': [],
}
for filename in results:
result = results[filename]
segments = result['segments'] if use_segments else [{'text': result['text']}]
for segment in segments:
text = segment['text'].strip()
file = filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav") if use_segments else filename
for line in lines:
split = line.split("|")
filename = split[0]
text = split[1]
culled = len(text) < text_length culled = len(text) < text_length
if not culled and audio_length > 0: if not culled and audio_length > 0:
metadata = torchaudio.info(f'{indir}/{filename}') metadata = torchaudio.info(f'{indir}/audio/{file}')
duration = metadata.num_channels * metadata.num_frames / metadata.sample_rate duration = metadata.num_channels * metadata.num_frames / metadata.sample_rate
culled = duration < audio_length culled = duration < audio_length
if culled: lines['training' if not culled else 'validation'].append(f'audio/{file}|{text}')
validation.append(line.strip())
else: training_joined = "\n".join(lines['training'])
training.append(line.strip()) validation_joined = "\n".join(lines['validation'])
with open(f'{indir}/train.txt', 'w', encoding="utf-8") as f: with open(f'{indir}/train.txt', 'w', encoding="utf-8") as f:
f.write("\n".join(training)) f.write(training_joined)
with open(f'{indir}/validation.txt', 'w', encoding="utf-8") as f: with open(f'{indir}/validation.txt', 'w', encoding="utf-8") as f:
f.write("\n".join(validation)) f.write(validation_joined)
msg = f"Culled {len(validation)}/{len(lines)} lines." msg = f"Prepared {len(lines['training'])} lines (validation: {len(lines['validation'])}).\n{training_joined}\n\n{validation_joined}"
print(msg)
return msg return msg
def calc_iterations( epochs, lines, batch_size ): def calc_iterations( epochs, lines, batch_size ):

View File

@ -182,16 +182,19 @@ def read_generate_settings_proxy(file, saveAs='.temp'):
gr.update(visible=j is not None), gr.update(visible=j is not None),
) )
def prepare_dataset_proxy( voice, language, validation_text_length, validation_audio_length, skip_existings, slice_audio, progress=gr.Progress(track_tqdm=True) ): def prepare_dataset_proxy( voice, language, validation_text_length, validation_audio_length, skip_existings, slice_audio, progress=gr.Progress(track_tqdm=False) ):
messages = [] messages = []
message = prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, skip_existings=skip_existings, progress=progress )
message = transcribe_dataset( voice=voice, language=language, skip_existings=skip_existings, progress=progress )
messages.append(message) messages.append(message)
if slice_audio: if slice_audio:
message = slice_dataset( voice ) message = slice_dataset( voice )
messages.append(message) messages.append(message)
if validation_text_length > 0 or validation_audio_length > 0:
message = prepare_validation_dataset( voice, text_length=validation_text_length, audio_length=validation_audio_length ) message = prepare_dataset( voice, use_segments=slice_audio, text_length=validation_text_length, audio_length=validation_audio_length )
messages.append(message) messages.append(message)
return "\n".join(messages) return "\n".join(messages)
def update_args_proxy( *args ): def update_args_proxy( *args ):
@ -421,8 +424,8 @@ def setup_gradio():
with gr.Row(): with gr.Row():
transcribe_button = gr.Button(value="Transcribe") transcribe_button = gr.Button(value="Transcribe")
prepare_validation_button = gr.Button(value="(Re)Create Validation Dataset")
slice_dataset_button = gr.Button(value="(Re)Slice Audio") slice_dataset_button = gr.Button(value="(Re)Slice Audio")
prepare_dataset_button = gr.Button(value="(Re)Create Dataset")
with gr.Row(): with gr.Row():
EXEC_SETTINGS['whisper_backend'] = gr.Dropdown(WHISPER_BACKENDS, label="Whisper Backends", value=args.whisper_backend) EXEC_SETTINGS['whisper_backend'] = gr.Dropdown(WHISPER_BACKENDS, label="Whisper Backends", value=args.whisper_backend)
@ -654,7 +657,7 @@ def setup_gradio():
inputs=None, inputs=None,
outputs=[ outputs=[
GENERATE_SETTINGS['voice'], GENERATE_SETTINGS['voice'],
dataset_settings[0], DATASET_SETTINGS['voice'],
history_voices history_voices
] ]
) )
@ -742,10 +745,11 @@ def setup_gradio():
inputs=dataset_settings, inputs=dataset_settings,
outputs=prepare_dataset_output #console_output outputs=prepare_dataset_output #console_output
) )
prepare_validation_button.click( prepare_dataset_button.click(
prepare_validation_dataset, prepare_dataset,
inputs=[ inputs=[
dataset_settings[0], DATASET_SETTINGS['voice'],
DATASET_SETTINGS['slice'],
DATASET_SETTINGS['validation_text_length'], DATASET_SETTINGS['validation_text_length'],
DATASET_SETTINGS['validation_audio_length'], DATASET_SETTINGS['validation_audio_length'],
], ],
@ -754,7 +758,7 @@ def setup_gradio():
slice_dataset_button.click( slice_dataset_button.click(
slice_dataset, slice_dataset,
inputs=[ inputs=[
dataset_settings[0] DATASET_SETTINGS['voice']
], ],
outputs=prepare_dataset_output outputs=prepare_dataset_output
) )