forked from camenduru/ai-voice-cloning
move validating audio to creating the text files instead, consider audio longer than 11 seconds invalid, consider text lengths over 200 invalid
This commit is contained in:
parent
51ddc205cd
commit
239c984850
87
src/utils.py
87
src/utils.py
|
@ -1065,11 +1065,18 @@ def whisper_transcribe( file, language=None ):
|
||||||
|
|
||||||
def validate_waveform( waveform, sample_rate ):
|
def validate_waveform( waveform, sample_rate ):
|
||||||
if not torch.any(waveform < 0):
|
if not torch.any(waveform < 0):
|
||||||
return False
|
return "Waveform is empty"
|
||||||
|
|
||||||
if waveform.shape[-1] < (.6 * sample_rate):
|
num_channels, num_frames = waveform.shape
|
||||||
return False
|
duration = num_channels * num_frames / sample_rate
|
||||||
return True
|
|
||||||
|
if duration < 0.6:
|
||||||
|
return "Duration too short ({:.3f} < 0.6s)".format(duration)
|
||||||
|
|
||||||
|
if duration > 11:
|
||||||
|
return "Duration too long (11s < {:.3f})".format(duration)
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
def transcribe_dataset( voice, language=None, skip_existings=False, progress=None ):
|
def transcribe_dataset( voice, language=None, skip_existings=False, progress=None ):
|
||||||
unload_tts()
|
unload_tts()
|
||||||
|
@ -1100,8 +1107,8 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non
|
||||||
results[basename] = whisper_transcribe(file, language=language)
|
results[basename] = whisper_transcribe(file, language=language)
|
||||||
|
|
||||||
# lazy copy
|
# lazy copy
|
||||||
waveform, sampling_rate = torchaudio.load(file)
|
waveform, sample_rate = torchaudio.load(file)
|
||||||
torchaudio.save(f"{indir}/audio/{basename}", waveform, sampling_rate)
|
torchaudio.save(f"{indir}/audio/{basename}", waveform, sample_rate)
|
||||||
|
|
||||||
with open(infile, 'w', encoding="utf-8") as f:
|
with open(infile, 'w', encoding="utf-8") as f:
|
||||||
f.write(json.dumps(results, indent='\t'))
|
f.write(json.dumps(results, indent='\t'))
|
||||||
|
@ -1115,6 +1122,7 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non
|
||||||
def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0 ):
|
def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0 ):
|
||||||
indir = f'./training/{voice}/'
|
indir = f'./training/{voice}/'
|
||||||
infile = f'{indir}/whisper.json'
|
infile = f'{indir}/whisper.json'
|
||||||
|
messages = []
|
||||||
|
|
||||||
if not os.path.exists(infile):
|
if not os.path.exists(infile):
|
||||||
raise Exception(f"Missing dataset: {infile}")
|
raise Exception(f"Missing dataset: {infile}")
|
||||||
|
@ -1124,15 +1132,21 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0 ):
|
||||||
files = 0
|
files = 0
|
||||||
segments = 0
|
segments = 0
|
||||||
for filename in results:
|
for filename in results:
|
||||||
|
path = f'./voices/{voice}/{filename}'
|
||||||
|
if not os.path.exists(path):
|
||||||
|
path = f'./training/{voice}/{filename}'
|
||||||
|
|
||||||
|
if not os.path.exists(path):
|
||||||
|
messages.append(f"Missing source audio: {filename}")
|
||||||
|
continue
|
||||||
|
|
||||||
files += 1
|
files += 1
|
||||||
|
|
||||||
result = results[filename]
|
result = results[filename]
|
||||||
waveform, sampling_rate = torchaudio.load(f'./voices/{voice}/{filename}')
|
waveform, sample_rate = torchaudio.load(path)
|
||||||
|
|
||||||
for segment in result['segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress):
|
for segment in result['segments']:
|
||||||
segments +=1
|
start = int((segment['start'] + start_offset) * sample_rate)
|
||||||
start = int((segment['start'] + start_offset) * sampling_rate)
|
end = int((segment['end'] + end_offset) * sample_rate)
|
||||||
end = int((segment['end'] + end_offset) * sampling_rate)
|
|
||||||
|
|
||||||
if start < 0:
|
if start < 0:
|
||||||
start = 0
|
start = 0
|
||||||
|
@ -1142,20 +1156,19 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0 ):
|
||||||
sliced = waveform[:, start:end]
|
sliced = waveform[:, start:end]
|
||||||
file = filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav")
|
file = filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav")
|
||||||
|
|
||||||
if not validate_waveform( sliced, sampling_rate ):
|
|
||||||
print(f"Invalid waveform segment ({segment['start']}:{segment['end']}): {file}, skipping...")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if trim_silence:
|
if trim_silence:
|
||||||
sliced = torchaudio.functional.vad( sliced, sampling_rate )
|
sliced = torchaudio.functional.vad( sliced, sample_rate )
|
||||||
|
|
||||||
torchaudio.save(f"{indir}/audio/{file}", sliced, sampling_rate)
|
segments +=1
|
||||||
|
torchaudio.save(f"{indir}/audio/{file}", sliced, sample_rate)
|
||||||
|
|
||||||
return f"Sliced segments: {files} => {segments}."
|
messages.append(f"Sliced segments: {files} => {segments}.")
|
||||||
|
return "\n".join(messages)
|
||||||
|
|
||||||
def prepare_dataset( voice, use_segments, text_length, audio_length ):
|
def prepare_dataset( voice, use_segments, text_length, audio_length ):
|
||||||
indir = f'./training/{voice}/'
|
indir = f'./training/{voice}/'
|
||||||
infile = f'{indir}/whisper.json'
|
infile = f'{indir}/whisper.json'
|
||||||
|
messages = []
|
||||||
|
|
||||||
if not os.path.exists(infile):
|
if not os.path.exists(infile):
|
||||||
raise Exception(f"Missing dataset: {infile}")
|
raise Exception(f"Missing dataset: {infile}")
|
||||||
|
@ -1171,16 +1184,27 @@ def prepare_dataset( voice, use_segments, text_length, audio_length ):
|
||||||
result = results[filename]
|
result = results[filename]
|
||||||
segments = result['segments'] if use_segments else [{'text': result['text']}]
|
segments = result['segments'] if use_segments else [{'text': result['text']}]
|
||||||
for segment in segments:
|
for segment in segments:
|
||||||
text = segment['text'].strip()
|
|
||||||
file = filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav") if use_segments else filename
|
file = filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav") if use_segments else filename
|
||||||
path = f'{indir}/audio/{file}'
|
path = f'{indir}/audio/{file}'
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
|
messages.append(f"Missing source audio: {file}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
text = segment['text'].strip()
|
||||||
|
if len(text) > 200:
|
||||||
|
messages.append(f"[{file}] Text length too long (200 < {len(text)}), skipping...")
|
||||||
|
|
||||||
|
waveform, sample_rate = torchaudio.load(path)
|
||||||
|
num_channels, num_frames = waveform.shape
|
||||||
|
duration = num_channels * num_frames / sample_rate
|
||||||
|
|
||||||
|
error = validate_waveform( waveform, sample_rate )
|
||||||
|
if error:
|
||||||
|
messages.append(f"[{file}]: {error}, skipping...")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
culled = len(text) < text_length
|
culled = len(text) < text_length
|
||||||
if not culled and audio_length > 0:
|
if not culled and audio_length > 0:
|
||||||
metadata = torchaudio.info(path)
|
|
||||||
duration = metadata.num_channels * metadata.num_frames / metadata.sample_rate
|
|
||||||
culled = duration < audio_length
|
culled = duration < audio_length
|
||||||
|
|
||||||
lines['training' if not culled else 'validation'].append(f'audio/{file}|{text}')
|
lines['training' if not culled else 'validation'].append(f'audio/{file}|{text}')
|
||||||
|
@ -1194,8 +1218,8 @@ def prepare_dataset( voice, use_segments, text_length, audio_length ):
|
||||||
with open(f'{indir}/validation.txt', 'w', encoding="utf-8") as f:
|
with open(f'{indir}/validation.txt', 'w', encoding="utf-8") as f:
|
||||||
f.write(validation_joined)
|
f.write(validation_joined)
|
||||||
|
|
||||||
msg = f"Prepared {len(lines['training'])} lines (validation: {len(lines['validation'])}).\n{training_joined}\n\n{validation_joined}"
|
messages.append(f"Prepared {len(lines['training'])} lines (validation: {len(lines['validation'])}).\n{training_joined}\n\n{validation_joined}")
|
||||||
return msg
|
return "\n".join(messages)
|
||||||
|
|
||||||
def calc_iterations( epochs, lines, batch_size ):
|
def calc_iterations( epochs, lines, batch_size ):
|
||||||
iterations = int(epochs * lines / float(batch_size))
|
iterations = int(epochs * lines / float(batch_size))
|
||||||
|
@ -1213,6 +1237,9 @@ def optimize_training_settings( **kwargs ):
|
||||||
with open(dataset_path, 'r', encoding="utf-8") as f:
|
with open(dataset_path, 'r', encoding="utf-8") as f:
|
||||||
lines = len(f.readlines())
|
lines = len(f.readlines())
|
||||||
|
|
||||||
|
if lines == 0:
|
||||||
|
raise Exception("Empty dataset.")
|
||||||
|
|
||||||
if settings['batch_size'] > lines:
|
if settings['batch_size'] > lines:
|
||||||
settings['batch_size'] = lines
|
settings['batch_size'] = lines
|
||||||
messages.append(f"Batch size is larger than your dataset, clamping batch size to: {settings['batch_size']}")
|
messages.append(f"Batch size is larger than your dataset, clamping batch size to: {settings['batch_size']}")
|
||||||
|
@ -1471,17 +1498,17 @@ def import_voices(files, saveAs=None, progress=None):
|
||||||
path = f"{outdir}/{os.path.basename(filename)}"
|
path = f"{outdir}/{os.path.basename(filename)}"
|
||||||
print(f"Importing voice to {path}")
|
print(f"Importing voice to {path}")
|
||||||
|
|
||||||
waveform, sampling_rate = torchaudio.load(filename)
|
waveform, sample_rate = torchaudio.load(filename)
|
||||||
|
|
||||||
if args.voice_fixer:
|
if args.voice_fixer:
|
||||||
if not voicefixer:
|
if not voicefixer:
|
||||||
load_voicefixer()
|
load_voicefixer()
|
||||||
|
|
||||||
# resample to best bandwidth since voicefixer will do it anyways through librosa
|
# resample to best bandwidth since voicefixer will do it anyways through librosa
|
||||||
if sampling_rate != 44100:
|
if sample_rate != 44100:
|
||||||
print(f"Resampling imported voice sample: {path}")
|
print(f"Resampling imported voice sample: {path}")
|
||||||
resampler = torchaudio.transforms.Resample(
|
resampler = torchaudio.transforms.Resample(
|
||||||
sampling_rate,
|
sample_rate,
|
||||||
44100,
|
44100,
|
||||||
lowpass_filter_width=16,
|
lowpass_filter_width=16,
|
||||||
rolloff=0.85,
|
rolloff=0.85,
|
||||||
|
@ -1489,9 +1516,9 @@ def import_voices(files, saveAs=None, progress=None):
|
||||||
beta=8.555504641634386,
|
beta=8.555504641634386,
|
||||||
)
|
)
|
||||||
waveform = resampler(waveform)
|
waveform = resampler(waveform)
|
||||||
sampling_rate = 44100
|
sample_rate = 44100
|
||||||
|
|
||||||
torchaudio.save(path, waveform, sampling_rate)
|
torchaudio.save(path, waveform, sample_rate)
|
||||||
|
|
||||||
print(f"Running 'voicefixer' on voice sample: {path}")
|
print(f"Running 'voicefixer' on voice sample: {path}")
|
||||||
voicefixer.restore(
|
voicefixer.restore(
|
||||||
|
@ -1501,7 +1528,7 @@ def import_voices(files, saveAs=None, progress=None):
|
||||||
#mode=mode,
|
#mode=mode,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
torchaudio.save(path, waveform, sampling_rate)
|
torchaudio.save(path, waveform, sample_rate)
|
||||||
|
|
||||||
print(f"Imported voice to {path}")
|
print(f"Imported voice to {path}")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user