710 lines
24 KiB
Python
Executable File
710 lines
24 KiB
Python
Executable File
import copy
|
|
import diskcache
|
|
import h5py
|
|
import json
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
import argparse
|
|
import yaml
|
|
import random
|
|
import logging
|
|
|
|
import torch
|
|
import numpy as np
|
|
|
|
from dataclasses import asdict, dataclass, field
|
|
|
|
from functools import cached_property
|
|
from pathlib import Path
|
|
|
|
from .utils.distributed import world_size
|
|
|
|
|
|
def set_seed(seed=None):
|
|
if not seed:
|
|
seed = time.time()
|
|
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
|
|
@dataclass()
|
|
class BaseConfig:
|
|
yaml_path: str | None = None # path passed in through --yaml
|
|
|
|
@property
|
|
def cfg_path(self):
|
|
return Path(self.yaml_path.parent) if self.yaml_path is not None else None
|
|
|
|
@property
|
|
def rel_path(self):
|
|
return Path(self.cfg_path)
|
|
|
|
@property
|
|
def cache_dir(self):
|
|
return self.rel_path / ".cache"
|
|
|
|
@property
|
|
def data_dir(self):
|
|
return self.rel_path / "data"
|
|
|
|
@property
|
|
def metadata_dir(self):
|
|
return self.rel_path / "metadata"
|
|
|
|
@property
|
|
def ckpt_dir(self):
|
|
return self.rel_path / "ckpt"
|
|
|
|
@property
|
|
def log_dir(self):
|
|
return self.rel_path / "logs" / str(self.start_time)
|
|
|
|
@cached_property
|
|
def start_time(self):
|
|
return int(time.time())
|
|
|
|
@cached_property
|
|
def git_commit(self):
|
|
try:
|
|
cmd = "git rev-parse HEAD"
|
|
return subprocess.check_output(cmd.split()).decode("utf8").strip()
|
|
except:
|
|
return ""
|
|
|
|
@cached_property
|
|
def git_status(self):
|
|
try:
|
|
cmd = "git status"
|
|
return subprocess.check_output(cmd.split()).decode("utf8").strip()
|
|
except:
|
|
return ""
|
|
|
|
def dumps(self):
|
|
data = {k: getattr(self, k) for k in dir(self) if not k.startswith("__")}
|
|
data = {k: v for k, v in data.items() if not callable(v)}
|
|
return json.dumps(data, indent=2, default=str)
|
|
|
|
def dump(self, path=None):
|
|
if path is None:
|
|
path = self.log_dir / "cfg.json"
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
with open(path, "w") as f:
|
|
f.write(self.dumps())
|
|
|
|
@classmethod
|
|
def from_yaml( cls, yaml_path ):
|
|
state = {}
|
|
state = yaml.safe_load(open(yaml_path, "r", encoding="utf-8"))
|
|
state.setdefault("yaml_path", yaml_path)
|
|
return cls(**state)
|
|
|
|
@classmethod
|
|
def from_cli(cls, args=sys.argv):
|
|
# legacy support for yaml=`` format
|
|
for i, arg in enumerate(args):
|
|
if arg.startswith("yaml"):
|
|
args[i] = f'--{arg}'
|
|
|
|
parser = argparse.ArgumentParser(allow_abbrev=False)
|
|
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
|
|
args, unknown = parser.parse_known_args(args=args)
|
|
|
|
if args.yaml:
|
|
return cls.from_yaml( args.yaml )
|
|
|
|
return cls(**{})
|
|
|
|
def __repr__(self):
|
|
return str(self)
|
|
|
|
def __str__(self):
|
|
return self.dumps()
|
|
|
|
@dataclass()
|
|
class Dataset:
|
|
training: list[Path] = field(default_factory=lambda: []) # paths to load into the training dataset
|
|
validation: list[Path] = field(default_factory=lambda: []) # paths to load into the validation dataset
|
|
|
|
hdf5_name: str = "data.h5" # file name to load the HDF5 dataset
|
|
use_hdf5: bool = False # whether to load from an HDF5 dataset
|
|
hdf5_flag: str = "a" # flag to load the HDF5 file, automatically adjusted anyways
|
|
|
|
validate: bool = True # validate each utterance on wheter it can be included based on duration range caps
|
|
workers: int = 8 # number of dataloader workers to spawn
|
|
cache: bool = True # use diskcache to cache the dataset
|
|
|
|
# I really need to clean this up
|
|
@dataclass()
|
|
class Model:
|
|
name: str = "classifier"
|
|
|
|
tokens: int = 0 # number of token types
|
|
len: int = 1 # how long a sequence can be
|
|
dim: int = 512
|
|
resnet: int = 18
|
|
|
|
width: int = 300
|
|
height: int = 80
|
|
|
|
version: int = 1
|
|
training: bool = True
|
|
frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training
|
|
|
|
@property
|
|
def full_name(self):
|
|
return self.name
|
|
|
|
def get(self, name=None):
|
|
return [ self ] if not name or self.name == name else []
|
|
|
|
def loss_factor(self, k):
|
|
return self.loss_factors[k] if k in self.loss_factors else 1.0
|
|
|
|
@property
|
|
# required for fp8 as the lengths needs to be divisible by 8
|
|
def input_alignment(self):
|
|
return 8 if cfg.optimizations.fp8 else 0
|
|
|
|
@property
|
|
def activation_checkpointing(self):
|
|
return cfg.trainer.activation_checkpointing
|
|
|
|
@property
|
|
def gradient_checkpointing(self):
|
|
return cfg.trainer.gradient_checkpointing
|
|
|
|
@property
|
|
def lora_policy(self):
|
|
include = ["model"] # by default only adapt the main model (not embeddings nor classifier/output projection/LM head/whatever)
|
|
exclude = []
|
|
|
|
if self.arch_type == "llama":
|
|
include = ["self_attn", "mlp"] # target only the attention + mlp
|
|
exclude = ["self_attn.k_proj"] # common literature says to ignore it
|
|
if self.arch_type == "retnet":
|
|
include = ["layers."] # target the core layers of the RetNet and ignore the auxiliary stuff
|
|
exclude = ["retention.k_proj"] # attention-based transformers ignore the K, so might as well ignore it for the retnet
|
|
|
|
return dict(include=include, exclude=exclude)
|
|
|
|
# should be renamed to Adapters
|
|
@dataclass()
|
|
class LoRA:
|
|
name: str = "lora" # vanity name
|
|
# to-do: find sane default values
|
|
rank: int = 128 # rank for the LoRA
|
|
alpha: int = 128 # rank for the LoRA
|
|
training: bool = True #
|
|
embeddings: bool = False # train the embedding too
|
|
parametrize: bool = False # whether to use the parameterized pathway for LoRAs or not
|
|
rvq_levels: list[int] = field(default_factory=lambda: []) # determines RVQ levels to activate the LoRA
|
|
|
|
@property
|
|
def full_name(self):
|
|
name = [ self.name, f"r{self.rank}", f"a{self.alpha}" ]
|
|
return "-".join(name)
|
|
|
|
# actually not needed anymore
|
|
def active_level( self, level ):
|
|
if not self.rvq_levels:
|
|
return True
|
|
return level in self.rvq_levels
|
|
|
|
@dataclass()
|
|
class Hyperparameters:
|
|
batch_size: int = 8 # number of samples per training batch
|
|
gradient_accumulation_steps: int = 32 # number of steps to accumulate gradients before updating
|
|
gradient_clipping: int | float = 10 # largest size a gradient norm can be
|
|
|
|
optimizer: str = "Adamw" # optimizer to use, should be 'Prodigyopt" now
|
|
optimizer_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
|
|
|
learning_rate: float = 3.25e-4 # should be 1.0 for ProdigyOpt
|
|
warmup_steps: int = 0 # number of steps to warm up the optimizer before performing updates, I think, this is just passed to deepspeed
|
|
|
|
scheduler: str = "" # scheduler to use, currently don't ever use one so this doesn't really matter
|
|
scheduler_type: str = "" # deprecated
|
|
scheduler_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
|
|
|
autotune: bool = False # to do deepspeed's autotuning
|
|
autotune_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
|
|
|
torch_optimizer: bool = False # if the requested optimizer is torch-derived rather than deepspeed supplied
|
|
torch_scheduler: bool = False # if the requested scheduler is torch-derived rather than deepspeed-supplied
|
|
|
|
@dataclass()
|
|
class Evaluation:
|
|
batch_size: int = 64 # number of samples per batch during eval / val
|
|
frequency: int = 250 # do eval / val every X iterations
|
|
size: int = 64 # number of samples to generate during eval / val
|
|
|
|
steps: int = 500
|
|
temperature: float = 1.0 # AR temp for inferencing
|
|
|
|
load_disabled_engines: bool = True # see the other load_disabled_engines
|
|
|
|
@dataclass()
|
|
class DeepSpeed:
|
|
zero_optimization_level: int = 0 # doesn't seem to work
|
|
use_compression_training: bool = False # cope
|
|
compression_bits: int = 8 # cope
|
|
inferencing: bool = False # for using DeepSpeed's inferencing wrapper instead
|
|
|
|
amp: bool = False # use DeepSpeed's AMP (requires some other package installed apparently)
|
|
|
|
config: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
|
|
|
@cached_property
|
|
def ds_cfg(self):
|
|
optimizer_params = cfg.hyperparameters.optimizer_params
|
|
|
|
if 'lr' not in optimizer_params:
|
|
optimizer_params["lr"] = cfg.hyperparameters.learning_rate,
|
|
|
|
scheduler_params = cfg.hyperparameters.scheduler_params
|
|
if 'warmup_num_steps' not in scheduler_params:
|
|
scheduler_params['warmup_num_steps'] = cfg.hyperparameters.warmup_steps
|
|
|
|
if 'total_num_steps' not in scheduler_params:
|
|
scheduler_params['total_num_steps'] = cfg.trainer.iterations
|
|
|
|
autotune_params = cfg.hyperparameters.autotune_params
|
|
|
|
if "enabled" not in autotune_params:
|
|
autotune_params['enabled'] = True
|
|
|
|
if "results_dir" not in autotune_params:
|
|
autotune_params['results_dir'] = str( cfg.rel_path / "autotune" / "results" )
|
|
|
|
if "exps_dir" not in autotune_params:
|
|
autotune_params['exps_dir'] = str( cfg.rel_path / "autotune" / "exps_" )
|
|
|
|
# DeepSpeed fp16 is incompatible with its AMP
|
|
if cfg.trainer.weight_dtype.lower() == "float16":
|
|
self.amp = False
|
|
|
|
# disable local AMP
|
|
if self.amp:
|
|
cfg.trainer.amp = False
|
|
|
|
ds_cfg = {
|
|
"train_micro_batch_size_per_gpu": cfg.hyperparameters.batch_size,
|
|
"gradient_accumulation_steps": cfg.hyperparameters.gradient_accumulation_steps,
|
|
"optimizer": {
|
|
"type": cfg.hyperparameters.optimizer,
|
|
"params": optimizer_params,
|
|
} if not cfg.hyperparameters.torch_optimizer else None,
|
|
"scheduler": {
|
|
"type": cfg.hyperparameters.scheduler,
|
|
"params": scheduler_params,
|
|
} if not cfg.hyperparameters.torch_scheduler else None,
|
|
"gradient_clipping": cfg.hyperparameters.gradient_clipping,
|
|
"fp16": {
|
|
"enabled": cfg.trainer.weight_dtype.lower() == "float16",
|
|
"auto_cast": True, # ???
|
|
"loss_scale": 0.0 if cfg.trainer.scale_loss else 1.0,
|
|
},
|
|
"bf16": {
|
|
"enabled": cfg.trainer.weight_dtype.lower() == "bfloat16",
|
|
},
|
|
"amp": {
|
|
"enabled": self.amp,
|
|
},
|
|
"autotuning": autotune_params if cfg.hyperparameters.autotune else None,
|
|
"compression_training": {
|
|
"weight_quantization": {
|
|
"shared_parameters":{
|
|
"enabled": True,
|
|
"quantizer_kernel": True,
|
|
"schedule_offset": 0,
|
|
"quantize_groups": 64,
|
|
"quantize_verbose": True,
|
|
"quantization_type": "symmetric",
|
|
"rounding": "nearest",
|
|
"quantize_weight_in_forward": cfg.trainer.weight_dtype.lower() != "float16", # MoQ (quantize in optimization step) weight quantization is only supported for FP16
|
|
"fp16_mixed_quantize":{
|
|
"enabled": False,
|
|
"quantize_change_ratio": 1
|
|
}
|
|
},
|
|
"different_groups": {
|
|
"wq1": {
|
|
"params": {
|
|
"start_bits": self.compression_bits,
|
|
"target_bits": self.compression_bits,
|
|
"quantization_period": 0
|
|
},
|
|
"modules": [ "self_attn", "mlp" ] # for LLaMA, need to find for other arches
|
|
}
|
|
}
|
|
},
|
|
"activation_quantization": {
|
|
"shared_parameters":{
|
|
"enabled": True,
|
|
"quantizer_kernel": True,
|
|
"schedule_offset": 0,
|
|
"quantize_groups": 64,
|
|
"quantize_verbose": True,
|
|
"quantization_type": "symmetric",
|
|
"rounding": "nearest",
|
|
"quantize_weight_in_forward": cfg.trainer.weight_dtype.lower() != "float16", # MoQ (quantize in optimization step) weight quantization is only supported for FP16
|
|
"fp16_mixed_quantize":{
|
|
"enabled": False,
|
|
"quantize_change_ratio": 1
|
|
}
|
|
},
|
|
"different_groups": {
|
|
"aq1": {
|
|
"params": {
|
|
"bits": self.compression_bits,
|
|
},
|
|
"modules": [ "self_attn", "mlp" ] # for LLaMA, need to find for other arches
|
|
}
|
|
}
|
|
},
|
|
} if self.use_compression_training else None,
|
|
"zero_optimization": {
|
|
"stage": self.zero_optimization_level,
|
|
"contiguous_gradients": True,
|
|
"overlap_comm": True,
|
|
"reduce_scatter": True,
|
|
"reduce_bucket_size": 5e8,
|
|
"allgather_bucket_size": 5e8,
|
|
"sub_group_size": 5e8,
|
|
"round_robin_gradients": True,
|
|
"offload_optimizer": {
|
|
"device": "cpu",
|
|
"pin_memory": True
|
|
},
|
|
"offload_param": {
|
|
"device": "cpu",
|
|
"pin_memory": True
|
|
},
|
|
"zero_quantized_weights": self.use_compression_training,
|
|
"zero_hpz_partition_size": world_size(),
|
|
"zero_quantized_gradients": self.use_compression_training,
|
|
} if self.zero_optimization_level > 0 else None,
|
|
"comms_logger": {
|
|
"enabled": False
|
|
}
|
|
}
|
|
|
|
null_keys = [ k for k in ds_cfg if not ds_cfg[k] ]
|
|
for k in null_keys:
|
|
del ds_cfg[k]
|
|
|
|
if os.path.exists("./data/ds_config.json"):
|
|
ds_cfg.update(json.load(open("./data/ds_config.json", "r", encoding="utf-8")))
|
|
else:
|
|
ds_cfg.update(self.config)
|
|
|
|
return ds_cfg
|
|
|
|
@dataclass()
|
|
class Trainer:
|
|
iterations: int = 1_000_000 # maximum iterations to train
|
|
|
|
save_tag: str = "step" # name to save checkpoints under, "step" will save as current step count
|
|
load_tag: str | None = None # tag to load checkpoint from; if None: will check against contents of `./ckpt/{model-name}/latest` for the checkpoint name
|
|
|
|
save_on_oom: bool = True # save if an OOM error is raised
|
|
save_on_quit: bool = True # save when quitting training
|
|
|
|
export_on_save: bool = False # export weights to local `fp32.pth` state_dict on saving a checkpoint
|
|
export_on_quit: bool = False # export weights to local `fp32.pth` state_dict on quitting training
|
|
|
|
save_frequency: int = 100 # frequency to save every X iterations
|
|
|
|
keep_last_checkpoints: int = 0 # number of checkpoints to keep, prunes oldest ones
|
|
|
|
load_state_dict: bool = False # loads `fp32.pth` state_dict, will automatically be done if a checkpoint is not found but `fp32.pth` exists
|
|
load_states: bool = True #
|
|
strict_loading: bool = False # sets strict_loading=True when loading the state dict
|
|
load_module_only: bool = False #
|
|
restart_step_count: bool = False # clears the training stats when loading a checkpoint
|
|
resize_modules: bool = False # automatically resizes
|
|
|
|
activation_checkpointing: bool | None = None # deprecated, should technically be used for only on activations and not the entire gradients, but HF only has gradient checkpointing
|
|
gradient_checkpointing: bool = True # enables gradient checkpointing to save VRAM at the cost of slightly reduced performance when training
|
|
|
|
aggressive_optimizations: bool = False # deprecated
|
|
check_for_oom: bool = True # checks for OOMs thrown during forward/backwards
|
|
gc_mode: str | None = None # deprecated, but marks when to do GC
|
|
load_disabled_engines: bool = False # deprecated, but signals to load engines not used for training for, for example, evaluation/validation
|
|
|
|
weight_dtype: str = "float16" # dtype to have the model under
|
|
|
|
amp: bool = False # automatic mixed precision
|
|
ddp: bool = False # torch's internal DDP, automatically set if local backend is used and multiple GPUs are requested
|
|
#scale_loss: bool = False # whether to perform loss scaling (for FP16 training) (it actually seems more harmful than not for this specific workload)
|
|
|
|
load_webui: bool = False # not working, but loads the web UI to allow inferencing during training
|
|
no_logger: bool = False # deprecated, but reroutes some logger calls to normal print statements for when logger broke because of BitNet
|
|
|
|
backend: str = "local" # training backend to use. currently supports "local" | "deepspeed"
|
|
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed) # deepspeed settings
|
|
|
|
@cached_property
|
|
def dtype(self):
|
|
if self.weight_dtype == "float16":
|
|
return torch.float16
|
|
if self.weight_dtype == "bfloat16":
|
|
return torch.bfloat16
|
|
if self.weight_dtype == "float8_e5m2":
|
|
return torch.float8_e5m2
|
|
if self.weight_dtype == "float8_e4m3fn":
|
|
return torch.float8_e4m3fn
|
|
return torch.float32
|
|
|
|
@cached_property
|
|
def scale_loss(self):
|
|
# currently cannot feasibly apply loss scaling with DeepSpeed backend (it can handle it itself anyways)
|
|
return self.dtype == torch.float16
|
|
|
|
@dataclass()
|
|
class Inference:
|
|
backend: str = "local" # backend to use when inferencing
|
|
weight_dtype: str = "float32" # dtype to load the model under
|
|
amp: bool = False # automatic mixed precision during inferencing
|
|
|
|
normalize: bool = False # do NOT enable this unless you know exactly what you're doing
|
|
|
|
@property
|
|
def dtype(self):
|
|
if self.weight_dtype == "float16":
|
|
return torch.float16
|
|
if self.weight_dtype == "bfloat16":
|
|
return torch.bfloat16
|
|
if self.weight_dtype == "int8":
|
|
return torch.int8
|
|
if self.weight_dtype == "float8_e5m2":
|
|
return torch.float8_e5m2
|
|
if self.weight_dtype == "float8_e4m3fn":
|
|
return torch.float8_e4m3fn
|
|
return torch.float32
|
|
|
|
@dataclass()
|
|
class Optimizations:
|
|
injects: bool = False # overwrites default torch classes (not recommended)
|
|
replace: bool = False # replaces modules in place with the optimized version (recommended)
|
|
compile: bool | str = False # runs torch.compile on the model
|
|
|
|
linear: bool = True # inject/replace linear for BnB
|
|
embedding: bool = True # inject/replace embedding for BnB
|
|
optimizers: bool = True # inject/replace optimizers (BnB, DAdaptation)
|
|
|
|
bitsandbytes: bool = False # use bitsandbytes
|
|
dadaptation: bool = False # use dadaptation optimizer
|
|
bitnet: bool = False # use bitnet
|
|
fp8: bool = False # use fp8
|
|
|
|
model_offloading: dict | None = None # automatically splits the model over a list of devices
|
|
# example: {"include":["model"], "limits": [ (6 * 1024) * (1024 ** 2), -1 ]} will have the GPU capped to 6GiB, and offload the remaining layers to CPU
|
|
# example: {"include":["model"], "device": ["cuda:0", "cuda:1"], "limits": [ 0.5, 0.5 ]} will have the GPU 1 try and use 50% of the model, and GPU 2 try and use the other 50%
|
|
# | {"assign": [[ f'layers.{i}.' for i in range(0,6) ], [ f'layers.{i}.' for i in range(6,12) ]]} will assign layers 0-5 to device 1, and 6-12 to device 2
|
|
|
|
tensorrt: bool = False
|
|
|
|
@dataclass()
|
|
class Config(BaseConfig):
|
|
device: str = "cuda" # target device
|
|
mode: str = "training" # "inferencing"
|
|
experimental: bool = False # Debug flag, unused now
|
|
|
|
dataset: Dataset = field(default_factory=lambda: Dataset)
|
|
models: dict | list | None = field(default_factory=lambda: [])
|
|
loras: dict | list | None = field(default_factory=lambda: [])
|
|
hyperparameters: Hyperparameters = field(default_factory=lambda: Hyperparameters)
|
|
evaluation: Evaluation = field(default_factory=lambda: Evaluation)
|
|
trainer: Trainer = field(default_factory=lambda: Trainer)
|
|
inference: Inference = field(default_factory=lambda: Inference)
|
|
bitsandbytes: dict | list | None = None # deprecated
|
|
optimizations: Optimizations = field(default_factory=lambda: Optimizations)
|
|
|
|
tokenizer: str | None = None # tokenizer class
|
|
tokenizer_path: str = "./tokenizer.json" # tokenizer path
|
|
|
|
weights_format: str = "pth" # "pth" | "sft"
|
|
supported_weights_formats: list[str] = field(default_factory=lambda: ["sft", "safetensors", "pt", "pth"])
|
|
|
|
@property
|
|
def model(self):
|
|
for i, model in enumerate(self.models):
|
|
if model.training:
|
|
return model
|
|
|
|
return self.models[0] if len(self.models) > 0 else None
|
|
|
|
# should be renamed to adapters
|
|
@property
|
|
def lora(self):
|
|
for i, lora in enumerate(self.loras):
|
|
if lora.training:
|
|
return lora
|
|
|
|
return self.loras[0] if len(self.loras) > 0 else None
|
|
|
|
@property
|
|
def distributed(self):
|
|
return world_size() > 1
|
|
|
|
@cached_property
|
|
def diskcache(self):
|
|
if self.yaml_path is not None and self.dataset.cache:
|
|
return diskcache.Cache(self.cache_dir).memoize
|
|
return lambda: lambda x: x
|
|
|
|
# I don't remember why this is needed
|
|
def load_yaml( self, config_path ):
|
|
tmp = Config.from_yaml( config_path )
|
|
self.__dict__.update(tmp.__dict__)
|
|
|
|
def load_hdf5( self, write=False ):
|
|
if hasattr(self, 'hdf5'):
|
|
self.hdf5.close()
|
|
|
|
if self.distributed:
|
|
self.dataset.hdf5_flag = "r"
|
|
try:
|
|
self.hdf5 = h5py.File(f'{self.rel_path}/{self.dataset.hdf5_name}', 'a' if write else self.dataset.hdf5_flag) # to-do, have an easy to set flag that determines if training or creating the dataset
|
|
except Exception as e:
|
|
_logger.warning(f"Error while opening HDF5 file: {self.rel_path}/{self.dataset.hdf5_name}: {str(e)}")
|
|
self.dataset.use_hdf5 = False
|
|
|
|
# to-do: prune unused keys
|
|
def format( self, training=True ):
|
|
if isinstance(self.dataset, type):
|
|
self.dataset = dict()
|
|
|
|
if isinstance(self.models, type):
|
|
self.models = dict()
|
|
|
|
if isinstance(self.loras, type):
|
|
self.loras = dict()
|
|
|
|
if isinstance(self.hyperparameters, type):
|
|
self.hyperparameters = dict()
|
|
|
|
if isinstance(self.evaluation, type):
|
|
self.evaluation = dict()
|
|
|
|
if isinstance(self.trainer, type):
|
|
self.trainer = dict()
|
|
|
|
if isinstance(self.inference, type):
|
|
self.inference = dict()
|
|
|
|
if isinstance(self.optimizations, type):
|
|
self.optimizations = dict()
|
|
|
|
self.dataset = Dataset(**self.dataset)
|
|
self.dataset.training = [ Path(dir) for dir in self.dataset.training ]
|
|
self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ]
|
|
|
|
self.models = [ Model(**model) for model in self.models ]
|
|
self.loras = [ LoRA(**lora) for lora in self.loras ]
|
|
|
|
if not self.models:
|
|
self.models = [ Model() ]
|
|
|
|
self.hyperparameters = Hyperparameters(**self.hyperparameters)
|
|
|
|
self.evaluation = Evaluation(**self.evaluation)
|
|
|
|
self.trainer = Trainer(**self.trainer)
|
|
|
|
if not isinstance(self.trainer.deepspeed, type):
|
|
self.trainer.deepspeed = DeepSpeed(**self.trainer.deepspeed)
|
|
|
|
self.inference = Inference(**self.inference)
|
|
|
|
if self.bitsandbytes is not None:
|
|
self.optimizations = Optimizations(**self.bitsandbytes)
|
|
else:
|
|
self.optimizations = Optimizations(**self.optimizations)
|
|
|
|
if self.hyperparameters.scheduler_type and not self.hyperparameters.scheduler:
|
|
self.hyperparameters.scheduler = self.hyperparameters.scheduler_type
|
|
self.hyperparameters.scheduler_type = ""
|
|
|
|
# do not combine the two
|
|
if self.hyperparameters.scheduler == "schedulefree" and self.optimizations.dadaptation:
|
|
self.hyperparameters.scheduler = ""
|
|
|
|
if self.hyperparameters.scheduler == "":
|
|
self.hyperparameters.torch_scheduler = True
|
|
|
|
if self.trainer.backend == "local" and self.distributed:
|
|
self.trainer.ddp = True
|
|
|
|
if self.trainer.activation_checkpointing is not None:
|
|
self.trainer.gradient_checkpointing = self.trainer.activation_checkpointing
|
|
|
|
if not training:
|
|
self.dataset.use_hdf5 = False
|
|
|
|
# load our HDF5 file if requested here
|
|
if self.dataset.use_hdf5:
|
|
self.load_hdf5()
|
|
|
|
# load tokenizer
|
|
if cfg.tokenizer == "naive":
|
|
cfg.tokenizer = NaiveTokenizer()
|
|
else:
|
|
try:
|
|
from transformers import PreTrainedTokenizerFast
|
|
|
|
tokenizer_path = cfg.rel_path / cfg.tokenizer_path if cfg.yaml_path is not None else None
|
|
if tokenizer_path and not tokenizer_path.exists():
|
|
tokenizer_path = Path("./data/") / cfg.tokenizer_path
|
|
|
|
if tokenizer_path and tokenizer_path.exists():
|
|
cfg.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_path))
|
|
else:
|
|
cfg.tokenizer = NaiveTokenizer()
|
|
except Exception as e:
|
|
cfg.tokenizer = NaiveTokenizer()
|
|
_logger.warning(f"Error while parsing tokenizer: {str(e)}")
|
|
pass
|
|
|
|
|
|
# Preserves the old behavior
|
|
class NaiveTokenizer:
|
|
def get_vocab( self ):
|
|
"""
|
|
if cfg.dataset.use_hdf5 and 'symmap' in cfg.hdf5:
|
|
return json.loads( cfg.hdf5['symmap'].asstr()[()] )
|
|
"""
|
|
return { " ": 0, "<s>": 1, "</s>": 2, "0": 3, "2": 4, "4": 5, "8": 6, "A": 7, "D": 8, "G": 9, "H": 10, "J": 11, "K": 12, "M": 13, "N": 14, "P": 15, "R": 16, "S": 17, "T": 18, "V": 19, "W": 20, "X": 21, "Y": 22 }
|
|
|
|
@cached_property
|
|
def _bos_token( self ):
|
|
return self.get_vocab()["<s>"]
|
|
|
|
@cached_property
|
|
def _eos_token( self ):
|
|
return self.get_vocab()["</s>"]
|
|
|
|
def encode( self, s ):
|
|
symmap = self.get_vocab()
|
|
s = s.replace("O", "0")
|
|
s = [f"<s>"] + [ p if p in symmap else " " for p in s ] + [f"</s>"]
|
|
return [*map(symmap.get, s)]
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
cfg = Config.from_cli()
|
|
|
|
# some safety for remapping deprecated formats and re-coercing uninitialized properties into actual types
|
|
try:
|
|
cfg.format()
|
|
except Exception as e:
|
|
_logger.error(f"Error while parsing config YAML: {str(e)}")
|
|
raise e # throw an error because I'm tired of silent errors messing things up for me
|
|
|
|
if __name__ == "__main__":
|
|
print(cfg)
|