2024-08-06 01:34:58 +00:00
"""
# Handles processing audio provided through --input-audio of adequately annotated transcriptions provided through --input-metadata (through transcribe.py)
# Outputs NumPy objects containing quantized audio and adequate metadata for use of loading in the trainer through --output-dataset
"""
2023-08-26 15:21:12 +00:00
import os
import json
2024-08-06 01:34:58 +00:00
import argparse
2023-08-26 15:21:12 +00:00
import torch
2024-08-06 01:34:58 +00:00
import torchaudio
2024-05-25 16:07:52 +00:00
import numpy as np
2023-08-26 15:21:12 +00:00
from tqdm . auto import tqdm
from pathlib import Path
2024-05-25 16:07:52 +00:00
from vall_e . config import cfg
2024-08-06 01:34:58 +00:00
def pad ( num , zeroes ) :
return str ( num ) . zfill ( zeroes + 1 )
def process_items ( items , stride = 0 , stride_offset = 0 ) :
items = sorted ( items )
return items if stride == 0 else [ item for i , item in enumerate ( items ) if ( i + stride_offset ) % stride == 0 ]
def process (
audio_backend = " encodec " ,
input_audio = " LibriTTS_R " ,
output_dataset = " training " ,
raise_exceptions = False ,
stride = 0 ,
stride_offset = 0 ,
slice = " auto " ,
device = " cuda " ,
dtype = " float16 " ,
amp = False ,
) :
# encodec / vocos
if audio_backend in [ " encodec " , " vocos " ] :
audio_extension = " .enc "
cfg . sample_rate = 24_000
cfg . model . resp_levels = 8
elif audio_backend == " dac " :
audio_extension = " .dac "
cfg . sample_rate = 44_100
cfg . model . resp_levels = 9
elif cfg . audio_backend == " audiodec " :
sample_rate = 48_000
audio_extension = " .dec "
cfg . model . resp_levels = 8 # ?
else :
raise Exception ( f " Unknown audio backend: { audio_backend } " )
# prepare from args
cfg . audio_backend = audio_backend # "encodec"
cfg . inference . weight_dtype = dtype # "bfloat16"
cfg . inference . amp = amp # False
# import after because we've overriden the config above
# need to validate if this is even necessary anymore
from vall_e . emb . g2p import encode as phonemize
from vall_e . emb . qnt import encode as quantize , _replace_file_extension
output_dataset = f " { output_dataset } / { ' 2 ' if cfg . sample_rate == 24_000 else ' 4 ' } { ' 8 ' if cfg . sample_rate == 48_000 else ' 4 ' } KHz- { cfg . audio_backend } " # "training"
language_map = { } # k = group, v = language
ignore_groups = [ ] # skip these groups
ignore_speakers = [ ] # skip these speakers
only_groups = [ ] # only process these groups
only_speakers = [ ] # only process these speakers
always_slice_groups = [ ] # always slice from this group
missing = {
" transcription " : [ ] ,
" audio " : [ ]
}
dataset = [ ]
# Layout: ./LibriTTS_R/train-clean-100/103/1241
for group_name in sorted ( os . listdir ( f ' ./ { input_audio } / ' ) ) :
if not os . path . isdir ( f ' ./ { input_audio } / { group_name } / ' ) :
print ( " Is not dir: " , f ' ./ { input_audio } / { group_name } / ' )
continue
2024-05-25 16:07:52 +00:00
2024-08-06 01:34:58 +00:00
if group_name in ignore_groups :
continue
if only_groups and group_name not in only_groups :
continue
2023-08-26 15:21:12 +00:00
2024-08-06 01:34:58 +00:00
for speaker_id in tqdm ( process_items ( os . listdir ( f ' ./ { input_audio } / { group_name } / ' ) , stride = stride , stride_offset = stride_offset ) , desc = f " Processing speaker in { group_name } " ) :
if not os . path . isdir ( f ' ./ { input_audio } / { group_name } / { speaker_id } ' ) :
print ( " Is not dir: " , f ' ./ { input_audio } / { group_name } / { speaker_id } ' )
continue
if speaker_id in ignore_speakers :
continue
if only_speakers and speaker_id not in only_speakers :
continue
2024-05-25 16:07:52 +00:00
2024-08-06 01:34:58 +00:00
os . makedirs ( f ' ./ { output_dataset } / { group_name } / { speaker_id } / ' , exist_ok = True )
2023-08-26 15:21:12 +00:00
2024-08-06 01:34:58 +00:00
if f ' { group_name } / { speaker_id } ' not in dataset :
dataset . append ( f ' { group_name } / { speaker_id } ' )
2023-08-26 15:21:12 +00:00
2024-08-06 01:34:58 +00:00
txts = [ ]
wavs = [ ]
2023-08-26 15:21:12 +00:00
2024-08-06 01:38:29 +00:00
for book_id in os . listdir ( f ' ./ { input_audio } / { group_name } / { speaker_id } ' ) :
2024-08-06 01:34:58 +00:00
if not os . path . isdir ( f ' ./ { input_audio } / { group_name } / { speaker_id } / { book_id } ' ) :
print ( " Is not dir: " , f ' ./ { input_audio } / { group_name } / { speaker_id } / { book_id } ' )
continue
2024-04-18 18:32:41 +00:00
2024-08-06 01:38:29 +00:00
for filename in os . listdir ( f ' ./ { input_audio } / { group_name } / { speaker_id } / { book_id } ' ) :
if " .wav " not in filename :
continue
2024-08-06 01:34:58 +00:00
inpath = Path ( f ' ./ { input_audio } / { group_name } / { speaker_id } / { book_id } / { filename } ' )
2024-08-06 13:17:25 +00:00
textpath = _replace_file_extension ( inpath , " .original.txt " )
if not inpath . exists ( ) or not textpath . exists ( ) :
2024-08-06 01:34:58 +00:00
missing [ " audio " ] . append ( str ( inpath ) )
2024-08-06 13:17:25 +00:00
continue
2024-04-18 18:32:41 +00:00
2024-08-06 01:34:58 +00:00
extension = os . path . splitext ( filename ) [ - 1 ] [ 1 : ]
fname = filename . replace ( f ' . { extension } ' , " " )
2024-05-25 16:07:52 +00:00
2024-08-06 01:34:58 +00:00
waveform , sample_rate = None , None
language = " en "
2024-05-25 16:07:52 +00:00
2024-08-06 01:34:58 +00:00
outpath = Path ( f ' ./ { output_dataset } / { group_name } / { speaker_id } / { fname } . { extension } ' )
2024-08-06 13:17:25 +00:00
text = open ( textpath , " r " , encoding = " utf-8 " ) . read ( )
2024-05-25 16:07:52 +00:00
2024-08-06 01:34:58 +00:00
if len ( text ) == 0 :
continue
if _replace_file_extension ( outpath , audio_extension ) . exists ( ) :
continue
if waveform is None :
waveform , sample_rate = torchaudio . load ( inpath )
if waveform . shape [ 0 ] > 1 :
waveform = torch . mean ( waveform , dim = 0 , keepdim = True )
wavs . append ( (
outpath ,
text ,
language ,
waveform ,
sample_rate
) )
if len ( wavs ) > 0 :
for job in tqdm ( wavs , desc = f " Quantizing: { speaker_id } " ) :
try :
outpath , text , language , waveform , sample_rate = job
phones = phonemize ( text , language = language )
qnt = quantize ( waveform , sr = sample_rate , device = device )
if cfg . audio_backend == " dac " :
np . save ( open ( _replace_file_extension ( outpath , audio_extension ) , " wb " ) , {
" codes " : qnt . codes . cpu ( ) . numpy ( ) . astype ( np . uint16 ) ,
" metadata " : {
" original_length " : qnt . original_length ,
" sample_rate " : qnt . sample_rate ,
" input_db " : qnt . input_db . cpu ( ) . numpy ( ) . astype ( np . float32 ) ,
" chunk_length " : qnt . chunk_length ,
" channels " : qnt . channels ,
" padding " : qnt . padding ,
" dac_version " : " 1.0.0 " ,
" text " : text . strip ( ) ,
" phonemes " : " " . join ( phones ) ,
" language " : language ,
} ,
} )
else :
np . save ( open ( _replace_file_extension ( outpath , audio_extension ) , " wb " ) , {
" codes " : qnt . cpu ( ) . numpy ( ) . astype ( np . uint16 ) ,
" metadata " : {
" original_length " : waveform . shape [ - 1 ] ,
" sample_rate " : sample_rate ,
" text " : text . strip ( ) ,
" phonemes " : " " . join ( phones ) ,
" language " : language ,
} ,
} )
except Exception as e :
print ( f " Failed to quantize: { outpath } : " , e )
if raise_exceptions :
raise e
continue
open ( f " ./ { output_dataset } /missing.json " , ' w ' , encoding = ' utf-8 ' ) . write ( json . dumps ( missing ) )
open ( f " ./ { output_dataset } /dataset.json " , ' w ' , encoding = ' utf-8 ' ) . write ( json . dumps ( dataset ) )
def main ( ) :
parser = argparse . ArgumentParser ( )
parser . add_argument ( " --audio-backend " , type = str , default = " encodec " )
parser . add_argument ( " --dtype " , type = str , default = " bfloat16 " )
parser . add_argument ( " --amp " , action = " store_true " )
parser . add_argument ( " --input-audio " , type = str , default = " LibriTTS_R " )
parser . add_argument ( " --output-dataset " , type = str , default = " training/dataset " )
parser . add_argument ( " --device " , type = str , default = " cuda " )
parser . add_argument ( " --raise-exceptions " , action = " store_true " )
parser . add_argument ( " --stride " , type = int , default = 0 )
parser . add_argument ( " --stride-offset " , type = int , default = 0 )
parser . add_argument ( " --slice " , type = str , default = " auto " )
args = parser . parse_args ( )
2024-08-06 13:17:25 +00:00
# do some assumption magic
# to-do: find a nice way to spawn multiple processes where tqdm plays nicely
if args . device . isnumeric ( ) :
args . stride = torch . cuda . device_count ( )
args . stride_offset = int ( args . device )
args . device = f ' cuda: { args . device } '
2024-08-06 01:34:58 +00:00
process (
audio_backend = args . audio_backend ,
input_audio = args . input_audio ,
output_dataset = args . output_dataset ,
raise_exceptions = args . raise_exceptions ,
stride = args . stride ,
stride_offset = args . stride_offset ,
slice = args . slice ,
device = args . device ,
dtype = args . dtype ,
amp = args . amp ,
)
if __name__ == " __main__ " :
main ( )