2024-08-06 00:40:50 +00:00
"""
# Handles processing audio provided through --input-audio of adequately annotated transcriptions provided through --input-metadata (through transcribe.py)
# Outputs NumPy objects containing quantized audio and adequate metadata for use of loading in the trainer through --output-dataset
"""
2024-04-18 18:32:41 +00:00
import os
import json
2024-08-05 20:59:25 +00:00
import argparse
2024-04-18 18:32:41 +00:00
import torch
2024-04-19 02:34:28 +00:00
import torchaudio
2024-05-18 15:13:58 +00:00
import numpy as np
2024-08-05 20:59:25 +00:00
2024-04-18 18:32:41 +00:00
from tqdm . auto import tqdm
from pathlib import Path
2024-08-06 00:40:50 +00:00
from . . config import cfg
2024-04-18 18:32:41 +00:00
2024-04-19 02:34:28 +00:00
def pad ( num , zeroes ) :
return str ( num ) . zfill ( zeroes + 1 )
2024-04-18 18:32:41 +00:00
2024-08-06 01:12:13 +00:00
def process_items ( items , stride = 0 , stride_offset = 0 ) :
2024-08-05 20:59:25 +00:00
items = sorted ( items )
2024-08-06 01:12:13 +00:00
return items if stride == 0 else [ item for i , item in enumerate ( items ) if ( i + stride_offset ) % stride == 0 ]
2024-08-05 20:59:25 +00:00
2024-08-06 00:40:50 +00:00
def process (
audio_backend = " encodec " ,
input_audio = " voices " ,
input_metadata = " metadata " ,
output_dataset = " training " ,
raise_exceptions = False ,
stride = 0 ,
2024-08-06 01:12:13 +00:00
stride_offset = 0 ,
2024-08-06 00:40:50 +00:00
slice = " auto " ,
device = " cuda " ,
dtype = " float16 " ,
amp = False ,
) :
2024-08-05 20:59:25 +00:00
# encodec / vocos
2024-08-06 00:40:50 +00:00
if audio_backend in [ " encodec " , " vocos " ] :
2024-08-05 20:59:25 +00:00
audio_extension = " .enc "
cfg . sample_rate = 24_000
cfg . model . resp_levels = 8
2024-08-06 00:40:50 +00:00
elif audio_backend == " dac " :
2024-08-05 20:59:25 +00:00
audio_extension = " .dac "
cfg . sample_rate = 44_100
cfg . model . resp_levels = 9
elif cfg . audio_backend == " audiodec " :
sample_rate = 48_000
audio_extension = " .dec "
cfg . model . resp_levels = 8 # ?
else :
2024-08-06 00:40:50 +00:00
raise Exception ( f " Unknown audio backend: { audio_backend } " )
2024-08-05 20:59:25 +00:00
# prepare from args
2024-08-06 00:40:50 +00:00
cfg . audio_backend = audio_backend # "encodec"
cfg . inference . weight_dtype = dtype # "bfloat16"
cfg . inference . amp = amp # False
2024-08-05 20:59:25 +00:00
# import after because we've overriden the config above
2024-08-06 01:34:58 +00:00
# need to validate if this is even necessary anymore
2024-08-06 00:40:50 +00:00
from . g2p import encode as phonemize
from . qnt import encode as quantize , _replace_file_extension
2024-08-05 20:59:25 +00:00
2024-08-06 00:40:50 +00:00
output_dataset = f " { output_dataset } / { ' 2 ' if cfg . sample_rate == 24_000 else ' 4 ' } { ' 8 ' if cfg . sample_rate == 48_000 else ' 4 ' } KHz- { cfg . audio_backend } " # "training"
2024-08-05 20:59:25 +00:00
language_map = { } # k = group, v = language
ignore_groups = [ ] # skip these groups
ignore_speakers = [ ] # skip these speakers
only_groups = [ ] # only process these groups
only_speakers = [ ] # only process these speakers
always_slice_groups = [ ] # always slice from this group
missing = {
" transcription " : [ ] ,
" audio " : [ ]
}
dataset = [ ]
for group_name in sorted ( os . listdir ( f ' ./ { input_audio } / ' ) ) :
if not os . path . isdir ( f ' ./ { input_audio } / { group_name } / ' ) :
print ( " Is not dir: " , f ' ./ { input_audio } / { group_name } / ' )
2024-04-18 18:32:41 +00:00
continue
2024-05-12 15:17:29 +00:00
2024-08-05 20:59:25 +00:00
if group_name in ignore_groups :
2024-05-12 15:17:29 +00:00
continue
2024-08-05 20:59:25 +00:00
if only_groups and group_name not in only_groups :
2024-04-19 02:34:28 +00:00
continue
2024-04-18 18:32:41 +00:00
2024-08-06 01:12:13 +00:00
for speaker_id in tqdm ( process_items ( os . listdir ( f ' ./ { input_audio } / { group_name } / ' ) , stride = stride , stride_offset = stride_offset ) , desc = f " Processing speaker in { group_name } " ) :
2024-08-05 20:59:25 +00:00
if not os . path . isdir ( f ' ./ { input_audio } / { group_name } / { speaker_id } ' ) :
print ( " Is not dir: " , f ' ./ { input_audio } / { group_name } / { speaker_id } ' )
2024-04-18 18:32:41 +00:00
continue
2024-08-05 20:59:25 +00:00
if speaker_id in ignore_speakers :
continue
if only_speakers and speaker_id not in only_speakers :
continue
2024-04-19 02:34:28 +00:00
2024-08-06 00:40:50 +00:00
os . makedirs ( f ' ./ { output_dataset } / { group_name } / { speaker_id } / ' , exist_ok = True )
2024-05-16 04:04:19 +00:00
2024-08-05 20:59:25 +00:00
if speaker_id == " Noise " :
for filename in sorted ( os . listdir ( f ' ./ { input_audio } / { group_name } / { speaker_id } / ' ) ) :
inpath = Path ( f ' ./ { input_audio } / { group_name } / { speaker_id } / { filename } ' )
2024-08-06 00:40:50 +00:00
outpath = Path ( f ' ./ { output_dataset } / { group_name } / { speaker_id } / { filename } ' )
2024-05-16 04:04:19 +00:00
if _replace_file_extension ( outpath , audio_extension ) . exists ( ) :
continue
2024-08-05 20:59:25 +00:00
waveform , sample_rate = torchaudio . load ( inpath )
2024-08-06 00:40:50 +00:00
qnt = quantize ( waveform , sr = sample_rate , device = device )
2024-05-16 04:04:19 +00:00
2024-07-04 19:54:11 +00:00
if cfg . audio_backend == " dac " :
2024-05-16 04:04:19 +00:00
np . save ( open ( _replace_file_extension ( outpath , audio_extension ) , " wb " ) , {
2024-05-18 15:13:58 +00:00
" codes " : qnt . codes . cpu ( ) . numpy ( ) . astype ( np . uint16 ) ,
2024-05-16 04:04:19 +00:00
" metadata " : {
" original_length " : qnt . original_length ,
" sample_rate " : qnt . sample_rate ,
2024-05-18 15:13:58 +00:00
" input_db " : qnt . input_db . cpu ( ) . numpy ( ) . astype ( np . float32 ) ,
2024-05-16 04:04:19 +00:00
" chunk_length " : qnt . chunk_length ,
" channels " : qnt . channels ,
" padding " : qnt . padding ,
" dac_version " : " 1.0.0 " ,
} ,
} )
2024-05-12 15:17:29 +00:00
else :
2024-05-16 04:04:19 +00:00
np . save ( open ( _replace_file_extension ( outpath , audio_extension ) , " wb " ) , {
2024-05-18 15:13:58 +00:00
" codes " : qnt . cpu ( ) . numpy ( ) . astype ( np . uint16 ) ,
2024-05-16 04:04:19 +00:00
" metadata " : {
" original_length " : waveform . shape [ - 1 ] ,
" sample_rate " : sample_rate ,
} ,
} )
2024-08-05 20:59:25 +00:00
continue
metadata_path = Path ( f ' ./ { input_metadata } / { group_name } / { speaker_id } /whisper.json ' )
if not metadata_path . exists ( ) :
missing [ " transcription " ] . append ( str ( metadata_path ) )
continue
try :
metadata = json . loads ( open ( metadata_path , " r " , encoding = " utf-8 " ) . read ( ) )
except Exception as e :
missing [ " transcription " ] . append ( str ( metadata_path ) )
continue
if f ' { group_name } / { speaker_id } ' not in dataset :
dataset . append ( f ' { group_name } / { speaker_id } ' )
txts = [ ]
wavs = [ ]
use_slices = slice == True or ( slice == " auto " and len ( metadata . keys ( ) ) == 1 ) or group_name in always_slice_groups
for filename in sorted ( metadata . keys ( ) ) :
inpath = Path ( f ' ./ { input_audio } / { group_name } / { speaker_id } / { filename } ' )
if not inpath . exists ( ) :
missing [ " audio " ] . append ( str ( inpath ) )
2024-04-21 19:49:18 +00:00
continue
2024-08-05 20:59:25 +00:00
extension = os . path . splitext ( filename ) [ - 1 ] [ 1 : ]
fname = filename . replace ( f ' . { extension } ' , " " )
waveform , sample_rate = None , None
language = language_map [ group_name ] if group_name in language_map else ( metadata [ filename ] [ " language " ] if " language " in metadata [ filename ] else " en " )
if len ( metadata [ filename ] [ " segments " ] ) == 0 or not use_slices :
2024-08-06 00:40:50 +00:00
outpath = Path ( f ' ./ { output_dataset } / { group_name } / { speaker_id } / { fname } . { extension } ' )
2024-08-05 20:59:25 +00:00
text = metadata [ filename ] [ " text " ]
if len ( text ) == 0 :
continue
if _replace_file_extension ( outpath , audio_extension ) . exists ( ) :
continue
if waveform is None :
waveform , sample_rate = torchaudio . load ( inpath )
if waveform . shape [ 0 ] > 1 :
waveform = torch . mean ( waveform , dim = 0 , keepdim = True )
wavs . append ( (
outpath ,
text ,
language ,
waveform ,
sample_rate
) )
else :
i = 0
for segment in metadata [ filename ] [ " segments " ] :
id = pad ( i , 4 )
i = i + 1
2024-08-06 00:40:50 +00:00
outpath = Path ( f ' ./ { output_dataset } / { group_name } / { speaker_id } / { fname } _ { id } . { extension } ' )
2024-08-05 20:59:25 +00:00
text = segment [ " text " ]
if len ( text ) == 0 :
continue
if _replace_file_extension ( outpath , audio_extension ) . exists ( ) :
continue
if waveform is None :
waveform , sample_rate = torchaudio . load ( inpath )
if waveform . shape [ 0 ] > 1 :
waveform = torch . mean ( waveform , dim = 0 , keepdim = True )
start = int ( segment [ ' start ' ] * sample_rate )
end = int ( segment [ ' end ' ] * sample_rate )
if start < 0 :
start = 0
if end > = waveform . shape [ - 1 ] :
end = waveform . shape [ - 1 ] - 1
if end - start < 0 :
continue
wavs . append ( (
outpath ,
text ,
language ,
waveform [ : , start : end ] ,
sample_rate
) )
if len ( wavs ) > 0 :
for job in tqdm ( wavs , desc = f " Quantizing: { speaker_id } " ) :
try :
outpath , text , language , waveform , sample_rate = job
2024-08-06 00:40:50 +00:00
phones = phonemize ( text , language = language )
qnt = quantize ( waveform , sr = sample_rate , device = device )
2024-08-05 20:59:25 +00:00
if cfg . audio_backend == " dac " :
np . save ( open ( _replace_file_extension ( outpath , audio_extension ) , " wb " ) , {
" codes " : qnt . codes . cpu ( ) . numpy ( ) . astype ( np . uint16 ) ,
" metadata " : {
" original_length " : qnt . original_length ,
" sample_rate " : qnt . sample_rate ,
" input_db " : qnt . input_db . cpu ( ) . numpy ( ) . astype ( np . float32 ) ,
" chunk_length " : qnt . chunk_length ,
" channels " : qnt . channels ,
" padding " : qnt . padding ,
" dac_version " : " 1.0.0 " ,
" text " : text . strip ( ) ,
" phonemes " : " " . join ( phones ) ,
" language " : language ,
} ,
} )
else :
np . save ( open ( _replace_file_extension ( outpath , audio_extension ) , " wb " ) , {
" codes " : qnt . cpu ( ) . numpy ( ) . astype ( np . uint16 ) ,
" metadata " : {
" original_length " : waveform . shape [ - 1 ] ,
" sample_rate " : sample_rate ,
" text " : text . strip ( ) ,
" phonemes " : " " . join ( phones ) ,
" language " : language ,
} ,
} )
except Exception as e :
print ( f " Failed to quantize: { outpath } : " , e )
if raise_exceptions :
raise e
continue
2024-04-21 19:49:18 +00:00
2024-08-06 01:34:58 +00:00
open ( f " ./ { output_dataset } /missing.json " , ' w ' , encoding = ' utf-8 ' ) . write ( json . dumps ( missing ) )
open ( f " ./ { output_dataset } /dataset.json " , ' w ' , encoding = ' utf-8 ' ) . write ( json . dumps ( dataset ) )
2024-08-05 20:59:25 +00:00
def main ( ) :
parser = argparse . ArgumentParser ( )
parser . add_argument ( " --audio-backend " , type = str , default = " encodec " )
parser . add_argument ( " --dtype " , type = str , default = " bfloat16 " )
parser . add_argument ( " --amp " , action = " store_true " )
parser . add_argument ( " --input-audio " , type = str , default = " voices " )
2024-08-06 00:40:50 +00:00
parser . add_argument ( " --input-metadata " , type = str , default = " training/metadata " )
parser . add_argument ( " --output-dataset " , type = str , default = " training/dataset " )
2024-08-05 20:59:25 +00:00
parser . add_argument ( " --device " , type = str , default = " cuda " )
parser . add_argument ( " --raise-exceptions " , action = " store_true " )
parser . add_argument ( " --stride " , type = int , default = 0 )
2024-08-06 01:12:13 +00:00
parser . add_argument ( " --stride-offset " , type = int , default = 0 )
2024-08-05 20:59:25 +00:00
parser . add_argument ( " --slice " , type = str , default = " auto " )
args = parser . parse_args ( )
2024-08-06 00:40:50 +00:00
process (
audio_backend = args . audio_backend ,
input_audio = args . input_audio ,
input_metadata = args . input_metadata ,
output_dataset = args . output_dataset ,
raise_exceptions = args . raise_exceptions ,
stride = args . stride ,
2024-08-06 01:12:13 +00:00
stride_offset = args . stride_offset ,
2024-08-06 00:40:50 +00:00
slice = args . slice ,
device = args . device ,
dtype = args . dtype ,
amp = args . amp ,
)
2024-08-05 20:59:25 +00:00
if __name__ == " __main__ " :
main ( )