2024-08-06 00:40:50 +00:00
"""
# Handles processing audio provided through --input-audio of adequately annotated transcriptions provided through --input-metadata (through transcribe.py)
# Outputs NumPy objects containing quantized audio and adequate metadata for use of loading in the trainer through --output-dataset
"""
2024-04-18 18:32:41 +00:00
import os
import json
2024-08-05 20:59:25 +00:00
import argparse
2024-04-18 18:32:41 +00:00
import torch
2024-04-19 02:34:28 +00:00
import torchaudio
2024-05-18 15:13:58 +00:00
import numpy as np
2024-08-05 20:59:25 +00:00
2024-04-18 18:32:41 +00:00
from tqdm . auto import tqdm
from pathlib import Path
2024-08-06 00:40:50 +00:00
from . . config import cfg
2024-04-18 18:32:41 +00:00
2024-08-06 13:17:25 +00:00
# need to validate if this is safe to import before modifying the config
from . g2p import encode as phonemize
2024-08-07 01:23:33 +00:00
from . qnt import encode as quantize
2024-08-06 13:17:25 +00:00
2024-04-19 02:34:28 +00:00
def pad ( num , zeroes ) :
return str ( num ) . zfill ( zeroes + 1 )
2024-04-18 18:32:41 +00:00
2024-08-07 01:35:15 +00:00
def load_audio ( path ) :
2024-08-06 19:24:40 +00:00
waveform , sr = torchaudio . load ( path )
if waveform . shape [ 0 ] > 1 :
# mix channels
waveform = torch . mean ( waveform , dim = 0 , keepdim = True )
2024-08-07 01:35:15 +00:00
return waveform , sr
2024-08-06 19:24:40 +00:00
2024-08-06 01:12:13 +00:00
def process_items ( items , stride = 0 , stride_offset = 0 ) :
2024-08-05 20:59:25 +00:00
items = sorted ( items )
2024-08-06 01:12:13 +00:00
return items if stride == 0 else [ item for i , item in enumerate ( items ) if ( i + stride_offset ) % stride == 0 ]
2024-08-05 20:59:25 +00:00
2024-08-07 01:23:33 +00:00
def process_job ( outpath , waveform , sample_rate , text = None , language = " en " ) :
2024-08-07 01:35:15 +00:00
qnt = quantize ( waveform . to ( device = cfg . device ) , sr = sample_rate , device = cfg . device )
2024-08-06 19:24:40 +00:00
if cfg . audio_backend == " dac " :
2024-08-07 01:23:33 +00:00
state_dict = {
2024-08-06 19:24:40 +00:00
" codes " : qnt . codes . cpu ( ) . numpy ( ) . astype ( np . uint16 ) ,
" metadata " : {
" original_length " : qnt . original_length ,
" sample_rate " : qnt . sample_rate ,
" input_db " : qnt . input_db . cpu ( ) . numpy ( ) . astype ( np . float32 ) ,
" chunk_length " : qnt . chunk_length ,
" channels " : qnt . channels ,
" padding " : qnt . padding ,
" dac_version " : " 1.0.0 " ,
} ,
2024-08-07 01:23:33 +00:00
}
2024-08-06 19:24:40 +00:00
else :
2024-08-07 01:23:33 +00:00
state_dict = {
2024-08-06 19:24:40 +00:00
" codes " : qnt . cpu ( ) . numpy ( ) . astype ( np . uint16 ) ,
" metadata " : {
" original_length " : waveform . shape [ - 1 ] ,
" sample_rate " : sample_rate ,
} ,
2024-08-07 01:23:33 +00:00
}
if text :
text = text . strip ( )
state_dict [ ' metadata ' ] | = {
" text " : text ,
" phonemes " : phonemize ( text , language = language ) ,
" language " : language ,
}
np . save ( open ( outpath , " wb " ) , state_dict )
2024-08-06 19:24:40 +00:00
def process_jobs ( jobs , speaker_id = " " , raise_exceptions = True ) :
if not jobs :
return
for job in tqdm ( jobs , desc = f " Quantizing: { speaker_id } " ) :
2024-08-07 01:23:33 +00:00
outpath , waveform , sample_rate , text , language = job
2024-08-06 19:24:40 +00:00
try :
2024-08-07 01:23:33 +00:00
process_job ( outpath , waveform , sample_rate , text , language )
2024-08-06 19:24:40 +00:00
except Exception as e :
print ( f " Failed to quantize: { outpath } : " , e )
if raise_exceptions :
raise e
continue
2024-08-06 00:40:50 +00:00
def process (
audio_backend = " encodec " ,
input_audio = " voices " ,
input_metadata = " metadata " ,
output_dataset = " training " ,
raise_exceptions = False ,
stride = 0 ,
2024-08-06 01:12:13 +00:00
stride_offset = 0 ,
2024-08-06 00:40:50 +00:00
slice = " auto " ,
2024-08-06 19:24:40 +00:00
low_memory = False ,
2024-08-06 00:40:50 +00:00
device = " cuda " ,
dtype = " float16 " ,
amp = False ,
) :
2024-08-05 20:59:25 +00:00
# prepare from args
2024-08-08 12:51:42 +00:00
cfg . set_audio_backend ( audio_backend )
2024-08-07 01:23:33 +00:00
audio_extension = cfg . audio_backend_extension
2024-08-06 00:40:50 +00:00
cfg . inference . weight_dtype = dtype # "bfloat16"
cfg . inference . amp = amp # False
2024-08-05 20:59:25 +00:00
2024-08-06 00:40:50 +00:00
output_dataset = f " { output_dataset } / { ' 2 ' if cfg . sample_rate == 24_000 else ' 4 ' } { ' 8 ' if cfg . sample_rate == 48_000 else ' 4 ' } KHz- { cfg . audio_backend } " # "training"
2024-08-05 20:59:25 +00:00
2024-08-07 01:23:33 +00:00
# to-do: make this also prepared from args
2024-08-05 20:59:25 +00:00
language_map = { } # k = group, v = language
ignore_groups = [ ] # skip these groups
ignore_speakers = [ ] # skip these speakers
only_groups = [ ] # only process these groups
only_speakers = [ ] # only process these speakers
2024-08-08 12:51:42 +00:00
always_slice_groups = [ " Audiobooks " , " LibriVox " ] # always slice from this group
2024-08-06 19:24:40 +00:00
audio_only = [ " Noise " ] # special pathway for processing audio only (without a transcription)
2024-08-05 20:59:25 +00:00
missing = {
" transcription " : [ ] ,
" audio " : [ ]
}
dataset = [ ]
for group_name in sorted ( os . listdir ( f ' ./ { input_audio } / ' ) ) :
if not os . path . isdir ( f ' ./ { input_audio } / { group_name } / ' ) :
print ( " Is not dir: " , f ' ./ { input_audio } / { group_name } / ' )
2024-04-18 18:32:41 +00:00
continue
2024-05-12 15:17:29 +00:00
2024-08-05 20:59:25 +00:00
if group_name in ignore_groups :
2024-05-12 15:17:29 +00:00
continue
2024-08-05 20:59:25 +00:00
if only_groups and group_name not in only_groups :
2024-04-19 02:34:28 +00:00
continue
2024-04-18 18:32:41 +00:00
2024-08-06 01:12:13 +00:00
for speaker_id in tqdm ( process_items ( os . listdir ( f ' ./ { input_audio } / { group_name } / ' ) , stride = stride , stride_offset = stride_offset ) , desc = f " Processing speaker in { group_name } " ) :
2024-08-05 20:59:25 +00:00
if not os . path . isdir ( f ' ./ { input_audio } / { group_name } / { speaker_id } ' ) :
print ( " Is not dir: " , f ' ./ { input_audio } / { group_name } / { speaker_id } ' )
2024-04-18 18:32:41 +00:00
continue
2024-08-05 20:59:25 +00:00
if speaker_id in ignore_speakers :
continue
if only_speakers and speaker_id not in only_speakers :
continue
2024-04-19 02:34:28 +00:00
2024-08-06 00:40:50 +00:00
os . makedirs ( f ' ./ { output_dataset } / { group_name } / { speaker_id } / ' , exist_ok = True )
2024-05-16 04:04:19 +00:00
2024-08-06 19:24:40 +00:00
if speaker_id in audio_only :
2024-08-05 20:59:25 +00:00
for filename in sorted ( os . listdir ( f ' ./ { input_audio } / { group_name } / { speaker_id } / ' ) ) :
inpath = Path ( f ' ./ { input_audio } / { group_name } / { speaker_id } / { filename } ' )
2024-08-07 01:23:33 +00:00
outpath = Path ( f ' ./ { output_dataset } / { group_name } / { speaker_id } / { filename } ' ) . with_suffix ( audio_extension )
2024-05-16 04:04:19 +00:00
2024-08-06 19:24:40 +00:00
if outpath . exists ( ) :
2024-05-16 04:04:19 +00:00
continue
2024-08-07 01:35:15 +00:00
waveform , sample_rate = load_audio ( inpath )
2024-08-06 00:40:50 +00:00
qnt = quantize ( waveform , sr = sample_rate , device = device )
2024-05-16 04:04:19 +00:00
2024-08-07 01:23:33 +00:00
process_job ( outpath , waveform , sample_rate )
2024-08-05 20:59:25 +00:00
continue
metadata_path = Path ( f ' ./ { input_metadata } / { group_name } / { speaker_id } /whisper.json ' )
if not metadata_path . exists ( ) :
missing [ " transcription " ] . append ( str ( metadata_path ) )
continue
try :
metadata = json . loads ( open ( metadata_path , " r " , encoding = " utf-8 " ) . read ( ) )
except Exception as e :
missing [ " transcription " ] . append ( str ( metadata_path ) )
continue
if f ' { group_name } / { speaker_id } ' not in dataset :
dataset . append ( f ' { group_name } / { speaker_id } ' )
2024-08-06 19:24:40 +00:00
jobs = [ ]
2024-08-05 20:59:25 +00:00
use_slices = slice == True or ( slice == " auto " and len ( metadata . keys ( ) ) == 1 ) or group_name in always_slice_groups
for filename in sorted ( metadata . keys ( ) ) :
inpath = Path ( f ' ./ { input_audio } / { group_name } / { speaker_id } / { filename } ' )
if not inpath . exists ( ) :
missing [ " audio " ] . append ( str ( inpath ) )
2024-04-21 19:49:18 +00:00
continue
2024-08-05 20:59:25 +00:00
extension = os . path . splitext ( filename ) [ - 1 ] [ 1 : ]
fname = filename . replace ( f ' . { extension } ' , " " )
waveform , sample_rate = None , None
language = language_map [ group_name ] if group_name in language_map else ( metadata [ filename ] [ " language " ] if " language " in metadata [ filename ] else " en " )
if len ( metadata [ filename ] [ " segments " ] ) == 0 or not use_slices :
2024-08-07 01:23:33 +00:00
outpath = Path ( f ' ./ { output_dataset } / { group_name } / { speaker_id } / { fname } . { extension } ' ) . with_suffix ( audio_extension )
2024-08-05 20:59:25 +00:00
text = metadata [ filename ] [ " text " ]
2024-08-06 19:24:40 +00:00
if len ( text ) == 0 or outpath . exists ( ) :
2024-08-05 20:59:25 +00:00
continue
2024-08-06 19:24:40 +00:00
# audio not already loaded, load it
2024-08-05 20:59:25 +00:00
if waveform is None :
2024-08-07 01:35:15 +00:00
waveform , sample_rate = load_audio ( inpath )
2024-08-06 19:24:40 +00:00
2024-08-07 01:23:33 +00:00
jobs . append ( ( outpath , waveform , sample_rate , text , language ) )
2024-08-05 20:59:25 +00:00
else :
i = 0
for segment in metadata [ filename ] [ " segments " ] :
id = pad ( i , 4 )
i = i + 1
2024-08-07 01:23:33 +00:00
outpath = Path ( f ' ./ { output_dataset } / { group_name } / { speaker_id } / { fname } _ { id } . { extension } ' ) . with_suffix ( audio_extension )
2024-08-05 20:59:25 +00:00
text = segment [ " text " ]
2024-08-06 19:24:40 +00:00
if len ( text ) == 0 or outpath . exists ( ) :
2024-08-05 20:59:25 +00:00
continue
2024-08-06 19:24:40 +00:00
# audio not already loaded, load it
2024-08-05 20:59:25 +00:00
if waveform is None :
2024-08-07 01:35:15 +00:00
waveform , sample_rate = load_audio ( inpath )
2024-08-05 20:59:25 +00:00
start = int ( segment [ ' start ' ] * sample_rate )
end = int ( segment [ ' end ' ] * sample_rate )
if start < 0 :
start = 0
if end > = waveform . shape [ - 1 ] :
end = waveform . shape [ - 1 ] - 1
if end - start < 0 :
continue
2024-08-07 01:23:33 +00:00
jobs . append ( ( outpath , waveform [ : , start : end ] , sample_rate , text , language ) )
2024-08-06 19:24:40 +00:00
# processes audio files one at a time
if low_memory :
process_jobs ( jobs , speaker_id = f ' { speaker_id } / { filename } ' , raise_exceptions = raise_exceptions )
jobs = [ ]
# processes all audio files for a given speaker
if not low_memory :
process_jobs ( jobs , speaker_id = speaker_id , raise_exceptions = raise_exceptions )
jobs = [ ]
2024-04-21 19:49:18 +00:00
2024-08-06 01:34:58 +00:00
open ( f " ./ { output_dataset } /missing.json " , ' w ' , encoding = ' utf-8 ' ) . write ( json . dumps ( missing ) )
open ( f " ./ { output_dataset } /dataset.json " , ' w ' , encoding = ' utf-8 ' ) . write ( json . dumps ( dataset ) )
2024-08-05 20:59:25 +00:00
def main ( ) :
parser = argparse . ArgumentParser ( )
parser . add_argument ( " --audio-backend " , type = str , default = " encodec " )
parser . add_argument ( " --input-audio " , type = str , default = " voices " )
2024-08-06 00:40:50 +00:00
parser . add_argument ( " --input-metadata " , type = str , default = " training/metadata " )
parser . add_argument ( " --output-dataset " , type = str , default = " training/dataset " )
2024-08-05 20:59:25 +00:00
parser . add_argument ( " --raise-exceptions " , action = " store_true " )
2024-08-06 19:24:40 +00:00
parser . add_argument ( " --low-memory " , action = " store_true " )
2024-08-05 20:59:25 +00:00
parser . add_argument ( " --stride " , type = int , default = 0 )
2024-08-06 01:12:13 +00:00
parser . add_argument ( " --stride-offset " , type = int , default = 0 )
2024-08-05 20:59:25 +00:00
parser . add_argument ( " --slice " , type = str , default = " auto " )
2024-08-06 13:17:25 +00:00
parser . add_argument ( " --device " , type = str , default = " cuda " )
parser . add_argument ( " --dtype " , type = str , default = " bfloat16 " )
parser . add_argument ( " --amp " , action = " store_true " )
2024-08-05 20:59:25 +00:00
args = parser . parse_args ( )
2024-08-06 13:17:25 +00:00
# do some assumption magic
# to-do: find a nice way to spawn multiple processes where tqdm plays nicely
if args . device . isnumeric ( ) :
args . stride = torch . cuda . device_count ( )
args . stride_offset = int ( args . device )
args . device = f ' cuda: { args . device } '
2024-08-06 00:40:50 +00:00
process (
audio_backend = args . audio_backend ,
input_audio = args . input_audio ,
input_metadata = args . input_metadata ,
output_dataset = args . output_dataset ,
raise_exceptions = args . raise_exceptions ,
stride = args . stride ,
2024-08-06 01:12:13 +00:00
stride_offset = args . stride_offset ,
2024-08-06 00:40:50 +00:00
slice = args . slice ,
2024-08-06 19:24:40 +00:00
low_memory = args . low_memory ,
2024-08-06 00:40:50 +00:00
device = args . device ,
dtype = args . dtype ,
amp = args . amp ,
)
2024-08-05 20:59:25 +00:00
if __name__ == " __main__ " :
main ( )