also do text similarities (dont know what use I'll have for this)
This commit is contained in:
parent
1c615a0f52
commit
4f3c7a37c8
|
@ -40,14 +40,9 @@ def load_audio( path ):
|
||||||
def process(
|
def process(
|
||||||
input_speaker,
|
input_speaker,
|
||||||
yaml,
|
yaml,
|
||||||
|
text=False,
|
||||||
|
|
||||||
audio_backend="encodec",
|
audio_backend="encodec",
|
||||||
output_dataset="training",
|
|
||||||
raise_exceptions=False,
|
|
||||||
stride=0,
|
|
||||||
stride_offset=0,
|
|
||||||
slice="auto",
|
|
||||||
|
|
||||||
device="cuda",
|
device="cuda",
|
||||||
dtype="float16",
|
dtype="float16",
|
||||||
amp=False,
|
amp=False,
|
||||||
|
@ -55,7 +50,7 @@ def process(
|
||||||
verbose=False,
|
verbose=False,
|
||||||
):
|
):
|
||||||
cfg.set_audio_backend(audio_backend)
|
cfg.set_audio_backend(audio_backend)
|
||||||
audio_extension = cfg.audio_backend_extension
|
artifact_extension = cfg.audio_backend_extension
|
||||||
|
|
||||||
cfg.inference.weight_dtype = dtype # "bfloat16"
|
cfg.inference.weight_dtype = dtype # "bfloat16"
|
||||||
cfg.inference.amp = amp # False
|
cfg.inference.amp = amp # False
|
||||||
|
@ -74,8 +69,28 @@ def process(
|
||||||
for filename in tqdm(os.listdir(f'./{input_speaker}/'), desc="Encoding...", disable=not verbose):
|
for filename in tqdm(os.listdir(f'./{input_speaker}/'), desc="Encoding...", disable=not verbose):
|
||||||
extension = filename.split(".")[-1]
|
extension = filename.split(".")[-1]
|
||||||
|
|
||||||
|
|
||||||
|
if text:
|
||||||
|
if extension not in artifact_extension:
|
||||||
|
raise Exception("!")
|
||||||
|
|
||||||
|
artifact = np.load(f'./{input_speaker}/{filename}', allow_pickle=True)[()]
|
||||||
|
|
||||||
|
lang = artifact["metadata"]["language"] if "language" in artifact["metadata"]["language"] else "en"
|
||||||
|
if "phonemes" in artifact["metadata"]:
|
||||||
|
phn = artifact["metadata"]["phonemes"]
|
||||||
|
elif "text" in artifact["metadata"]:
|
||||||
|
txt = artifact["metadata"]["text"]
|
||||||
|
phn = phonemize( txt, language=lang )
|
||||||
|
|
||||||
|
phn = phn.replace("(en)", "")
|
||||||
|
if lang != "en":
|
||||||
|
phn = phn.replace(f"({metadata['language']})", "")
|
||||||
|
|
||||||
|
features[filename] = tts.text_embedding( phn )
|
||||||
|
else:
|
||||||
# treat embeddings as features, if provided quantized audio
|
# treat embeddings as features, if provided quantized audio
|
||||||
if extension in audio_extension:
|
if extension in artifact_extension:
|
||||||
artifact = np.load(f'./{input_speaker}/{filename}', allow_pickle=True)[()]
|
artifact = np.load(f'./{input_speaker}/{filename}', allow_pickle=True)[()]
|
||||||
qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16, device=device)
|
qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16, device=device)
|
||||||
qnt = trim( qnt, int( cfg.dataset.frames_per_second * 3 ) )
|
qnt = trim( qnt, int( cfg.dataset.frames_per_second * 3 ) )
|
||||||
|
@ -144,21 +159,21 @@ def main():
|
||||||
|
|
||||||
parser.add_argument("--input-speaker", type=Path)
|
parser.add_argument("--input-speaker", type=Path)
|
||||||
parser.add_argument("--yaml", type=Path)
|
parser.add_argument("--yaml", type=Path)
|
||||||
|
parser.add_argument("--text", action="store_true")
|
||||||
|
|
||||||
parser.add_argument("--audio-backend", type=str, default="encodec")
|
parser.add_argument("--audio-backend", type=str, default="encodec")
|
||||||
parser.add_argument("--dtype", type=str, default="bfloat16")
|
parser.add_argument("--dtype", type=str, default="bfloat16")
|
||||||
parser.add_argument("--amp", action="store_true")
|
parser.add_argument("--amp", action="store_true")
|
||||||
parser.add_argument("--device", type=str, default="cuda")
|
parser.add_argument("--device", type=str, default="cuda")
|
||||||
parser.add_argument("--raise-exceptions", action="store_true")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
process(
|
process(
|
||||||
input_speaker=args.input_speaker,
|
input_speaker=args.input_speaker,
|
||||||
yaml=args.yaml,
|
yaml=args.yaml,
|
||||||
|
text=args.text,
|
||||||
|
|
||||||
audio_backend=args.audio_backend,
|
audio_backend=args.audio_backend,
|
||||||
raise_exceptions=args.raise_exceptions,
|
|
||||||
|
|
||||||
device=args.device,
|
device=args.device,
|
||||||
dtype=args.dtype,
|
dtype=args.dtype,
|
||||||
amp=args.amp,
|
amp=args.amp,
|
||||||
|
|
|
@ -123,6 +123,22 @@ class TTS():
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def text_embedding( self, input, prom=False ):
|
||||||
|
model = None
|
||||||
|
|
||||||
|
for name, engine in self.engines.items():
|
||||||
|
model = engine.module
|
||||||
|
break
|
||||||
|
|
||||||
|
if isinstance( input, str ):
|
||||||
|
input = cfg.tokenizer.encode(input)
|
||||||
|
|
||||||
|
if isinstance( input, list ):
|
||||||
|
input = torch.tensor( input, dtype=torch.uint8, device=self.device )
|
||||||
|
|
||||||
|
return model.text_emb( input )
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def audio_embedding( self, input, prom=False ):
|
def audio_embedding( self, input, prom=False ):
|
||||||
model = None
|
model = None
|
||||||
|
|
Loading…
Reference in New Issue
Block a user