forked from mrq/ai-voice-cloning
when creating the train/validatio datasets, use segments if the main audio's duration is too long, and slice to make the segments if they don't exist
This commit is contained in:
parent
0cf9db5e69
commit
ee1b048d07
121
src/utils.py
121
src/utils.py
|
@ -51,6 +51,9 @@ LEARNING_RATE_SCHEDULE = [ 9, 18, 25, 33 ]
|
||||||
|
|
||||||
RESAMPLERS = {}
|
RESAMPLERS = {}
|
||||||
|
|
||||||
|
MIN_TRAINING_DURATION = 0.6
|
||||||
|
MAX_TRAINING_DURATION = 11.6097505669
|
||||||
|
|
||||||
args = None
|
args = None
|
||||||
tts = None
|
tts = None
|
||||||
tts_loading = False
|
tts_loading = False
|
||||||
|
@ -62,6 +65,9 @@ training_state = None
|
||||||
current_voice = None
|
current_voice = None
|
||||||
|
|
||||||
def resample( waveform, input_rate, output_rate=44100 ):
|
def resample( waveform, input_rate, output_rate=44100 ):
|
||||||
|
# mono-ize
|
||||||
|
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
||||||
|
|
||||||
if input_rate == output_rate:
|
if input_rate == output_rate:
|
||||||
return waveform, output_rate
|
return waveform, output_rate
|
||||||
|
|
||||||
|
@ -1066,18 +1072,19 @@ def whisper_transcribe( file, language=None ):
|
||||||
result['segments'].append(reparsed)
|
result['segments'].append(reparsed)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def validate_waveform( waveform, sample_rate ):
|
def validate_waveform( waveform, sample_rate, min_only=False ):
|
||||||
if not torch.any(waveform < 0):
|
if not torch.any(waveform < 0):
|
||||||
return "Waveform is empty"
|
return "Waveform is empty"
|
||||||
|
|
||||||
num_channels, num_frames = waveform.shape
|
num_channels, num_frames = waveform.shape
|
||||||
duration = num_channels * num_frames / sample_rate
|
duration = num_channels * num_frames / sample_rate
|
||||||
|
|
||||||
if duration < 0.6:
|
if duration < MIN_TRAINING_DURATION:
|
||||||
return "Duration too short ({:.3f} < 0.6s)".format(duration)
|
return "Duration too short ({:.3f}s < {:.3f}s)".format(duration, MIN_TRAINING_DURATION)
|
||||||
|
|
||||||
if duration > 11:
|
if not min_only:
|
||||||
return "Duration too long (11s < {:.3f})".format(duration)
|
if duration > MAX_TRAINING_DURATION:
|
||||||
|
return "Duration too long ({:.3f}s < {:.3f}s)".format(MAX_TRAINING_DURATION, duration)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -1122,7 +1129,24 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non
|
||||||
|
|
||||||
return f"Processed dataset to: {indir}"
|
return f"Processed dataset to: {indir}"
|
||||||
|
|
||||||
def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0 ):
|
def slice_waveform( waveform, sample_rate, start, end, trim ):
|
||||||
|
start = int(start * sample_rate)
|
||||||
|
end = int(end * sample_rate)
|
||||||
|
|
||||||
|
if start < 0:
|
||||||
|
start = 0
|
||||||
|
if end >= waveform.shape[-1]:
|
||||||
|
end = waveform.shape[-1] - 1
|
||||||
|
|
||||||
|
sliced = waveform[:, start:end]
|
||||||
|
|
||||||
|
error = validate_waveform( sliced, sample_rate, min_only=True )
|
||||||
|
if trim and not error:
|
||||||
|
sliced = torchaudio.functional.vad( sliced, sample_rate )
|
||||||
|
|
||||||
|
return sliced, error
|
||||||
|
|
||||||
|
def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, results=None ):
|
||||||
indir = f'./training/{voice}/'
|
indir = f'./training/{voice}/'
|
||||||
infile = f'{indir}/whisper.json'
|
infile = f'{indir}/whisper.json'
|
||||||
messages = []
|
messages = []
|
||||||
|
@ -1130,6 +1154,7 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0 ):
|
||||||
if not os.path.exists(infile):
|
if not os.path.exists(infile):
|
||||||
raise Exception(f"Missing dataset: {infile}")
|
raise Exception(f"Missing dataset: {infile}")
|
||||||
|
|
||||||
|
if results is None:
|
||||||
results = json.load(open(infile, 'r', encoding="utf-8"))
|
results = json.load(open(infile, 'r', encoding="utf-8"))
|
||||||
|
|
||||||
files = 0
|
files = 0
|
||||||
|
@ -1140,37 +1165,35 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0 ):
|
||||||
path = f'./training/{voice}/{filename}'
|
path = f'./training/{voice}/{filename}'
|
||||||
|
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
messages.append(f"Missing source audio: {filename}")
|
message = f"Missing source audio: {filename}"
|
||||||
|
print(message)
|
||||||
|
messages.append(message)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
files += 1
|
files += 1
|
||||||
result = results[filename]
|
result = results[filename]
|
||||||
waveform, sample_rate = torchaudio.load(path)
|
waveform, sample_rate = torchaudio.load(path)
|
||||||
|
num_channels, num_frames = waveform.shape
|
||||||
|
duration = num_channels * num_frames / sample_rate
|
||||||
|
|
||||||
for segment in result['segments']:
|
for segment in result['segments']:
|
||||||
start = int((segment['start'] + start_offset) * sample_rate)
|
|
||||||
end = int((segment['end'] + end_offset) * sample_rate)
|
|
||||||
|
|
||||||
if start < 0:
|
|
||||||
start = 0
|
|
||||||
if end >= waveform.shape[-1]:
|
|
||||||
end = waveform.shape[-1] - 1
|
|
||||||
|
|
||||||
sliced = waveform[:, start:end]
|
|
||||||
file = filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav")
|
file = filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav")
|
||||||
|
|
||||||
if trim_silence:
|
sliced, error = slice_waveform( waveform, sample_rate, segment['start'] + start_offset, segment['end'] + end_offset, trim_silence )
|
||||||
sliced = torchaudio.functional.vad( sliced, sample_rate )
|
if error:
|
||||||
|
message = f"{error}, skipping... {file}"
|
||||||
sliced, sample_rate = resample( sliced, sample_rate, 22050 )
|
print(message)
|
||||||
torchaudio.save(f"{indir}/audio/{file}", sliced, sample_rate)
|
messages.append(message)
|
||||||
|
continue
|
||||||
|
sliced, _ = resample( sliced, sample_rate, 22050 )
|
||||||
|
torchaudio.save(f"{indir}/audio/{file}", sliced, 22050)
|
||||||
|
|
||||||
segments +=1
|
segments +=1
|
||||||
|
|
||||||
messages.append(f"Sliced segments: {files} => {segments}.")
|
messages.append(f"Sliced segments: {files} => {segments}.")
|
||||||
return "\n".join(messages)
|
return "\n".join(messages)
|
||||||
|
|
||||||
def prepare_dataset( voice, use_segments, text_length, audio_length ):
|
def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=True ):
|
||||||
indir = f'./training/{voice}/'
|
indir = f'./training/{voice}/'
|
||||||
infile = f'{indir}/whisper.json'
|
infile = f'{indir}/whisper.json'
|
||||||
messages = []
|
messages = []
|
||||||
|
@ -1187,31 +1210,67 @@ def prepare_dataset( voice, use_segments, text_length, audio_length ):
|
||||||
|
|
||||||
for filename in results:
|
for filename in results:
|
||||||
result = results[filename]
|
result = results[filename]
|
||||||
segments = result['segments'] if use_segments else [{'text': result['text']}]
|
use_segment = use_segments
|
||||||
for segment in segments:
|
|
||||||
file = filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav") if use_segments else filename
|
# check if unsegmented audio exceeds 11.6s
|
||||||
path = f'{indir}/audio/{file}'
|
if not use_segment:
|
||||||
|
path = f'{indir}/audio/{filename}'
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
messages.append(f"Missing source audio: {file}")
|
messages.append(f"Missing source audio: {filename}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
metadata = torchaudio.info(path)
|
||||||
|
duration = metadata.num_channels * metadata.num_frames / metadata.sample_rate
|
||||||
|
if duration >= MAX_TRAINING_DURATION:
|
||||||
|
message = f"Audio too large, using segments: {filename}"
|
||||||
|
print(message)
|
||||||
|
messages.append(message)
|
||||||
|
use_segment = True
|
||||||
|
|
||||||
|
segments = result['segments'] if use_segment else [{'text': result['text']}]
|
||||||
|
|
||||||
|
for segment in segments:
|
||||||
|
file = filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav") if use_segment else filename
|
||||||
|
path = f'{indir}/audio/{file}'
|
||||||
|
# segment when needed
|
||||||
|
if not os.path.exists(path):
|
||||||
|
tmp_results = {}
|
||||||
|
tmp_results[filename] = result
|
||||||
|
print(f"Audio not segmented, segmenting: {filename}")
|
||||||
|
message = slice_dataset( voice, results=tmp_results )
|
||||||
|
print(message)
|
||||||
|
messages = messages + message.split("\n")
|
||||||
|
|
||||||
|
if not os.path.exists(path):
|
||||||
|
message = f"Missing source audio: {file}"
|
||||||
|
print(message)
|
||||||
|
messages.append(message)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
text = segment['text'].strip()
|
text = segment['text'].strip()
|
||||||
|
normalized_text = text
|
||||||
|
|
||||||
if len(text) > 200:
|
if len(text) > 200:
|
||||||
messages.append(f"[{file}] Text length too long (200 < {len(text)}), skipping...")
|
message = f"Text length too long (200 < {len(text)}), skipping... {file}"
|
||||||
|
print(message)
|
||||||
|
messages.append(message)
|
||||||
|
|
||||||
waveform, sample_rate = torchaudio.load(path)
|
waveform, sample_rate = torchaudio.load(path)
|
||||||
num_channels, num_frames = waveform.shape
|
|
||||||
duration = num_channels * num_frames / sample_rate
|
|
||||||
|
|
||||||
error = validate_waveform( waveform, sample_rate )
|
error = validate_waveform( waveform, sample_rate )
|
||||||
if error:
|
if error:
|
||||||
messages.append(f"[{file}]: {error}, skipping...")
|
message = f"{error}, skipping... {file}"
|
||||||
|
print(message)
|
||||||
|
messages.append(message)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
culled = len(text) < text_length
|
culled = len(text) < text_length
|
||||||
if not culled and audio_length > 0:
|
if not culled and audio_length > 0:
|
||||||
|
num_channels, num_frames = waveform.shape
|
||||||
|
duration = num_channels * num_frames / sample_rate
|
||||||
culled = duration < audio_length
|
culled = duration < audio_length
|
||||||
|
|
||||||
|
# lines['training' if not culled else 'validation'].append(f'audio/{file}|{text}|{normalized_text}')
|
||||||
lines['training' if not culled else 'validation'].append(f'audio/{file}|{text}')
|
lines['training' if not culled else 'validation'].append(f'audio/{file}|{text}')
|
||||||
|
|
||||||
training_joined = "\n".join(lines['training'])
|
training_joined = "\n".join(lines['training'])
|
||||||
|
|
Loading…
Reference in New Issue
Block a user