more tweaks

This commit is contained in:
mrq 2024-09-17 16:26:30 -05:00
parent 804ddb5182
commit 8f41d1b324

View File

@ -40,6 +40,8 @@ def load_audio( path ):
return waveform, sr return waveform, sr
tts = None
def process( def process(
speaker_path, speaker_path,
yaml, yaml,
@ -52,7 +54,12 @@ def process(
verbose=False, verbose=False,
metadata_path=None, metadata_path=None,
maximum_duration=0,
#use_faiss=True,
): ):
global tts
cfg.set_audio_backend(audio_backend) cfg.set_audio_backend(audio_backend)
artifact_extension = cfg.audio_backend_extension artifact_extension = cfg.audio_backend_extension
@ -60,7 +67,8 @@ def process(
cfg.inference.amp = amp # False cfg.inference.amp = amp # False
# easy way to load the model and handle encoding audio # easy way to load the model and handle encoding audio
tts = init_tts( yaml=yaml, restart=False, device=device, dtype=dtype ) if tts is None:
tts = init_tts( yaml=yaml, restart=False, device=device, dtype=dtype )
queue = [] queue = []
features = {} features = {}
@ -69,6 +77,13 @@ def process(
mfcc = T.MFCC(sample_rate=cfg.sample_rate) mfcc = T.MFCC(sample_rate=cfg.sample_rate)
"""
# too slow
if use_faiss:
import faiss
index = None
"""
# compute features (embeddings if quantized already, MFCC features if raw audio) # compute features (embeddings if quantized already, MFCC features if raw audio)
for filename in tqdm(os.listdir(f'./{speaker_path}/'), desc=f"Encoding '{speaker_path}'", disable=not verbose): for filename in tqdm(os.listdir(f'./{speaker_path}/'), desc=f"Encoding '{speaker_path}'", disable=not verbose):
extension = filename.split(".")[-1] extension = filename.split(".")[-1]
@ -91,20 +106,40 @@ def process(
if lang != "en": if lang != "en":
phn = phn.replace(f"({metadata['language']})", "") phn = phn.replace(f"({metadata['language']})", "")
features[filename] = tts.text_embedding( phn ) embedding = tts.text_embedding( phn )
else: else:
# treat embeddings as features, if provided quantized audio # treat embeddings as features, if provided quantized audio
if extension in artifact_extension: if extension in artifact_extension:
artifact = np.load(f'./{speaker_path}/{filename}.{extension}', allow_pickle=True)[()] artifact = np.load(f'./{speaker_path}/{filename}.{extension}', allow_pickle=True)[()]
qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16, device=device) qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16, device=device)
qnt = trim( qnt, int( cfg.dataset.frames_per_second * 3 ) ) if maximum_duration > 0:
qnt = trim( qnt, int( cfg.dataset.frames_per_second * maximum_duration ) )
features[filename] = tts.audio_embedding( qnt ) embedding = tts.audio_embedding( qnt )
# try and extract features from the raw audio itself # try and extract features from the raw audio itself
else: else:
# qnt = tts.encode_audio(f'./{speaker_path}/{filename}', trim_length=3.0).to(device) # qnt = tts.encode_audio(f'./{speaker_path}/{filename}', trim_length=3.0).to(device)
wav, sr = load_audio( f'./{speaker_path}/{filename}.{extension}' ) wav, sr = load_audio( f'./{speaker_path}/{filename}.{extension}' )
features[filename] = mfcc(wav.to(device))[0].t() embedding = mfcc(wav.to(device))[0].t()
features[filename] = embedding
"""
if use_faiss:
if index is None:
shape = embedding.shape
index = faiss.IndexFlatL2(shape[1])
index.add(embedding.cpu())
if verbose:
for filename, embedding in features.items():
D, I = index.search(embedding.cpu(), k=3)
# print(f'{filename}: {I[1:]}')
if metadata_path is not None:
index.save(metadata_path)
"""
keys = list(features.keys()) keys = list(features.keys())
key_range = range(len(keys)) key_range = range(len(keys))
@ -126,6 +161,10 @@ def process(
similarities[key] = similarity similarities[key] = similarity
# combinations() doesn't have swapped keys
if swapped_key not in similarities:
similarities[swapped_key] = similarity
if index_a not in sorted_similarities: if index_a not in sorted_similarities:
sorted_similarities[index_a] = {} sorted_similarities[index_a] = {}
if index_b not in sorted_similarities[index_a]: if index_b not in sorted_similarities[index_a]:
@ -176,11 +215,12 @@ def main():
parser.add_argument("--yaml", type=Path) parser.add_argument("--yaml", type=Path)
parser.add_argument("--text", action="store_true") parser.add_argument("--text", action="store_true")
parser.add_argument("--maximum-duration", type=float, default=3.0)
parser.add_argument("--audio-backend", type=str, default="encodec") parser.add_argument("--audio-backend", type=str, default="encodec")
parser.add_argument("--dtype", type=str, default="bfloat16") parser.add_argument("--dtype", type=str, default="float16")
parser.add_argument("--amp", action="store_true") parser.add_argument("--amp", action="store_true")
parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--device", type=str, default="cpu") # unironically faster
args = parser.parse_args() args = parser.parse_args()
@ -199,6 +239,7 @@ def main():
metadata_path=cfg.metadata_dir / f'{speaker_name}.json', metadata_path=cfg.metadata_dir / f'{speaker_name}.json',
yaml=args.yaml, yaml=args.yaml,
text=args.text, text=args.text,
maximum_duration=args.maximum_duration,
audio_backend=args.audio_backend, audio_backend=args.audio_backend,
device=args.device, device=args.device,
@ -224,6 +265,7 @@ def main():
speaker_path=args.input_speaker, speaker_path=args.input_speaker,
yaml=args.yaml, yaml=args.yaml,
text=args.text, text=args.text,
maximum_duration=args.maximum_duration,
audio_backend=args.audio_backend, audio_backend=args.audio_backend,
device=args.device, device=args.device,