ugh
This commit is contained in:
parent
953015748f
commit
075ffef68a
|
@ -111,9 +111,7 @@ def speaker_similarity_embedding(
|
||||||
def batch_similar_utterances(
|
def batch_similar_utterances(
|
||||||
speaker_path,
|
speaker_path,
|
||||||
yaml,
|
yaml,
|
||||||
text=False,
|
|
||||||
|
|
||||||
audio_backend="encodec",
|
|
||||||
device="cuda",
|
device="cuda",
|
||||||
dtype="float16",
|
dtype="float16",
|
||||||
amp=False,
|
amp=False,
|
||||||
|
@ -121,13 +119,18 @@ def batch_similar_utterances(
|
||||||
verbose=False,
|
verbose=False,
|
||||||
metadata_path=None,
|
metadata_path=None,
|
||||||
top_k=8,
|
top_k=8,
|
||||||
|
top_p=0.5,
|
||||||
metadata_keys=[],
|
metadata_keys=[],
|
||||||
|
|
||||||
trim_duration=0,
|
trim_duration=0,
|
||||||
min_duration=0,
|
min_duration=0,
|
||||||
max_duration=0,
|
max_duration=0,
|
||||||
|
|
||||||
storage_backend="slop"
|
audio_backend="encodec",
|
||||||
|
storage_backend="slop",
|
||||||
|
similarity_backend="resp",
|
||||||
|
|
||||||
|
return_features=False,
|
||||||
):
|
):
|
||||||
global tts
|
global tts
|
||||||
|
|
||||||
|
@ -151,12 +154,18 @@ def batch_similar_utterances(
|
||||||
if not speaker_path.exists():
|
if not speaker_path.exists():
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# to-do: find decent thresholds
|
||||||
|
"""
|
||||||
|
if similarity_backend != "wavlm":
|
||||||
|
top_p = float("-inf")
|
||||||
|
"""
|
||||||
|
|
||||||
# compute features (embeddings if quantized already, MFCC features if raw audio)
|
# compute features (embeddings if quantized already, MFCC features if raw audio)
|
||||||
for filename in tqdm(os.listdir(f'./{speaker_path}/'), desc=f"Encoding '{speaker_path.name}'", disable=not verbose):
|
for filename in tqdm(os.listdir(f'./{speaker_path}/'), desc=f"Encoding '{speaker_path.name}'", disable=not verbose):
|
||||||
extension = filename.split(".")[-1]
|
extension = filename.split(".")[-1]
|
||||||
filename = filename.replace(f".{extension}", "")
|
filename = filename.replace(f".{extension}", "")
|
||||||
|
|
||||||
if text:
|
if similarity_backend == "text":
|
||||||
if extension not in artifact_extension:
|
if extension not in artifact_extension:
|
||||||
raise Exception("!")
|
raise Exception("!")
|
||||||
|
|
||||||
|
@ -183,34 +192,37 @@ def batch_similar_utterances(
|
||||||
phn = phn.replace(f"({metadata['language']})", "")
|
phn = phn.replace(f"({metadata['language']})", "")
|
||||||
|
|
||||||
embedding = tts.text_embedding( phn )
|
embedding = tts.text_embedding( phn )
|
||||||
else:
|
elif similarity_backend == "resp":
|
||||||
# treat embeddings as features, if provided quantized audio
|
# treat embeddings as features, if provided quantized audio
|
||||||
if extension in artifact_extension:
|
if extension not in artifact_extension:
|
||||||
artifact = np.load(f'./{speaker_path}/{filename}.{extension}', allow_pickle=True)[()]
|
continue
|
||||||
duration = artifact["metadata"]["original_length"] / artifact["metadata"]["sample_rate"]
|
artifact = np.load(f'./{speaker_path}/{filename}.{extension}', allow_pickle=True)[()]
|
||||||
|
duration = artifact["metadata"]["original_length"] / artifact["metadata"]["sample_rate"]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if 0 < min_duration and duration < min_duration:
|
if 0 < min_duration and duration < min_duration:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if 0 < max_duration and max_duration < duration:
|
if 0 < max_duration and max_duration < duration:
|
||||||
continue
|
continue
|
||||||
"""
|
"""
|
||||||
|
|
||||||
qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16, device=device)
|
qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16, device=device)
|
||||||
|
|
||||||
if trim_duration > 0:
|
if trim_duration > 0:
|
||||||
qnt = trim( qnt, int( cfg.dataset.frames_per_second * trim_duration ) )
|
qnt = trim( qnt, int( cfg.dataset.frames_per_second * trim_duration ) )
|
||||||
|
|
||||||
embedding = tts.audio_embedding( qnt )
|
embedding = tts.audio_embedding( qnt )
|
||||||
# try and extract features from the raw audio itself
|
# try and extract features from the raw audio itself
|
||||||
|
else:
|
||||||
|
# qnt = tts.encode_audio(f'./{speaker_path}/{filename}', trim_length=3.0).to(device)
|
||||||
|
if similarity_backend == "wavlm":
|
||||||
|
embedding = speaker_similarity_embedding( f'./{speaker_path}/{filename}.{extension}' )
|
||||||
else:
|
else:
|
||||||
# qnt = tts.encode_audio(f'./{speaker_path}/{filename}', trim_length=3.0).to(device)
|
|
||||||
wav, sr = load_audio( f'./{speaker_path}/{filename}.{extension}' )
|
wav, sr = load_audio( f'./{speaker_path}/{filename}.{extension}' )
|
||||||
|
|
||||||
duration = wav.shape[-1] / sr
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
duration = wav.shape[-1] / sr
|
||||||
if 0 < min_duration and duration < min_duration:
|
if 0 < min_duration and duration < min_duration:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -223,6 +235,7 @@ def batch_similar_utterances(
|
||||||
|
|
||||||
embedding = mfcc(wav.to(device))[0].t()
|
embedding = mfcc(wav.to(device))[0].t()
|
||||||
|
|
||||||
|
dim_shape = embedding.shape[-1]
|
||||||
if slop:
|
if slop:
|
||||||
embedding = embedding.sum(dim=0)
|
embedding = embedding.sum(dim=0)
|
||||||
|
|
||||||
|
@ -254,10 +267,10 @@ def batch_similar_utterances(
|
||||||
top_k = min( top_k, len(keys) )
|
top_k = min( top_k, len(keys) )
|
||||||
|
|
||||||
if top_k == 0:
|
if top_k == 0:
|
||||||
return
|
top_k = len(keys)
|
||||||
|
|
||||||
# fill any missing keys with a null embedding to keep the order the same
|
# fill any missing keys with a null embedding to keep the order the same
|
||||||
null_embedding = torch.zeros( (1024,), device=tts.device, dtype=tts.dtype )
|
null_embedding = torch.zeros( (dim_shape,), device=tts.device, dtype=tts.dtype )
|
||||||
embeddings = torch.stack( [ feature if feature is not None else null_embedding for feature in features.values() ] )
|
embeddings = torch.stack( [ feature if feature is not None else null_embedding for feature in features.values() ] )
|
||||||
sorted_similarities = {}
|
sorted_similarities = {}
|
||||||
|
|
||||||
|
@ -277,10 +290,12 @@ def batch_similar_utterances(
|
||||||
similarities[index] = float("-inf")
|
similarities[index] = float("-inf")
|
||||||
|
|
||||||
topk = torch.topk(similarities, k=top_k, largest=True, sorted=True)
|
topk = torch.topk(similarities, k=top_k, largest=True, sorted=True)
|
||||||
similarities = [ (index, keys[index], score) for index, score in zip( topk.indices.tolist(), topk.values.tolist() ) ]
|
similarities = [ (index, keys[index], score) for index, score in zip( topk.indices.tolist(), topk.values.tolist() ) if score > top_p ]
|
||||||
|
|
||||||
sorted_similarities[filename] = similarities
|
sorted_similarities[filename] = similarities
|
||||||
|
|
||||||
|
if return_features:
|
||||||
|
return sorted_similarities, features
|
||||||
|
|
||||||
return sorted_similarities
|
return sorted_similarities
|
||||||
|
|
||||||
|
@ -292,17 +307,20 @@ def main():
|
||||||
parser.add_argument("--use-dataset", action="store_true")
|
parser.add_argument("--use-dataset", action="store_true")
|
||||||
|
|
||||||
parser.add_argument("--yaml", type=Path)
|
parser.add_argument("--yaml", type=Path)
|
||||||
parser.add_argument("--text", action="store_true")
|
parser.add_argument("--out-path", type=Path, default=None)
|
||||||
# dropped, because this might mess with the indices to map to
|
# dropped, because this might mess with the indices to map to
|
||||||
"""
|
"""
|
||||||
parser.add_argument("--trim-duration", type=float, default=3.0)
|
parser.add_argument("--trim-duration", type=float, default=3.0)
|
||||||
parser.add_argument("--min-duration", type=float, default=0)
|
parser.add_argument("--min-duration", type=float, default=0)
|
||||||
parser.add_argument("--max-duration", type=float, default=0)
|
parser.add_argument("--max-duration", type=float, default=0)
|
||||||
"""
|
"""
|
||||||
parser.add_argument("--storage-backend", type=str, default="slop")
|
|
||||||
parser.add_argument("--top-k", type=int, default=8)
|
parser.add_argument("--top-k", type=int, default=8)
|
||||||
|
parser.add_argument("--top-p", type=float, default=0.5)
|
||||||
|
|
||||||
|
parser.add_argument("--storage-backend", type=str, default="slop")
|
||||||
|
parser.add_argument("--similarity-backend", type=str, default="resp")
|
||||||
parser.add_argument("--audio-backend", type=str, default="encodec")
|
parser.add_argument("--audio-backend", type=str, default="encodec")
|
||||||
|
|
||||||
parser.add_argument("--dtype", type=str, default="float16")
|
parser.add_argument("--dtype", type=str, default="float16")
|
||||||
parser.add_argument("--amp", action="store_true")
|
parser.add_argument("--amp", action="store_true")
|
||||||
parser.add_argument("--device", type=str, default="cuda")
|
parser.add_argument("--device", type=str, default="cuda")
|
||||||
|
@ -336,15 +354,17 @@ def main():
|
||||||
similarities = batch_similar_utterances(
|
similarities = batch_similar_utterances(
|
||||||
speaker_path=cfg.data_dir / speaker_name,
|
speaker_path=cfg.data_dir / speaker_name,
|
||||||
yaml=args.yaml,
|
yaml=args.yaml,
|
||||||
text=args.text,
|
|
||||||
top_k=args.top_k,
|
top_k=args.top_k,
|
||||||
|
top_p=args.top_p,
|
||||||
#trim_duration=args.trim_duration,
|
#trim_duration=args.trim_duration,
|
||||||
#min_duration=args.min_duration,
|
#min_duration=args.min_duration,
|
||||||
#max_duration=args.max_duration,
|
#max_duration=args.max_duration,
|
||||||
|
audio_backend=args.audio_backend,
|
||||||
storage_backend=args.storage_backend,
|
storage_backend=args.storage_backend,
|
||||||
|
similarity_backend=args.similarity_backend,
|
||||||
|
|
||||||
metadata_keys=metadata_keys,
|
metadata_keys=metadata_keys,
|
||||||
|
|
||||||
audio_backend=args.audio_backend,
|
|
||||||
device=args.device,
|
device=args.device,
|
||||||
dtype=args.dtype,
|
dtype=args.dtype,
|
||||||
amp=args.amp,
|
amp=args.amp,
|
||||||
|
@ -381,25 +401,36 @@ def main():
|
||||||
add( data_dir, type="noise", texts=False )
|
add( data_dir, type="noise", texts=False )
|
||||||
|
|
||||||
elif args.input_speaker:
|
elif args.input_speaker:
|
||||||
similarities = batch_similar_utterances(
|
similarities, features = batch_similar_utterances(
|
||||||
speaker_path=args.input_speaker,
|
speaker_path=args.input_speaker,
|
||||||
yaml=args.yaml,
|
yaml=args.yaml,
|
||||||
text=args.text,
|
|
||||||
top_k=args.top_k,
|
top_k=args.top_k,
|
||||||
|
top_p=args.top_p,
|
||||||
|
|
||||||
#trim_duration=args.trim_duration,
|
#trim_duration=args.trim_duration,
|
||||||
#min_duration=args.min_duration,
|
#min_duration=args.min_duration,
|
||||||
#max_duration=args.max_duration,
|
#max_duration=args.max_duration,
|
||||||
|
|
||||||
audio_backend=args.audio_backend,
|
|
||||||
device=args.device,
|
device=args.device,
|
||||||
dtype=args.dtype,
|
dtype=args.dtype,
|
||||||
amp=args.amp,
|
amp=args.amp,
|
||||||
|
|
||||||
|
audio_backend=args.audio_backend,
|
||||||
storage_backend=args.storage_backend,
|
storage_backend=args.storage_backend,
|
||||||
|
similarity_backend=args.similarity_backend,
|
||||||
|
|
||||||
verbose=True,
|
verbose=True,
|
||||||
|
return_features=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
features_json = {}
|
||||||
|
for k, v in features.items():
|
||||||
|
features_json[k] = [ x.item() for x in v ]
|
||||||
|
|
||||||
|
if args.out_path is not None:
|
||||||
|
json_write( similarities, args.out_path / "similarities.json" )
|
||||||
|
json_write( features_json, args.out_path / "embeddings.json" )
|
||||||
|
|
||||||
# and print
|
# and print
|
||||||
for filename, sim in similarities.items():
|
for filename, sim in similarities.items():
|
||||||
print(f'{filename}: {sim}')
|
print(f'{filename}: {sim}')
|
||||||
|
|
|
@ -165,11 +165,17 @@ def transcribe(
|
||||||
start = 0
|
start = 0
|
||||||
end = 0
|
end = 0
|
||||||
segments = []
|
segments = []
|
||||||
for segment in result["chunks"]:
|
|
||||||
|
info = torchaudio.info(audio)
|
||||||
|
duration = info.num_frames / info.sample_rate
|
||||||
|
|
||||||
|
for segment in result["chunks"]:
|
||||||
text = segment["text"]
|
text = segment["text"]
|
||||||
|
|
||||||
if "timestamp" in segment:
|
if "timestamp" in segment:
|
||||||
s, e = segment["timestamp"]
|
s, e = segment["timestamp"]
|
||||||
|
if not e:
|
||||||
|
e = duration
|
||||||
start = min( start, s )
|
start = min( start, s )
|
||||||
end = max( end, e )
|
end = max( end, e )
|
||||||
else:
|
else:
|
||||||
|
@ -285,12 +291,12 @@ def transcribe_batch(
|
||||||
if not os.path.isdir(f'./{input_audio}/{dataset_name}/'):
|
if not os.path.isdir(f'./{input_audio}/{dataset_name}/'):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if group_name in ignore_groups:
|
if dataset_name in ignore_groups:
|
||||||
continue
|
continue
|
||||||
if only_groups and group_name not in only_groups:
|
if only_groups and dataset_name not in only_groups:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for speaker_id in tqdm(process_items(os.listdir(f'./{input_audio}/{dataset_name}/')), desc="Processing speaker"):
|
for speaker_id in tqdm(process_items(os.listdir(f'./{input_audio}/{dataset_name}/'), stride=stride, stride_offset=stride_offset), desc="Processing speaker"):
|
||||||
if not os.path.isdir(f'./{input_audio}/{dataset_name}/{speaker_id}'):
|
if not os.path.isdir(f'./{input_audio}/{dataset_name}/{speaker_id}'):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
|
@ -46,15 +46,11 @@ except Exception as e:
|
||||||
ERROR_ARCHES["bitnet"] = e
|
ERROR_ARCHES["bitnet"] = e
|
||||||
pass
|
pass
|
||||||
|
|
||||||
from .mixtral import MixtralModel, MixtralConfig, MixtralAttention, MixtralAttention_Adapted, MixtralModel_Adapted, load_balancing_loss_func
|
|
||||||
AVAILABLE_ARCHES.append("mixtral")
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
from .mixtral import MixtralModel, MixtralConfig, MixtralAttention, MixtralAttention_Adapted, MixtralModel_Adapted, load_balancing_loss_func
|
from .mixtral import MixtralModel, MixtralConfig, MixtralAttention, MixtralAttention_Adapted, MixtralModel_Adapted, load_balancing_loss_func
|
||||||
AVAILABLE_ARCHES.append("mixtral")
|
AVAILABLE_ARCHES.append("mixtral")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
ERROR_ARCHES["mixtral"] = e
|
ERROR_ARCHES["mixtral"] = e
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from .mamba import MambaModel, Mamba2Model, MambaConfig, Mamba2Config
|
from .mamba import MambaModel, Mamba2Model, MambaConfig, Mamba2Config
|
||||||
|
|
Loading…
Reference in New Issue
Block a user