2024-09-10 21:34:23 +00:00
"""
# Handles processing audio provided through --input-audio of adequately annotated transcriptions provided through --input-metadata (through transcribe.py)
# Outputs NumPy objects containing quantized audio and adequate metadata for use of loading in the trainer through --output-dataset
"""
import os
import argparse
import torch
import torchaudio
import numpy as np
import logging
2024-09-17 20:51:45 +00:00
from itertools import combinations
2024-09-10 21:34:23 +00:00
_logger = logging . getLogger ( __name__ )
from tqdm . auto import tqdm
from pathlib import Path
import torchaudio . functional as F
import torchaudio . transforms as T
from . . config import cfg
2024-09-17 20:25:12 +00:00
from . . utils import truncate_json
2024-09-18 03:57:04 +00:00
from . . utils . io import json_read , json_write
2024-09-10 21:34:23 +00:00
from . g2p import encode as phonemize
from . qnt import encode as quantize , trim , convert_audio
from . . webui import init_tts
def load_audio ( path ) :
waveform , sr = torchaudio . load ( path )
# mix channels
if waveform . shape [ 0 ] > 1 :
waveform = torch . mean ( waveform , dim = 0 , keepdim = True )
# resample
waveform , sr = convert_audio ( waveform , sr , cfg . sample_rate , 1 ) , cfg . sample_rate
return waveform , sr
2024-09-17 21:26:30 +00:00
tts = None
2024-09-10 21:34:23 +00:00
def process (
2024-09-17 19:37:21 +00:00
speaker_path ,
2024-09-10 21:34:23 +00:00
yaml ,
2024-09-10 21:45:59 +00:00
text = False ,
2024-09-10 21:34:23 +00:00
audio_backend = " encodec " ,
device = " cuda " ,
dtype = " float16 " ,
amp = False ,
verbose = False ,
2024-09-17 04:10:29 +00:00
metadata_path = None ,
2024-09-18 03:44:36 +00:00
top_k = 8 ,
2024-09-18 21:43:57 +00:00
metadata_keys = [ ] ,
2024-09-17 21:26:30 +00:00
2024-09-18 02:58:44 +00:00
trim_duration = 0 ,
min_duration = 0 ,
max_duration = 0 ,
2024-09-18 03:26:31 +00:00
storage_backend = " slop "
2024-09-10 21:34:23 +00:00
) :
2024-09-17 21:26:30 +00:00
global tts
2024-09-10 21:34:23 +00:00
cfg . set_audio_backend ( audio_backend )
2024-09-10 21:45:59 +00:00
artifact_extension = cfg . audio_backend_extension
2024-09-10 21:34:23 +00:00
cfg . inference . weight_dtype = dtype # "bfloat16"
cfg . inference . amp = amp # False
# easy way to load the model and handle encoding audio
2024-09-17 21:26:30 +00:00
if tts is None :
tts = init_tts ( yaml = yaml , restart = False , device = device , dtype = dtype )
2024-09-10 21:34:23 +00:00
2024-09-18 21:43:57 +00:00
features = { key : None for key in metadata_keys }
2024-09-10 21:34:23 +00:00
2024-09-18 02:58:44 +00:00
mfcc = None
2024-09-10 21:34:23 +00:00
2024-09-18 03:26:31 +00:00
simplified_metadata = True # aims to slim down the raw data in the JSON to store
slop = True # should probably have a better name for this, but it governs whether to just sum the entire sequence of embeddings into one embedding to make life easier
2024-09-17 21:26:30 +00:00
2024-09-18 21:43:57 +00:00
if not speaker_path . exists ( ) :
return
2024-09-10 21:34:23 +00:00
# compute features (embeddings if quantized already, MFCC features if raw audio)
2024-09-18 03:26:31 +00:00
for filename in tqdm ( os . listdir ( f ' ./ { speaker_path } / ' ) , desc = f " Encoding ' { speaker_path . name } ' " , disable = not verbose ) :
2024-09-10 21:34:23 +00:00
extension = filename . split ( " . " ) [ - 1 ]
2024-09-17 19:37:21 +00:00
filename = filename . replace ( f " . { extension } " , " " )
2024-09-10 21:34:23 +00:00
2024-09-10 21:45:59 +00:00
if text :
if extension not in artifact_extension :
raise Exception ( " ! " )
2024-09-17 19:37:21 +00:00
artifact = np . load ( f ' ./ { speaker_path } / { filename } . { extension } ' , allow_pickle = True ) [ ( ) ]
2024-09-18 02:58:44 +00:00
duration = artifact [ " metadata " ] [ " original_length " ] / artifact [ " metadata " ] [ " sample_rate " ]
2024-09-18 21:43:57 +00:00
"""
2024-09-18 02:58:44 +00:00
if 0 < min_duration and duration < min_duration :
continue
if 0 < max_duration and max_duration < duration :
continue
2024-09-18 21:43:57 +00:00
"""
2024-09-10 21:45:59 +00:00
lang = artifact [ " metadata " ] [ " language " ] if " language " in artifact [ " metadata " ] [ " language " ] else " en "
if " phonemes " in artifact [ " metadata " ] :
phn = artifact [ " metadata " ] [ " phonemes " ]
elif " text " in artifact [ " metadata " ] :
txt = artifact [ " metadata " ] [ " text " ]
phn = phonemize ( txt , language = lang )
2024-09-10 21:34:23 +00:00
2024-09-10 21:45:59 +00:00
phn = phn . replace ( " (en) " , " " )
if lang != " en " :
phn = phn . replace ( f " ( { metadata [ ' language ' ] } ) " , " " )
2024-09-17 21:26:30 +00:00
embedding = tts . text_embedding ( phn )
2024-09-10 21:34:23 +00:00
else :
2024-09-10 21:45:59 +00:00
# treat embeddings as features, if provided quantized audio
if extension in artifact_extension :
2024-09-17 19:37:21 +00:00
artifact = np . load ( f ' ./ { speaker_path } / { filename } . { extension } ' , allow_pickle = True ) [ ( ) ]
2024-09-18 02:58:44 +00:00
duration = artifact [ " metadata " ] [ " original_length " ] / artifact [ " metadata " ] [ " sample_rate " ]
2024-09-18 03:26:31 +00:00
"""
2024-09-18 02:58:44 +00:00
if 0 < min_duration and duration < min_duration :
continue
if 0 < max_duration and max_duration < duration :
continue
2024-09-18 03:26:31 +00:00
"""
2024-09-18 02:58:44 +00:00
2024-09-10 21:45:59 +00:00
qnt = torch . from_numpy ( artifact [ " codes " ] . astype ( int ) ) [ 0 ] . t ( ) . to ( dtype = torch . int16 , device = device )
2024-09-18 02:58:44 +00:00
if trim_duration > 0 :
qnt = trim ( qnt , int ( cfg . dataset . frames_per_second * trim_duration ) )
2024-09-10 21:45:59 +00:00
2024-09-17 21:26:30 +00:00
embedding = tts . audio_embedding ( qnt )
2024-09-10 21:45:59 +00:00
# try and extract features from the raw audio itself
else :
2024-09-17 19:37:21 +00:00
# qnt = tts.encode_audio(f'./{speaker_path}/{filename}', trim_length=3.0).to(device)
wav , sr = load_audio ( f ' ./ { speaker_path } / { filename } . { extension } ' )
2024-09-18 02:58:44 +00:00
duration = wav . shape [ - 1 ] / sr
2024-09-18 03:26:31 +00:00
"""
2024-09-18 02:58:44 +00:00
if 0 < min_duration and duration < min_duration :
continue
if 0 < max_duration and max_duration < duration :
continue
2024-09-18 03:26:31 +00:00
"""
2024-09-18 02:58:44 +00:00
if mfcc is None :
mfcc = T . MFCC ( sample_rate = cfg . sample_rate )
2024-09-17 21:26:30 +00:00
embedding = mfcc ( wav . to ( device ) ) [ 0 ] . t ( )
2024-09-18 02:58:44 +00:00
if slop :
embedding = embedding . sum ( dim = 0 )
2024-09-17 21:26:30 +00:00
features [ filename ] = embedding
2024-09-18 02:58:44 +00:00
# rely on FAISS to handle storing embeddings and handling queries
# will probably explode in size fast...........
if storage_backend == " faiss " :
import faiss
index = faiss . IndexFlatL2 ( embeddings . shape [ - 1 ] )
embeddings = torch . stack ( list ( features . values ( ) ) ) . cpu ( )
index . add ( embeddings )
2024-09-17 21:26:30 +00:00
2024-09-18 02:58:44 +00:00
"""
# to-do: support just querying for list of similar to cram into JSON metadata
2024-09-17 21:26:30 +00:00
if verbose :
for filename , embedding in features . items ( ) :
2024-09-18 03:44:36 +00:00
D , I = index . search ( embedding . unsqueeze ( 0 ) . cpu ( ) , k = top_k + 1 )
2024-09-18 02:58:44 +00:00
sim = list ( I [ 0 ] [ 1 : ] )
print ( f ' { filename } : { sim } ' )
"""
2024-09-18 03:26:31 +00:00
return index
2024-09-18 02:58:44 +00:00
# do batch cosine similarity processing
2024-09-18 03:26:31 +00:00
keys = list ( features . keys ( ) )
2024-09-18 21:43:57 +00:00
top_k = min ( top_k , len ( keys ) )
if top_k == 0 :
return
2024-09-19 00:36:03 +00:00
# fill any missing keys with a null embedding to keep the order the same
2024-09-18 21:43:57 +00:00
null_embedding = torch . zeros ( ( 1024 , ) , device = tts . device , dtype = tts . dtype )
embeddings = torch . stack ( [ feature if feature is not None else null_embedding for feature in features . values ( ) ] )
2024-09-18 03:26:31 +00:00
sorted_similarities = { }
2024-09-10 21:34:23 +00:00
2024-09-18 21:43:57 +00:00
2024-09-18 03:44:36 +00:00
for index , filename in tqdm ( enumerate ( keys ) , total = len ( keys ) , desc = f " Computing similarities: { speaker_path . name } " ) :
2024-09-18 21:43:57 +00:00
if features [ filename ] is None :
continue
2024-09-18 03:26:31 +00:00
embedding = features [ filename ] . unsqueeze ( 0 )
2024-09-18 03:44:36 +00:00
similarities = torch . nn . functional . cosine_similarity ( embedding , embeddings , dim = 1 )
2024-09-18 21:43:57 +00:00
# sorting is slow, don't bother
#sorted_similarities[filename] = sorted([ ( i if simplified_metadata else keys[i], similarity ) for i, similarity in enumerate( similarities ) if index != i ], key=lambda x: x[1], reverse=True)
2024-09-18 03:44:36 +00:00
# set current index to -inf
similarities [ index ] = float ( " -inf " )
2024-09-18 21:43:57 +00:00
topk = torch . topk ( similarities , k = top_k , largest = True , sorted = True )
similarities = [ ( index , keys [ index ] , score ) for index , score in zip ( topk . indices . tolist ( ) , topk . values . tolist ( ) ) ]
2024-09-18 03:44:36 +00:00
2024-09-18 03:26:31 +00:00
sorted_similarities [ filename ] = similarities
2024-09-18 03:44:36 +00:00
2024-09-10 21:34:23 +00:00
return sorted_similarities
def main ( ) :
parser = argparse . ArgumentParser ( )
2024-09-17 19:37:21 +00:00
parser . add_argument ( " --input-speaker " , type = Path , default = None )
parser . add_argument ( " --use-dataset " , action = " store_true " )
2024-09-10 21:34:23 +00:00
parser . add_argument ( " --yaml " , type = Path )
2024-09-10 21:45:59 +00:00
parser . add_argument ( " --text " , action = " store_true " )
2024-09-18 03:44:36 +00:00
# dropped, because this might mess with the indices to map to
2024-09-18 03:26:31 +00:00
"""
2024-09-18 02:58:44 +00:00
parser . add_argument ( " --trim-duration " , type = float , default = 3.0 )
parser . add_argument ( " --min-duration " , type = float , default = 0 )
parser . add_argument ( " --max-duration " , type = float , default = 0 )
2024-09-18 03:26:31 +00:00
"""
2024-09-18 02:58:44 +00:00
parser . add_argument ( " --storage-backend " , type = str , default = " slop " )
2024-09-18 03:44:36 +00:00
parser . add_argument ( " --top-k " , type = int , default = 8 )
2024-09-10 21:45:59 +00:00
2024-09-10 21:34:23 +00:00
parser . add_argument ( " --audio-backend " , type = str , default = " encodec " )
2024-09-18 03:44:36 +00:00
parser . add_argument ( " --dtype " , type = str , default = " float16 " )
2024-09-10 21:34:23 +00:00
parser . add_argument ( " --amp " , action = " store_true " )
2024-09-18 03:44:36 +00:00
parser . add_argument ( " --device " , type = str , default = " cuda " )
2024-09-10 21:34:23 +00:00
args = parser . parse_args ( )
2024-09-18 21:43:57 +00:00
args . skip_existing = False #
2024-09-17 19:37:21 +00:00
if args . use_dataset :
cfg . metadata_dir . mkdir ( parents = True , exist_ok = True )
def add ( dir , type = " training " , audios = True , texts = True ) :
name = str ( dir )
2024-09-17 20:25:12 +00:00
name = name . replace ( str ( cfg . data_dir ) , " " )
2024-09-17 19:37:21 +00:00
speaker_name = name
2024-09-18 21:43:57 +00:00
"""
2024-09-17 20:25:12 +00:00
if " LibriTTS-R " in speaker_name :
speaker_name = speaker_name . replace ( " LibriTTS-R " , " LibriVox " )
2024-09-18 21:43:57 +00:00
"""
2024-09-18 03:26:31 +00:00
metadata_path = cfg . metadata_dir / f ' { speaker_name } .json '
2024-09-18 21:43:57 +00:00
metadata = json_read ( metadata_path , default = { } )
metadata_keys = list ( metadata . keys ( ) ) if metadata else [ ]
if args . skip_existing and metadata_keys and " similar " in metadata [ metadata_keys [ - 1 ] ] :
return
2024-09-17 19:37:21 +00:00
2024-09-18 03:26:31 +00:00
similarities = process (
2024-09-17 19:37:21 +00:00
speaker_path = cfg . data_dir / speaker_name ,
yaml = args . yaml ,
text = args . text ,
2024-09-18 03:44:36 +00:00
top_k = args . top_k ,
2024-09-18 03:26:31 +00:00
#trim_duration=args.trim_duration,
#min_duration=args.min_duration,
#max_duration=args.max_duration,
2024-09-18 02:58:44 +00:00
storage_backend = args . storage_backend ,
2024-09-18 21:43:57 +00:00
metadata_keys = metadata_keys ,
2024-09-17 19:37:21 +00:00
audio_backend = args . audio_backend ,
device = args . device ,
dtype = args . dtype ,
amp = args . amp ,
2024-09-17 20:51:45 +00:00
verbose = True ,
2024-09-17 19:37:21 +00:00
)
2024-09-18 21:43:57 +00:00
if not similarities :
return
2024-09-17 19:37:21 +00:00
2024-09-18 03:26:31 +00:00
if args . storage_backend == " faiss " :
faiss . write_index ( similarities , str ( metadata_path . with_suffix ( " .faiss " ) ) )
return
2024-09-18 21:43:57 +00:00
for filename , similar in similarities . items ( ) :
2024-09-18 03:26:31 +00:00
if filename not in metadata :
metadata [ filename ] = { }
2024-09-18 21:43:57 +00:00
# overkill but i'm very paranoid about mismatching indices
metadata [ filename ] [ " similar " ] = [ metadata_keys . index ( s [ 1 ] ) for s in similar ]
2024-09-18 03:26:31 +00:00
2024-09-18 03:57:04 +00:00
json_write ( metadata , metadata_path )
2024-09-17 19:37:21 +00:00
# training
for data_dir in tqdm ( sorted ( cfg . dataset . training ) , desc = " Processing Training " ) :
add ( data_dir , type = " training " )
# validation
for data_dir in tqdm ( sorted ( cfg . dataset . validation ) , desc = ' Processing Validation ' ) :
add ( data_dir , type = " validation " )
# noise
for data_dir in tqdm ( sorted ( cfg . dataset . noise ) , desc = ' Processing Noise ' ) :
add ( data_dir , type = " noise " , texts = False )
2024-09-18 02:58:44 +00:00
2024-09-17 19:37:21 +00:00
elif args . input_speaker :
2024-09-18 21:43:57 +00:00
similarities = process (
2024-09-17 19:37:21 +00:00
speaker_path = args . input_speaker ,
yaml = args . yaml ,
text = args . text ,
2024-09-18 03:44:36 +00:00
top_k = args . top_k ,
2024-09-18 03:26:31 +00:00
#trim_duration=args.trim_duration,
#min_duration=args.min_duration,
#max_duration=args.max_duration,
2024-09-17 19:37:21 +00:00
audio_backend = args . audio_backend ,
device = args . device ,
dtype = args . dtype ,
amp = args . amp ,
2024-09-18 02:58:44 +00:00
storage_backend = args . storage_backend ,
2024-09-17 19:37:21 +00:00
verbose = True ,
)
2024-09-18 21:43:57 +00:00
# and print
for filename , sim in similarities . items ( ) :
print ( f ' { filename } : { sim } ' )
2024-09-17 19:37:21 +00:00
else :
raise Exception ( " ! " )
2024-09-10 21:34:23 +00:00
if __name__ == " __main__ " :
main ( )