fixed deducing tokenizer path, added option to default to naive tokenizer (for old models, like ar+nar-retnet-8)
This commit is contained in:
parent
8a986eb480
commit
62a53eed64
|
@ -18,9 +18,6 @@ from pathlib import Path
|
||||||
|
|
||||||
from .utils.distributed import world_size
|
from .utils.distributed import world_size
|
||||||
|
|
||||||
# Yuck
|
|
||||||
from transformers import PreTrainedTokenizerFast
|
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class BaseConfig:
|
class BaseConfig:
|
||||||
yaml_path: str | None = None
|
yaml_path: str | None = None
|
||||||
|
@ -805,10 +802,16 @@ class Config(BaseConfig):
|
||||||
self.load_hdf5()
|
self.load_hdf5()
|
||||||
|
|
||||||
# load tokenizer
|
# load tokenizer
|
||||||
|
if cfg.tokenizer == "naive":
|
||||||
|
cfg.tokenizer = NaiveTokenizer()
|
||||||
|
else:
|
||||||
try:
|
try:
|
||||||
from transformers import PreTrainedTokenizerFast
|
from transformers import PreTrainedTokenizerFast
|
||||||
cfg.tokenizer = (cfg.rel_path if cfg.yaml_path is not None else Path("./data/")) / cfg.tokenizer
|
|
||||||
cfg.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(cfg.tokenizer))
|
tokenizer_path = cfg.rel_path / cfg.tokenizer
|
||||||
|
if not tokenizer_path.exists():
|
||||||
|
tokenizer_path = Path("./data/") / cfg.tokenizer
|
||||||
|
cfg.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_path))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
cfg.tokenizer = NaiveTokenizer()
|
cfg.tokenizer = NaiveTokenizer()
|
||||||
print("Error while parsing tokenizer:", e)
|
print("Error while parsing tokenizer:", e)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user