also do text similarities (dont know what use I'll have for this)

This commit is contained in:
mrq 2024-09-10 16:45:59 -05:00
parent 1c615a0f52
commit 4f3c7a37c8
2 changed files with 50 additions and 19 deletions

View File

@ -40,14 +40,9 @@ def load_audio( path ):
def process(
input_speaker,
yaml,
text=False,
audio_backend="encodec",
output_dataset="training",
raise_exceptions=False,
stride=0,
stride_offset=0,
slice="auto",
device="cuda",
dtype="float16",
amp=False,
@ -55,7 +50,7 @@ def process(
verbose=False,
):
cfg.set_audio_backend(audio_backend)
audio_extension = cfg.audio_backend_extension
artifact_extension = cfg.audio_backend_extension
cfg.inference.weight_dtype = dtype # "bfloat16"
cfg.inference.amp = amp # False
@ -74,18 +69,38 @@ def process(
for filename in tqdm(os.listdir(f'./{input_speaker}/'), desc="Encoding...", disable=not verbose):
extension = filename.split(".")[-1]
# treat embeddings as features, if provided quantized audio
if extension in audio_extension:
if text:
if extension not in artifact_extension:
raise Exception("!")
artifact = np.load(f'./{input_speaker}/{filename}', allow_pickle=True)[()]
qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16, device=device)
qnt = trim( qnt, int( cfg.dataset.frames_per_second * 3 ) )
lang = artifact["metadata"]["language"] if "language" in artifact["metadata"]["language"] else "en"
if "phonemes" in artifact["metadata"]:
phn = artifact["metadata"]["phonemes"]
elif "text" in artifact["metadata"]:
txt = artifact["metadata"]["text"]
phn = phonemize( txt, language=lang )
features[filename] = tts.audio_embedding( qnt )
# try and extract features from the raw audio itself
phn = phn.replace("(en)", "")
if lang != "en":
phn = phn.replace(f"({metadata['language']})", "")
features[filename] = tts.text_embedding( phn )
else:
# qnt = tts.encode_audio(f'./{input_speaker}/{filename}', trim_length=3.0).to(device)
wav, sr = load_audio( f'./{input_speaker}/{filename}' )
features[filename] = mfcc(wav.to(device))[0].t()
# treat embeddings as features, if provided quantized audio
if extension in artifact_extension:
artifact = np.load(f'./{input_speaker}/{filename}', allow_pickle=True)[()]
qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16, device=device)
qnt = trim( qnt, int( cfg.dataset.frames_per_second * 3 ) )
features[filename] = tts.audio_embedding( qnt )
# try and extract features from the raw audio itself
else:
# qnt = tts.encode_audio(f'./{input_speaker}/{filename}', trim_length=3.0).to(device)
wav, sr = load_audio( f'./{input_speaker}/{filename}' )
features[filename] = mfcc(wav.to(device))[0].t()
# calculate pairs, flattened because it makes tqdm nicer
for filename_a, embedding_a in features.items():
@ -144,21 +159,21 @@ def main():
parser.add_argument("--input-speaker", type=Path)
parser.add_argument("--yaml", type=Path)
parser.add_argument("--text", action="store_true")
parser.add_argument("--audio-backend", type=str, default="encodec")
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--amp", action="store_true")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--raise-exceptions", action="store_true")
args = parser.parse_args()
process(
input_speaker=args.input_speaker,
yaml=args.yaml,
text=args.text,
audio_backend=args.audio_backend,
raise_exceptions=args.raise_exceptions,
device=args.device,
dtype=args.dtype,
amp=args.amp,

View File

@ -123,6 +123,22 @@ class TTS():
return res
@torch.inference_mode()
def text_embedding( self, input, prom=False ):
model = None
for name, engine in self.engines.items():
model = engine.module
break
if isinstance( input, str ):
input = cfg.tokenizer.encode(input)
if isinstance( input, list ):
input = torch.tensor( input, dtype=torch.uint8, device=self.device )
return model.text_emb( input )
@torch.inference_mode()
def audio_embedding( self, input, prom=False ):
model = None