From 62a53eed64eb4fc1a6d9a3b9e9cb02f95878f30e Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 18 Jun 2024 22:11:14 -0500 Subject: [PATCH] fixed deducing tokenizer path, added option to default to naive tokenizer (for old models, like ar+nar-retnet-8) --- vall_e/config.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index a77bce7..016029a 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -18,9 +18,6 @@ from pathlib import Path from .utils.distributed import world_size -# Yuck -from transformers import PreTrainedTokenizerFast - @dataclass() class BaseConfig: yaml_path: str | None = None @@ -805,14 +802,20 @@ class Config(BaseConfig): self.load_hdf5() # load tokenizer - try: - from transformers import PreTrainedTokenizerFast - cfg.tokenizer = (cfg.rel_path if cfg.yaml_path is not None else Path("./data/")) / cfg.tokenizer - cfg.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(cfg.tokenizer)) - except Exception as e: + if cfg.tokenizer == "naive": cfg.tokenizer = NaiveTokenizer() - print("Error while parsing tokenizer:", e) - pass + else: + try: + from transformers import PreTrainedTokenizerFast + + tokenizer_path = cfg.rel_path / cfg.tokenizer + if not tokenizer_path.exists(): + tokenizer_path = Path("./data/") / cfg.tokenizer + cfg.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_path)) + except Exception as e: + cfg.tokenizer = NaiveTokenizer() + print("Error while parsing tokenizer:", e) + pass # Preserves the old behavior