fixed deducing tokenizer path, added option to default to naive tokenizer (for old models, like ar+nar-retnet-8)

This commit is contained in:
mrq 2024-06-18 22:11:14 -05:00
parent 8a986eb480
commit 62a53eed64

View File

@ -18,9 +18,6 @@ from pathlib import Path
from .utils.distributed import world_size
# Yuck
from transformers import PreTrainedTokenizerFast
@dataclass()
class BaseConfig:
yaml_path: str | None = None
@ -805,14 +802,20 @@ class Config(BaseConfig):
self.load_hdf5()
# load tokenizer
try:
from transformers import PreTrainedTokenizerFast
cfg.tokenizer = (cfg.rel_path if cfg.yaml_path is not None else Path("./data/")) / cfg.tokenizer
cfg.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(cfg.tokenizer))
except Exception as e:
if cfg.tokenizer == "naive":
cfg.tokenizer = NaiveTokenizer()
print("Error while parsing tokenizer:", e)
pass
else:
try:
from transformers import PreTrainedTokenizerFast
tokenizer_path = cfg.rel_path / cfg.tokenizer
if not tokenizer_path.exists():
tokenizer_path = Path("./data/") / cfg.tokenizer
cfg.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_path))
except Exception as e:
cfg.tokenizer = NaiveTokenizer()
print("Error while parsing tokenizer:", e)
pass
# Preserves the old behavior