also do text similarities (dont know what use I'll have for this)
This commit is contained in:
parent
1c615a0f52
commit
4f3c7a37c8
|
@ -40,14 +40,9 @@ def load_audio( path ):
|
|||
def process(
|
||||
input_speaker,
|
||||
yaml,
|
||||
text=False,
|
||||
|
||||
audio_backend="encodec",
|
||||
output_dataset="training",
|
||||
raise_exceptions=False,
|
||||
stride=0,
|
||||
stride_offset=0,
|
||||
slice="auto",
|
||||
|
||||
device="cuda",
|
||||
dtype="float16",
|
||||
amp=False,
|
||||
|
@ -55,7 +50,7 @@ def process(
|
|||
verbose=False,
|
||||
):
|
||||
cfg.set_audio_backend(audio_backend)
|
||||
audio_extension = cfg.audio_backend_extension
|
||||
artifact_extension = cfg.audio_backend_extension
|
||||
|
||||
cfg.inference.weight_dtype = dtype # "bfloat16"
|
||||
cfg.inference.amp = amp # False
|
||||
|
@ -74,18 +69,38 @@ def process(
|
|||
for filename in tqdm(os.listdir(f'./{input_speaker}/'), desc="Encoding...", disable=not verbose):
|
||||
extension = filename.split(".")[-1]
|
||||
|
||||
# treat embeddings as features, if provided quantized audio
|
||||
if extension in audio_extension:
|
||||
|
||||
if text:
|
||||
if extension not in artifact_extension:
|
||||
raise Exception("!")
|
||||
|
||||
artifact = np.load(f'./{input_speaker}/{filename}', allow_pickle=True)[()]
|
||||
qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16, device=device)
|
||||
qnt = trim( qnt, int( cfg.dataset.frames_per_second * 3 ) )
|
||||
|
||||
lang = artifact["metadata"]["language"] if "language" in artifact["metadata"]["language"] else "en"
|
||||
if "phonemes" in artifact["metadata"]:
|
||||
phn = artifact["metadata"]["phonemes"]
|
||||
elif "text" in artifact["metadata"]:
|
||||
txt = artifact["metadata"]["text"]
|
||||
phn = phonemize( txt, language=lang )
|
||||
|
||||
features[filename] = tts.audio_embedding( qnt )
|
||||
# try and extract features from the raw audio itself
|
||||
phn = phn.replace("(en)", "")
|
||||
if lang != "en":
|
||||
phn = phn.replace(f"({metadata['language']})", "")
|
||||
|
||||
features[filename] = tts.text_embedding( phn )
|
||||
else:
|
||||
# qnt = tts.encode_audio(f'./{input_speaker}/{filename}', trim_length=3.0).to(device)
|
||||
wav, sr = load_audio( f'./{input_speaker}/{filename}' )
|
||||
features[filename] = mfcc(wav.to(device))[0].t()
|
||||
# treat embeddings as features, if provided quantized audio
|
||||
if extension in artifact_extension:
|
||||
artifact = np.load(f'./{input_speaker}/{filename}', allow_pickle=True)[()]
|
||||
qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16, device=device)
|
||||
qnt = trim( qnt, int( cfg.dataset.frames_per_second * 3 ) )
|
||||
|
||||
features[filename] = tts.audio_embedding( qnt )
|
||||
# try and extract features from the raw audio itself
|
||||
else:
|
||||
# qnt = tts.encode_audio(f'./{input_speaker}/{filename}', trim_length=3.0).to(device)
|
||||
wav, sr = load_audio( f'./{input_speaker}/{filename}' )
|
||||
features[filename] = mfcc(wav.to(device))[0].t()
|
||||
|
||||
# calculate pairs, flattened because it makes tqdm nicer
|
||||
for filename_a, embedding_a in features.items():
|
||||
|
@ -144,21 +159,21 @@ def main():
|
|||
|
||||
parser.add_argument("--input-speaker", type=Path)
|
||||
parser.add_argument("--yaml", type=Path)
|
||||
parser.add_argument("--text", action="store_true")
|
||||
|
||||
parser.add_argument("--audio-backend", type=str, default="encodec")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16")
|
||||
parser.add_argument("--amp", action="store_true")
|
||||
parser.add_argument("--device", type=str, default="cuda")
|
||||
parser.add_argument("--raise-exceptions", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
process(
|
||||
input_speaker=args.input_speaker,
|
||||
yaml=args.yaml,
|
||||
text=args.text,
|
||||
|
||||
audio_backend=args.audio_backend,
|
||||
raise_exceptions=args.raise_exceptions,
|
||||
|
||||
device=args.device,
|
||||
dtype=args.dtype,
|
||||
amp=args.amp,
|
||||
|
|
|
@ -123,6 +123,22 @@ class TTS():
|
|||
|
||||
return res
|
||||
|
||||
@torch.inference_mode()
|
||||
def text_embedding( self, input, prom=False ):
|
||||
model = None
|
||||
|
||||
for name, engine in self.engines.items():
|
||||
model = engine.module
|
||||
break
|
||||
|
||||
if isinstance( input, str ):
|
||||
input = cfg.tokenizer.encode(input)
|
||||
|
||||
if isinstance( input, list ):
|
||||
input = torch.tensor( input, dtype=torch.uint8, device=self.device )
|
||||
|
||||
return model.text_emb( input )
|
||||
|
||||
@torch.inference_mode()
|
||||
def audio_embedding( self, input, prom=False ):
|
||||
model = None
|
||||
|
|
Loading…
Reference in New Issue
Block a user