2024-09-10 21:34:23 +00:00
"""
# Handles processing audio provided through --input-audio of adequately annotated transcriptions provided through --input-metadata (through transcribe.py)
# Outputs NumPy objects containing quantized audio and adequate metadata for use of loading in the trainer through --output-dataset
"""
import os
import json
import argparse
import torch
import torchaudio
import numpy as np
import logging
2024-09-17 20:51:45 +00:00
from itertools import combinations
2024-09-10 21:34:23 +00:00
_logger = logging . getLogger ( __name__ )
from tqdm . auto import tqdm
from pathlib import Path
import torchaudio . functional as F
import torchaudio . transforms as T
from . . config import cfg
2024-09-17 20:25:12 +00:00
from . . utils import truncate_json
2024-09-10 21:34:23 +00:00
from . g2p import encode as phonemize
from . qnt import encode as quantize , trim , convert_audio
from . . webui import init_tts
def load_audio ( path ) :
waveform , sr = torchaudio . load ( path )
# mix channels
if waveform . shape [ 0 ] > 1 :
waveform = torch . mean ( waveform , dim = 0 , keepdim = True )
# resample
waveform , sr = convert_audio ( waveform , sr , cfg . sample_rate , 1 ) , cfg . sample_rate
return waveform , sr
2024-09-17 21:26:30 +00:00
tts = None
2024-09-10 21:34:23 +00:00
def process (
2024-09-17 19:37:21 +00:00
speaker_path ,
2024-09-10 21:34:23 +00:00
yaml ,
2024-09-10 21:45:59 +00:00
text = False ,
2024-09-10 21:34:23 +00:00
audio_backend = " encodec " ,
device = " cuda " ,
dtype = " float16 " ,
amp = False ,
verbose = False ,
2024-09-17 04:10:29 +00:00
metadata_path = None ,
2024-09-17 21:26:30 +00:00
2024-09-18 02:58:44 +00:00
trim_duration = 0 ,
min_duration = 0 ,
max_duration = 0 ,
storage_backend = " local "
2024-09-10 21:34:23 +00:00
) :
2024-09-17 21:26:30 +00:00
global tts
2024-09-10 21:34:23 +00:00
cfg . set_audio_backend ( audio_backend )
2024-09-10 21:45:59 +00:00
artifact_extension = cfg . audio_backend_extension
2024-09-10 21:34:23 +00:00
cfg . inference . weight_dtype = dtype # "bfloat16"
cfg . inference . amp = amp # False
# easy way to load the model and handle encoding audio
2024-09-17 21:26:30 +00:00
if tts is None :
tts = init_tts ( yaml = yaml , restart = False , device = device , dtype = dtype )
2024-09-10 21:34:23 +00:00
queue = [ ]
features = { }
similarities = { }
sorted_similarities = { }
2024-09-18 02:58:44 +00:00
mfcc = None
2024-09-10 21:34:23 +00:00
2024-09-18 02:58:44 +00:00
slop = False # should probably have a better name for this, but it governs whether to just sum the entire sequence of embeddings into one embedding to make life easier
if storage_backend == " faiss " :
slop = True
elif storage_backend == " chunkdot " :
slop = True
elif storage_backend == " slop " :
slop = True
2024-09-17 21:26:30 +00:00
2024-09-10 21:34:23 +00:00
# compute features (embeddings if quantized already, MFCC features if raw audio)
2024-09-17 20:51:45 +00:00
for filename in tqdm ( os . listdir ( f ' ./ { speaker_path } / ' ) , desc = f " Encoding ' { speaker_path } ' " , disable = not verbose ) :
2024-09-10 21:34:23 +00:00
extension = filename . split ( " . " ) [ - 1 ]
2024-09-17 19:37:21 +00:00
filename = filename . replace ( f " . { extension } " , " " )
2024-09-10 21:34:23 +00:00
2024-09-10 21:45:59 +00:00
if text :
if extension not in artifact_extension :
raise Exception ( " ! " )
2024-09-17 19:37:21 +00:00
artifact = np . load ( f ' ./ { speaker_path } / { filename } . { extension } ' , allow_pickle = True ) [ ( ) ]
2024-09-18 02:58:44 +00:00
duration = artifact [ " metadata " ] [ " original_length " ] / artifact [ " metadata " ] [ " sample_rate " ]
if 0 < min_duration and duration < min_duration :
continue
if 0 < max_duration and max_duration < duration :
continue
2024-09-10 21:45:59 +00:00
lang = artifact [ " metadata " ] [ " language " ] if " language " in artifact [ " metadata " ] [ " language " ] else " en "
if " phonemes " in artifact [ " metadata " ] :
phn = artifact [ " metadata " ] [ " phonemes " ]
elif " text " in artifact [ " metadata " ] :
txt = artifact [ " metadata " ] [ " text " ]
phn = phonemize ( txt , language = lang )
2024-09-10 21:34:23 +00:00
2024-09-10 21:45:59 +00:00
phn = phn . replace ( " (en) " , " " )
if lang != " en " :
phn = phn . replace ( f " ( { metadata [ ' language ' ] } ) " , " " )
2024-09-17 21:26:30 +00:00
embedding = tts . text_embedding ( phn )
2024-09-10 21:34:23 +00:00
else :
2024-09-10 21:45:59 +00:00
# treat embeddings as features, if provided quantized audio
if extension in artifact_extension :
2024-09-17 19:37:21 +00:00
artifact = np . load ( f ' ./ { speaker_path } / { filename } . { extension } ' , allow_pickle = True ) [ ( ) ]
2024-09-18 02:58:44 +00:00
duration = artifact [ " metadata " ] [ " original_length " ] / artifact [ " metadata " ] [ " sample_rate " ]
if 0 < min_duration and duration < min_duration :
continue
if 0 < max_duration and max_duration < duration :
continue
2024-09-10 21:45:59 +00:00
qnt = torch . from_numpy ( artifact [ " codes " ] . astype ( int ) ) [ 0 ] . t ( ) . to ( dtype = torch . int16 , device = device )
2024-09-18 02:58:44 +00:00
if trim_duration > 0 :
qnt = trim ( qnt , int ( cfg . dataset . frames_per_second * trim_duration ) )
2024-09-10 21:45:59 +00:00
2024-09-17 21:26:30 +00:00
embedding = tts . audio_embedding ( qnt )
2024-09-10 21:45:59 +00:00
# try and extract features from the raw audio itself
else :
2024-09-17 19:37:21 +00:00
# qnt = tts.encode_audio(f'./{speaker_path}/{filename}', trim_length=3.0).to(device)
wav , sr = load_audio ( f ' ./ { speaker_path } / { filename } . { extension } ' )
2024-09-18 02:58:44 +00:00
duration = wav . shape [ - 1 ] / sr
if 0 < min_duration and duration < min_duration :
continue
if 0 < max_duration and max_duration < duration :
continue
if mfcc is None :
mfcc = T . MFCC ( sample_rate = cfg . sample_rate )
2024-09-17 21:26:30 +00:00
embedding = mfcc ( wav . to ( device ) ) [ 0 ] . t ( )
2024-09-18 02:58:44 +00:00
if slop :
embedding = embedding . sum ( dim = 0 )
2024-09-17 21:26:30 +00:00
features [ filename ] = embedding
2024-09-18 02:58:44 +00:00
# rely on FAISS to handle storing embeddings and handling queries
# will probably explode in size fast...........
if storage_backend == " faiss " :
import faiss
index = faiss . IndexFlatL2 ( embeddings . shape [ - 1 ] )
embeddings = torch . stack ( list ( features . values ( ) ) ) . cpu ( )
index . add ( embeddings )
2024-09-17 21:26:30 +00:00
2024-09-18 02:58:44 +00:00
"""
# to-do: support just querying for list of similar to cram into JSON metadata
2024-09-17 21:26:30 +00:00
if verbose :
for filename , embedding in features . items ( ) :
2024-09-18 02:58:44 +00:00
D , I = index . search ( embedding . unsqueeze ( 0 ) . cpu ( ) , k = 2 )
sim = list ( I [ 0 ] [ 1 : ] )
print ( f ' { filename } : { sim } ' )
"""
2024-09-17 21:26:30 +00:00
if metadata_path is not None :
2024-09-18 02:58:44 +00:00
faiss . write_index ( index , str ( metadata_path . with_suffix ( " .faiss " ) ) )
return
"""
# to-do: actually refine this, maybe
# desu it's not super easy to install with python3.12, and it is slower than FAISS in testing............
if storage_backend == " chunkdot " :
from chunkdot import cosine_similarity_top_k
embeddings = torch . stack ( list ( features . values ( ) ) ) . cpu ( ) . numpy ( )
similarities = cosine_similarity_top_k ( embeddings , top_k = 8 , show_progress = verbose )
print ( similarities )
return
"""
metadata = None
if metadata_path is not None :
metadata = json . loads ( open ( metadata_path , " r " , encoding = " utf-8 " ) . read ( ) ) if metadata_path . exists ( ) else None
2024-09-10 21:34:23 +00:00
2024-09-17 20:51:45 +00:00
keys = list ( features . keys ( ) )
2024-09-10 21:34:23 +00:00
2024-09-18 02:58:44 +00:00
# do batch cosine similarity processing
if slop :
embeddings = torch . stack ( list ( features . values ( ) ) )
sorted_similarities = { }
for index , filename in enumerate ( keys ) :
embedding = features [ filename ] . unsqueeze ( 0 )
similarities = torch . nn . functional . cosine_similarity ( embedding , embeddings , dim = 1 ) . cpu ( ) . tolist ( )
similarities = sorted ( [ ( keys [ i ] , similarity ) for i , similarity in enumerate ( similarities ) if index != i ] , key = lambda x : x [ 1 ] , reverse = True )
sorted_similarities [ filename ] = similarities
most_filename , most_score = similarities [ 0 ]
least_filename , least_score = similarities [ - 1 ]
if metadata is not None :
if filename not in metadata :
metadata [ filename ] = { }
metadata [ filename ] [ " similar " ] = similarities
if verbose :
print ( f ' { filename } : \n \t Most: { most_filename } ( { most_score : .3f } ) \n \t Least: { least_filename } ( { least_score : .3f } ) ' )
if metadata is not None :
with open ( str ( metadata_path ) , " w " , encoding = " utf-8 " ) as f :
f . write ( truncate_json ( json . dumps ( metadata ) ) )
return sorted_similarities
# an EXTREMELY naive implementation, fucking disgusting
queue = list ( combinations ( range ( len ( keys ) ) , 2 ) )
2024-09-10 21:34:23 +00:00
for key in tqdm ( queue , desc = " Computing similarities " , disable = not verbose ) :
2024-09-17 20:51:45 +00:00
index_a , index_b = key
filename_a , filename_b = keys [ index_a ] , keys [ index_b ]
swapped_key = ( index_b , index_a )
2024-09-10 21:34:23 +00:00
if swapped_key in similarities :
similarities [ key ] = similarities [ swapped_key ]
continue
2024-09-18 02:58:44 +00:00
if slop :
embedding_a = features [ filename_a ]
embedding_b = features [ filename_b ]
similarity = torch . nn . functional . cosine_similarity ( embedding_a , embedding_b , dim = 0 ) . mean ( ) . item ( )
else :
shortest = min ( features [ filename_a ] . shape [ 0 ] , features [ filename_b ] . shape [ 0 ] )
embedding_a = features [ filename_a ] [ : shortest , : ]
embedding_b = features [ filename_b ] [ : shortest , : ]
similarity = torch . nn . functional . cosine_similarity ( embedding_a , embedding_b , dim = 1 ) . mean ( ) . item ( )
2024-09-17 20:51:45 +00:00
similarities [ key ] = similarity
2024-09-10 21:34:23 +00:00
2024-09-17 21:26:30 +00:00
# combinations() doesn't have swapped keys
if swapped_key not in similarities :
similarities [ swapped_key ] = similarity
2024-09-17 20:51:45 +00:00
if index_a not in sorted_similarities :
sorted_similarities [ index_a ] = { }
if index_b not in sorted_similarities [ index_a ] :
sorted_similarities [ index_a ] [ index_b ] = similarity
2024-09-10 21:34:23 +00:00
2024-09-17 20:51:45 +00:00
if index_b not in sorted_similarities :
sorted_similarities [ index_b ] = { }
if index_a not in sorted_similarities [ index_b ] :
sorted_similarities [ index_b ] [ index_a ] = similarity
2024-09-10 21:34:23 +00:00
# sort similarities scores
2024-09-17 20:51:45 +00:00
for key , sorted_similarity in sorted_similarities . items ( ) :
sorted_similarities [ key ] = sorted ( [ ( key , similarity ) for key , similarity in sorted_similarity . items ( ) ] , key = lambda x : x [ 1 ] , reverse = True )
most_filename , most_score = sorted_similarities [ key ] [ 0 ]
least_filename , least_score = sorted_similarities [ key ] [ - 1 ]
2024-09-10 21:34:23 +00:00
2024-09-17 20:51:45 +00:00
filename = keys [ key ]
2024-09-10 21:34:23 +00:00
2024-09-17 20:25:12 +00:00
if metadata is not None :
if filename not in metadata :
metadata [ filename ] = { }
2024-09-17 20:51:45 +00:00
metadata [ filename ] [ " similar " ] = sorted_similarities [ key ]
2024-09-17 04:10:29 +00:00
2024-09-17 20:51:45 +00:00
#if verbose:
# print( f'{filename}:\n\tMost: {most_filename} ({most_score:.3f})\n\tLeast: {least_filename} ({least_score:.3f})' )
2024-09-10 21:34:23 +00:00
2024-09-17 04:10:29 +00:00
if metadata is not None :
with open ( str ( metadata_path ) , " w " , encoding = " utf-8 " ) as f :
2024-09-18 02:58:44 +00:00
f . write ( truncate_json ( json . dumps ( metadata ) ) )
2024-09-10 21:34:23 +00:00
return sorted_similarities
def main ( ) :
parser = argparse . ArgumentParser ( )
2024-09-17 19:37:21 +00:00
parser . add_argument ( " --input-speaker " , type = Path , default = None )
parser . add_argument ( " --use-dataset " , action = " store_true " )
2024-09-10 21:34:23 +00:00
parser . add_argument ( " --yaml " , type = Path )
2024-09-10 21:45:59 +00:00
parser . add_argument ( " --text " , action = " store_true " )
2024-09-18 02:58:44 +00:00
parser . add_argument ( " --trim-duration " , type = float , default = 3.0 )
parser . add_argument ( " --min-duration " , type = float , default = 0 )
parser . add_argument ( " --max-duration " , type = float , default = 0 )
parser . add_argument ( " --storage-backend " , type = str , default = " slop " )
2024-09-10 21:45:59 +00:00
2024-09-10 21:34:23 +00:00
parser . add_argument ( " --audio-backend " , type = str , default = " encodec " )
2024-09-18 02:58:44 +00:00
parser . add_argument ( " --dtype " , type = str , default = " float32 " )
2024-09-10 21:34:23 +00:00
parser . add_argument ( " --amp " , action = " store_true " )
2024-09-17 21:26:30 +00:00
parser . add_argument ( " --device " , type = str , default = " cpu " ) # unironically faster
2024-09-10 21:34:23 +00:00
args = parser . parse_args ( )
2024-09-17 19:37:21 +00:00
if args . use_dataset :
cfg . metadata_dir . mkdir ( parents = True , exist_ok = True )
def add ( dir , type = " training " , audios = True , texts = True ) :
name = str ( dir )
2024-09-17 20:25:12 +00:00
name = name . replace ( str ( cfg . data_dir ) , " " )
2024-09-17 19:37:21 +00:00
speaker_name = name
2024-09-17 20:25:12 +00:00
if " LibriTTS-R " in speaker_name :
speaker_name = speaker_name . replace ( " LibriTTS-R " , " LibriVox " )
2024-09-17 19:37:21 +00:00
process (
speaker_path = cfg . data_dir / speaker_name ,
2024-09-18 02:58:44 +00:00
metadata_path = cfg . metadata_dir / f ' { speaker_name } .faiss ' ,
2024-09-17 19:37:21 +00:00
yaml = args . yaml ,
text = args . text ,
2024-09-18 02:58:44 +00:00
trim_duration = args . trim_duration ,
min_duration = args . min_duration ,
max_duration = args . max_duration ,
storage_backend = args . storage_backend ,
2024-09-17 19:37:21 +00:00
audio_backend = args . audio_backend ,
device = args . device ,
dtype = args . dtype ,
amp = args . amp ,
2024-09-17 20:51:45 +00:00
verbose = True ,
2024-09-17 19:37:21 +00:00
)
# training
for data_dir in tqdm ( sorted ( cfg . dataset . training ) , desc = " Processing Training " ) :
add ( data_dir , type = " training " )
# validation
for data_dir in tqdm ( sorted ( cfg . dataset . validation ) , desc = ' Processing Validation ' ) :
add ( data_dir , type = " validation " )
# noise
for data_dir in tqdm ( sorted ( cfg . dataset . noise ) , desc = ' Processing Noise ' ) :
add ( data_dir , type = " noise " , texts = False )
2024-09-18 02:58:44 +00:00
2024-09-17 19:37:21 +00:00
elif args . input_speaker :
process (
speaker_path = args . input_speaker ,
yaml = args . yaml ,
text = args . text ,
2024-09-18 02:58:44 +00:00
trim_duration = args . trim_duration ,
min_duration = args . min_duration ,
max_duration = args . max_duration ,
2024-09-17 19:37:21 +00:00
audio_backend = args . audio_backend ,
device = args . device ,
dtype = args . dtype ,
amp = args . amp ,
2024-09-18 02:58:44 +00:00
storage_backend = args . storage_backend ,
2024-09-17 19:37:21 +00:00
verbose = True ,
)
else :
raise Exception ( " ! " )
2024-09-10 21:34:23 +00:00
if __name__ == " __main__ " :
main ( )