|
|
|
@ -33,7 +33,7 @@ from datetime import datetime
|
|
|
|
|
from datetime import timedelta
|
|
|
|
|
|
|
|
|
|
from tortoise.api import TextToSpeech, MODELS, get_model_path, pad_or_truncate
|
|
|
|
|
from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir
|
|
|
|
|
from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir, get_voices
|
|
|
|
|
from tortoise.utils.text import split_and_recombine_text
|
|
|
|
|
from tortoise.utils.device import get_device_name, set_device_name, get_device_count, get_device_vram
|
|
|
|
|
|
|
|
|
@ -1059,155 +1059,119 @@ def validate_waveform( waveform, sample_rate ):
|
|
|
|
|
return False
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
def slice_dataset( voice, start_offset=0, end_offset=0 ):
|
|
|
|
|
indir = f'./training/{voice}/'
|
|
|
|
|
infile = f'{indir}/whisper.json'
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(infile):
|
|
|
|
|
raise Exception(f"Missing dataset: {infile}")
|
|
|
|
|
|
|
|
|
|
with open(infile, 'r', encoding="utf-8") as f:
|
|
|
|
|
results = json.load(f)
|
|
|
|
|
|
|
|
|
|
transcription = []
|
|
|
|
|
for filename in results:
|
|
|
|
|
idx = 0
|
|
|
|
|
result = results[filename]
|
|
|
|
|
waveform, sampling_rate = torchaudio.load(f'./voices/{voice}/{filename}')
|
|
|
|
|
|
|
|
|
|
for segment in result['segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress):
|
|
|
|
|
start = int((segment['start'] + start_offset) * sampling_rate)
|
|
|
|
|
end = int((segment['end'] + end_offset) * sampling_rate)
|
|
|
|
|
|
|
|
|
|
sliced_waveform = waveform[:, start:end]
|
|
|
|
|
sliced_name = filename.replace(".wav", f"_{pad(idx, 4)}.wav")
|
|
|
|
|
|
|
|
|
|
if not validate_waveform( sliced_waveform, sampling_rate ):
|
|
|
|
|
print(f"Invalid waveform segment ({segment['start']}:{segment['end']}): {sliced_name}, skipping...")
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
torchaudio.save(f"{indir}/audio/{sliced_name}", sliced_waveform, sampling_rate)
|
|
|
|
|
|
|
|
|
|
idx = idx + 1
|
|
|
|
|
line = f"audio/{sliced_name}|{segment['text'].strip()}"
|
|
|
|
|
transcription.append(line)
|
|
|
|
|
with open(f'{indir}/train.txt', 'a', encoding="utf-8") as f:
|
|
|
|
|
f.write(f'\n{line}')
|
|
|
|
|
|
|
|
|
|
joined = "\n".join(transcription)
|
|
|
|
|
with open(f'{indir}/train.txt', 'w', encoding="utf-8") as f:
|
|
|
|
|
f.write(joined)
|
|
|
|
|
|
|
|
|
|
return f"Processed dataset to: {indir}\n{joined}"
|
|
|
|
|
|
|
|
|
|
def prepare_dataset( files, outdir, language=None, skip_existings=False, progress=None ):
|
|
|
|
|
def transcribe_dataset( voice, language=None, skip_existings=False, progress=None ):
|
|
|
|
|
unload_tts()
|
|
|
|
|
|
|
|
|
|
global whisper_model
|
|
|
|
|
if whisper_model is None:
|
|
|
|
|
load_whisper_model(language=language)
|
|
|
|
|
|
|
|
|
|
os.makedirs(f'{outdir}/audio/', exist_ok=True)
|
|
|
|
|
|
|
|
|
|
results = {}
|
|
|
|
|
transcription = []
|
|
|
|
|
files = sorted(files)
|
|
|
|
|
|
|
|
|
|
previous_list = []
|
|
|
|
|
if skip_existings and os.path.exists(f'{outdir}/train.txt'):
|
|
|
|
|
parsed_list = []
|
|
|
|
|
with open(f'{outdir}/train.txt', 'r', encoding="utf-8") as f:
|
|
|
|
|
parsed_list = f.readlines()
|
|
|
|
|
|
|
|
|
|
for line in parsed_list:
|
|
|
|
|
match = re.findall(r"^(.+?)_\d+\.wav$", line.split("|")[0])
|
|
|
|
|
|
|
|
|
|
if match is None or len(match) == 0:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
if match[0] not in previous_list:
|
|
|
|
|
previous_list.append(f'{match[0].split("/")[-1]}.wav')
|
|
|
|
|
|
|
|
|
|
files = sorted( get_voices(load_latents=False)[voice] )
|
|
|
|
|
indir = f'./training/{voice}/'
|
|
|
|
|
infile = f'{indir}/whisper.json'
|
|
|
|
|
|
|
|
|
|
os.makedirs(f'{indir}/audio/', exist_ok=True)
|
|
|
|
|
|
|
|
|
|
if os.path.exists(infile):
|
|
|
|
|
results = json.load(open(infile, 'r', encoding="utf-8"))
|
|
|
|
|
|
|
|
|
|
for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress):
|
|
|
|
|
basename = os.path.basename(file)
|
|
|
|
|
|
|
|
|
|
if basename in previous_list:
|
|
|
|
|
if basename in results and skip_existings:
|
|
|
|
|
print(f"Skipping already parsed file: {basename}")
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
result = whisper_transcribe(file, language=language)
|
|
|
|
|
results[basename] = result
|
|
|
|
|
print(f"Transcribed file: {file}, {len(result['segments'])} found.")
|
|
|
|
|
results[basename] = whisper_transcribe(file, language=language)
|
|
|
|
|
|
|
|
|
|
# lazy copy
|
|
|
|
|
waveform, sampling_rate = torchaudio.load(file)
|
|
|
|
|
torchaudio.save(f"{indir}/audio/{basename}", waveform, sampling_rate)
|
|
|
|
|
|
|
|
|
|
if not validate_waveform( waveform, sampling_rate ):
|
|
|
|
|
print(f"Invalid waveform: {basename}, skipping...")
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
torchaudio.save(f"{outdir}/audio/{basename}", waveform, sampling_rate)
|
|
|
|
|
line = f"audio/{basename}|{result['text'].strip()}"
|
|
|
|
|
transcription.append(line)
|
|
|
|
|
with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f:
|
|
|
|
|
f.write(f'\n{line}')
|
|
|
|
|
with open(infile, 'w', encoding="utf-8") as f:
|
|
|
|
|
f.write(json.dumps(results, indent='\t'))
|
|
|
|
|
|
|
|
|
|
do_gc()
|
|
|
|
|
|
|
|
|
|
with open(f'{outdir}/whisper.json', 'w', encoding="utf-8") as f:
|
|
|
|
|
f.write(json.dumps(results, indent='\t'))
|
|
|
|
|
|
|
|
|
|
unload_whisper()
|
|
|
|
|
|
|
|
|
|
joined = "\n".join(transcription)
|
|
|
|
|
if not skip_existings:
|
|
|
|
|
with open(f'{outdir}/train.txt', 'w', encoding="utf-8") as f:
|
|
|
|
|
f.write(joined)
|
|
|
|
|
|
|
|
|
|
return f"Processed dataset to: {outdir}\n{joined}"
|
|
|
|
|
return f"Processed dataset to: {indir}"
|
|
|
|
|
|
|
|
|
|
def prepare_validation_dataset( voice, text_length, audio_length ):
|
|
|
|
|
def slice_dataset( voice, start_offset=0, end_offset=0 ):
|
|
|
|
|
indir = f'./training/{voice}/'
|
|
|
|
|
infile = f'{indir}/dataset.txt'
|
|
|
|
|
infile = f'{indir}/whisper.json'
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(infile):
|
|
|
|
|
infile = f'{indir}/train.txt'
|
|
|
|
|
with open(f'{indir}/train.txt', 'r', encoding="utf-8") as src:
|
|
|
|
|
with open(f'{indir}/dataset.txt', 'w', encoding="utf-8") as dst:
|
|
|
|
|
dst.write(src.read())
|
|
|
|
|
raise Exception(f"Missing dataset: {infile}")
|
|
|
|
|
|
|
|
|
|
results = json.load(open(infile, 'r', encoding="utf-8"))
|
|
|
|
|
|
|
|
|
|
files = 0
|
|
|
|
|
segments = 0
|
|
|
|
|
for filename in results:
|
|
|
|
|
files += 1
|
|
|
|
|
|
|
|
|
|
result = results[filename]
|
|
|
|
|
waveform, sampling_rate = torchaudio.load(f'./voices/{voice}/{filename}')
|
|
|
|
|
|
|
|
|
|
for segment in result['segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress):
|
|
|
|
|
segments +=1
|
|
|
|
|
start = int((segment['start'] + start_offset) * sampling_rate)
|
|
|
|
|
end = int((segment['end'] + end_offset) * sampling_rate)
|
|
|
|
|
|
|
|
|
|
sliced = waveform[:, start:end]
|
|
|
|
|
file = filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav")
|
|
|
|
|
|
|
|
|
|
if not validate_waveform( sliced, sampling_rate ):
|
|
|
|
|
print(f"Invalid waveform segment ({segment['start']}:{segment['end']}): {file}, skipping...")
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
torchaudio.save(f"{indir}/audio/{file}", sliced, sampling_rate)
|
|
|
|
|
|
|
|
|
|
return f"Sliced segments: {files} => {segments}."
|
|
|
|
|
|
|
|
|
|
def prepare_dataset( voice, use_segments, text_length, audio_length ):
|
|
|
|
|
indir = f'./training/{voice}/'
|
|
|
|
|
infile = f'{indir}/whisper.json'
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(infile):
|
|
|
|
|
raise Exception(f"Missing dataset: {infile}")
|
|
|
|
|
|
|
|
|
|
with open(infile, 'r', encoding="utf-8") as f:
|
|
|
|
|
lines = f.readlines()
|
|
|
|
|
results = json.load(open(infile, 'r', encoding="utf-8"))
|
|
|
|
|
|
|
|
|
|
lines = {
|
|
|
|
|
'training': [],
|
|
|
|
|
'validation': [],
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
validation = []
|
|
|
|
|
training = []
|
|
|
|
|
for filename in results:
|
|
|
|
|
result = results[filename]
|
|
|
|
|
segments = result['segments'] if use_segments else [{'text': result['text']}]
|
|
|
|
|
for segment in segments:
|
|
|
|
|
text = segment['text'].strip()
|
|
|
|
|
file = filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav") if use_segments else filename
|
|
|
|
|
|
|
|
|
|
for line in lines:
|
|
|
|
|
split = line.split("|")
|
|
|
|
|
filename = split[0]
|
|
|
|
|
text = split[1]
|
|
|
|
|
culled = len(text) < text_length
|
|
|
|
|
culled = len(text) < text_length
|
|
|
|
|
if not culled and audio_length > 0:
|
|
|
|
|
metadata = torchaudio.info(f'{indir}/audio/{file}')
|
|
|
|
|
duration = metadata.num_channels * metadata.num_frames / metadata.sample_rate
|
|
|
|
|
culled = duration < audio_length
|
|
|
|
|
|
|
|
|
|
if not culled and audio_length > 0:
|
|
|
|
|
metadata = torchaudio.info(f'{indir}/{filename}')
|
|
|
|
|
duration = metadata.num_channels * metadata.num_frames / metadata.sample_rate
|
|
|
|
|
culled = duration < audio_length
|
|
|
|
|
lines['training' if not culled else 'validation'].append(f'audio/{file}|{text}')
|
|
|
|
|
|
|
|
|
|
if culled:
|
|
|
|
|
validation.append(line.strip())
|
|
|
|
|
else:
|
|
|
|
|
training.append(line.strip())
|
|
|
|
|
training_joined = "\n".join(lines['training'])
|
|
|
|
|
validation_joined = "\n".join(lines['validation'])
|
|
|
|
|
|
|
|
|
|
with open(f'{indir}/train.txt', 'w', encoding="utf-8") as f:
|
|
|
|
|
f.write("\n".join(training))
|
|
|
|
|
f.write(training_joined)
|
|
|
|
|
|
|
|
|
|
with open(f'{indir}/validation.txt', 'w', encoding="utf-8") as f:
|
|
|
|
|
f.write("\n".join(validation))
|
|
|
|
|
f.write(validation_joined)
|
|
|
|
|
|
|
|
|
|
msg = f"Culled {len(validation)}/{len(lines)} lines."
|
|
|
|
|
print(msg)
|
|
|
|
|
msg = f"Prepared {len(lines['training'])} lines (validation: {len(lines['validation'])}).\n{training_joined}\n\n{validation_joined}"
|
|
|
|
|
return msg
|
|
|
|
|
|
|
|
|
|
def calc_iterations( epochs, lines, batch_size ):
|
|
|
|
|