also do text similarities (dont know what use I'll have for this)

This commit is contained in:
mrq 2024-09-10 16:45:59 -05:00
parent 1c615a0f52
commit 4f3c7a37c8
2 changed files with 50 additions and 19 deletions

View File

@ -40,14 +40,9 @@ def load_audio( path ):
def process( def process(
input_speaker, input_speaker,
yaml, yaml,
text=False,
audio_backend="encodec", audio_backend="encodec",
output_dataset="training",
raise_exceptions=False,
stride=0,
stride_offset=0,
slice="auto",
device="cuda", device="cuda",
dtype="float16", dtype="float16",
amp=False, amp=False,
@ -55,7 +50,7 @@ def process(
verbose=False, verbose=False,
): ):
cfg.set_audio_backend(audio_backend) cfg.set_audio_backend(audio_backend)
audio_extension = cfg.audio_backend_extension artifact_extension = cfg.audio_backend_extension
cfg.inference.weight_dtype = dtype # "bfloat16" cfg.inference.weight_dtype = dtype # "bfloat16"
cfg.inference.amp = amp # False cfg.inference.amp = amp # False
@ -74,8 +69,28 @@ def process(
for filename in tqdm(os.listdir(f'./{input_speaker}/'), desc="Encoding...", disable=not verbose): for filename in tqdm(os.listdir(f'./{input_speaker}/'), desc="Encoding...", disable=not verbose):
extension = filename.split(".")[-1] extension = filename.split(".")[-1]
if text:
if extension not in artifact_extension:
raise Exception("!")
artifact = np.load(f'./{input_speaker}/{filename}', allow_pickle=True)[()]
lang = artifact["metadata"]["language"] if "language" in artifact["metadata"]["language"] else "en"
if "phonemes" in artifact["metadata"]:
phn = artifact["metadata"]["phonemes"]
elif "text" in artifact["metadata"]:
txt = artifact["metadata"]["text"]
phn = phonemize( txt, language=lang )
phn = phn.replace("(en)", "")
if lang != "en":
phn = phn.replace(f"({metadata['language']})", "")
features[filename] = tts.text_embedding( phn )
else:
# treat embeddings as features, if provided quantized audio # treat embeddings as features, if provided quantized audio
if extension in audio_extension: if extension in artifact_extension:
artifact = np.load(f'./{input_speaker}/{filename}', allow_pickle=True)[()] artifact = np.load(f'./{input_speaker}/{filename}', allow_pickle=True)[()]
qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16, device=device) qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16, device=device)
qnt = trim( qnt, int( cfg.dataset.frames_per_second * 3 ) ) qnt = trim( qnt, int( cfg.dataset.frames_per_second * 3 ) )
@ -144,21 +159,21 @@ def main():
parser.add_argument("--input-speaker", type=Path) parser.add_argument("--input-speaker", type=Path)
parser.add_argument("--yaml", type=Path) parser.add_argument("--yaml", type=Path)
parser.add_argument("--text", action="store_true")
parser.add_argument("--audio-backend", type=str, default="encodec") parser.add_argument("--audio-backend", type=str, default="encodec")
parser.add_argument("--dtype", type=str, default="bfloat16") parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--amp", action="store_true") parser.add_argument("--amp", action="store_true")
parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--raise-exceptions", action="store_true")
args = parser.parse_args() args = parser.parse_args()
process( process(
input_speaker=args.input_speaker, input_speaker=args.input_speaker,
yaml=args.yaml, yaml=args.yaml,
text=args.text,
audio_backend=args.audio_backend, audio_backend=args.audio_backend,
raise_exceptions=args.raise_exceptions,
device=args.device, device=args.device,
dtype=args.dtype, dtype=args.dtype,
amp=args.amp, amp=args.amp,

View File

@ -123,6 +123,22 @@ class TTS():
return res return res
@torch.inference_mode()
def text_embedding( self, input, prom=False ):
model = None
for name, engine in self.engines.items():
model = engine.module
break
if isinstance( input, str ):
input = cfg.tokenizer.encode(input)
if isinstance( input, list ):
input = torch.tensor( input, dtype=torch.uint8, device=self.device )
return model.text_emb( input )
@torch.inference_mode() @torch.inference_mode()
def audio_embedding( self, input, prom=False ): def audio_embedding( self, input, prom=False ):
model = None model = None