fixed deducing tokenizer path, added option to default to naive tokenizer (for old models, like ar+nar-retnet-8)
This commit is contained in:
parent
8a986eb480
commit
62a53eed64
|
@ -18,9 +18,6 @@ from pathlib import Path
|
|||
|
||||
from .utils.distributed import world_size
|
||||
|
||||
# Yuck
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
@dataclass()
|
||||
class BaseConfig:
|
||||
yaml_path: str | None = None
|
||||
|
@ -805,14 +802,20 @@ class Config(BaseConfig):
|
|||
self.load_hdf5()
|
||||
|
||||
# load tokenizer
|
||||
try:
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
cfg.tokenizer = (cfg.rel_path if cfg.yaml_path is not None else Path("./data/")) / cfg.tokenizer
|
||||
cfg.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(cfg.tokenizer))
|
||||
except Exception as e:
|
||||
if cfg.tokenizer == "naive":
|
||||
cfg.tokenizer = NaiveTokenizer()
|
||||
print("Error while parsing tokenizer:", e)
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
tokenizer_path = cfg.rel_path / cfg.tokenizer
|
||||
if not tokenizer_path.exists():
|
||||
tokenizer_path = Path("./data/") / cfg.tokenizer
|
||||
cfg.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_path))
|
||||
except Exception as e:
|
||||
cfg.tokenizer = NaiveTokenizer()
|
||||
print("Error while parsing tokenizer:", e)
|
||||
pass
|
||||
|
||||
|
||||
# Preserves the old behavior
|
||||
|
|
Loading…
Reference in New Issue
Block a user