more tweaks
This commit is contained in:
parent
804ddb5182
commit
8f41d1b324
|
@ -40,6 +40,8 @@ def load_audio( path ):
|
||||||
|
|
||||||
return waveform, sr
|
return waveform, sr
|
||||||
|
|
||||||
|
tts = None
|
||||||
|
|
||||||
def process(
|
def process(
|
||||||
speaker_path,
|
speaker_path,
|
||||||
yaml,
|
yaml,
|
||||||
|
@ -52,7 +54,12 @@ def process(
|
||||||
|
|
||||||
verbose=False,
|
verbose=False,
|
||||||
metadata_path=None,
|
metadata_path=None,
|
||||||
|
|
||||||
|
maximum_duration=0,
|
||||||
|
#use_faiss=True,
|
||||||
):
|
):
|
||||||
|
global tts
|
||||||
|
|
||||||
cfg.set_audio_backend(audio_backend)
|
cfg.set_audio_backend(audio_backend)
|
||||||
artifact_extension = cfg.audio_backend_extension
|
artifact_extension = cfg.audio_backend_extension
|
||||||
|
|
||||||
|
@ -60,7 +67,8 @@ def process(
|
||||||
cfg.inference.amp = amp # False
|
cfg.inference.amp = amp # False
|
||||||
|
|
||||||
# easy way to load the model and handle encoding audio
|
# easy way to load the model and handle encoding audio
|
||||||
tts = init_tts( yaml=yaml, restart=False, device=device, dtype=dtype )
|
if tts is None:
|
||||||
|
tts = init_tts( yaml=yaml, restart=False, device=device, dtype=dtype )
|
||||||
|
|
||||||
queue = []
|
queue = []
|
||||||
features = {}
|
features = {}
|
||||||
|
@ -69,6 +77,13 @@ def process(
|
||||||
|
|
||||||
mfcc = T.MFCC(sample_rate=cfg.sample_rate)
|
mfcc = T.MFCC(sample_rate=cfg.sample_rate)
|
||||||
|
|
||||||
|
"""
|
||||||
|
# too slow
|
||||||
|
if use_faiss:
|
||||||
|
import faiss
|
||||||
|
index = None
|
||||||
|
"""
|
||||||
|
|
||||||
# compute features (embeddings if quantized already, MFCC features if raw audio)
|
# compute features (embeddings if quantized already, MFCC features if raw audio)
|
||||||
for filename in tqdm(os.listdir(f'./{speaker_path}/'), desc=f"Encoding '{speaker_path}'", disable=not verbose):
|
for filename in tqdm(os.listdir(f'./{speaker_path}/'), desc=f"Encoding '{speaker_path}'", disable=not verbose):
|
||||||
extension = filename.split(".")[-1]
|
extension = filename.split(".")[-1]
|
||||||
|
@ -91,20 +106,40 @@ def process(
|
||||||
if lang != "en":
|
if lang != "en":
|
||||||
phn = phn.replace(f"({metadata['language']})", "")
|
phn = phn.replace(f"({metadata['language']})", "")
|
||||||
|
|
||||||
features[filename] = tts.text_embedding( phn )
|
embedding = tts.text_embedding( phn )
|
||||||
else:
|
else:
|
||||||
# treat embeddings as features, if provided quantized audio
|
# treat embeddings as features, if provided quantized audio
|
||||||
if extension in artifact_extension:
|
if extension in artifact_extension:
|
||||||
artifact = np.load(f'./{speaker_path}/{filename}.{extension}', allow_pickle=True)[()]
|
artifact = np.load(f'./{speaker_path}/{filename}.{extension}', allow_pickle=True)[()]
|
||||||
qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16, device=device)
|
qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16, device=device)
|
||||||
qnt = trim( qnt, int( cfg.dataset.frames_per_second * 3 ) )
|
if maximum_duration > 0:
|
||||||
|
qnt = trim( qnt, int( cfg.dataset.frames_per_second * maximum_duration ) )
|
||||||
|
|
||||||
features[filename] = tts.audio_embedding( qnt )
|
embedding = tts.audio_embedding( qnt )
|
||||||
# try and extract features from the raw audio itself
|
# try and extract features from the raw audio itself
|
||||||
else:
|
else:
|
||||||
# qnt = tts.encode_audio(f'./{speaker_path}/{filename}', trim_length=3.0).to(device)
|
# qnt = tts.encode_audio(f'./{speaker_path}/{filename}', trim_length=3.0).to(device)
|
||||||
wav, sr = load_audio( f'./{speaker_path}/{filename}.{extension}' )
|
wav, sr = load_audio( f'./{speaker_path}/{filename}.{extension}' )
|
||||||
features[filename] = mfcc(wav.to(device))[0].t()
|
embedding = mfcc(wav.to(device))[0].t()
|
||||||
|
|
||||||
|
features[filename] = embedding
|
||||||
|
|
||||||
|
"""
|
||||||
|
if use_faiss:
|
||||||
|
if index is None:
|
||||||
|
shape = embedding.shape
|
||||||
|
index = faiss.IndexFlatL2(shape[1])
|
||||||
|
|
||||||
|
index.add(embedding.cpu())
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
for filename, embedding in features.items():
|
||||||
|
D, I = index.search(embedding.cpu(), k=3)
|
||||||
|
# print(f'{filename}: {I[1:]}')
|
||||||
|
|
||||||
|
if metadata_path is not None:
|
||||||
|
index.save(metadata_path)
|
||||||
|
"""
|
||||||
|
|
||||||
keys = list(features.keys())
|
keys = list(features.keys())
|
||||||
key_range = range(len(keys))
|
key_range = range(len(keys))
|
||||||
|
@ -126,6 +161,10 @@ def process(
|
||||||
|
|
||||||
similarities[key] = similarity
|
similarities[key] = similarity
|
||||||
|
|
||||||
|
# combinations() doesn't have swapped keys
|
||||||
|
if swapped_key not in similarities:
|
||||||
|
similarities[swapped_key] = similarity
|
||||||
|
|
||||||
if index_a not in sorted_similarities:
|
if index_a not in sorted_similarities:
|
||||||
sorted_similarities[index_a] = {}
|
sorted_similarities[index_a] = {}
|
||||||
if index_b not in sorted_similarities[index_a]:
|
if index_b not in sorted_similarities[index_a]:
|
||||||
|
@ -176,11 +215,12 @@ def main():
|
||||||
|
|
||||||
parser.add_argument("--yaml", type=Path)
|
parser.add_argument("--yaml", type=Path)
|
||||||
parser.add_argument("--text", action="store_true")
|
parser.add_argument("--text", action="store_true")
|
||||||
|
parser.add_argument("--maximum-duration", type=float, default=3.0)
|
||||||
|
|
||||||
parser.add_argument("--audio-backend", type=str, default="encodec")
|
parser.add_argument("--audio-backend", type=str, default="encodec")
|
||||||
parser.add_argument("--dtype", type=str, default="bfloat16")
|
parser.add_argument("--dtype", type=str, default="float16")
|
||||||
parser.add_argument("--amp", action="store_true")
|
parser.add_argument("--amp", action="store_true")
|
||||||
parser.add_argument("--device", type=str, default="cuda")
|
parser.add_argument("--device", type=str, default="cpu") # unironically faster
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
@ -199,6 +239,7 @@ def main():
|
||||||
metadata_path=cfg.metadata_dir / f'{speaker_name}.json',
|
metadata_path=cfg.metadata_dir / f'{speaker_name}.json',
|
||||||
yaml=args.yaml,
|
yaml=args.yaml,
|
||||||
text=args.text,
|
text=args.text,
|
||||||
|
maximum_duration=args.maximum_duration,
|
||||||
|
|
||||||
audio_backend=args.audio_backend,
|
audio_backend=args.audio_backend,
|
||||||
device=args.device,
|
device=args.device,
|
||||||
|
@ -224,6 +265,7 @@ def main():
|
||||||
speaker_path=args.input_speaker,
|
speaker_path=args.input_speaker,
|
||||||
yaml=args.yaml,
|
yaml=args.yaml,
|
||||||
text=args.text,
|
text=args.text,
|
||||||
|
maximum_duration=args.maximum_duration,
|
||||||
|
|
||||||
audio_backend=args.audio_backend,
|
audio_backend=args.audio_backend,
|
||||||
device=args.device,
|
device=args.device,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user