renamed cfg.bitsandbytes to cfg.optimizations (and having it serve as cfg.optimizations.bitsandbytes)

This commit is contained in:
mrq 2024-05-02 20:08:59 -05:00
parent b5d1456a09
commit a7b43b98b5
4 changed files with 32 additions and 60 deletions

View File

@ -202,7 +202,7 @@ class Model:
@property
# required for fp8 as the lengths needs to be divisible by 8
def input_alignment(self):
return 8 if cfg.fp8.enabled else 0
return 8 if cfg.optimizations.fp8 else 0
@property
def full_name(self):
@ -220,7 +220,7 @@ class Model:
else:
name.append(self.arch_type.replace("/", "-"))
if cfg.bitsandbytes.bitnet:
if cfg.optimizations.bitnet:
name.append("bitnet")
if self.interleave:
@ -521,9 +521,10 @@ class Inference:
return torch.float8_e4m3fn
return torch.float32
# should be renamed to optimizations
@dataclass()
class BitsAndBytes:
enabled: bool = False
class Optimizations:
bitsandbytes: bool = False
injects: bool = False
replace: bool = False
@ -531,11 +532,7 @@ class BitsAndBytes:
embedding: bool = True
bitnet: bool = False
@dataclass()
class FP8:
enabled: bool = False
backend: str = "te"
fp8: bool = False
@dataclass()
class Config(_Config):
@ -550,11 +547,10 @@ class Config(_Config):
evaluation: Evaluation = field(default_factory=lambda: Evaluation)
trainer: Trainer = field(default_factory=lambda: Trainer)
inference: Inference = field(default_factory=lambda: Inference)
bitsandbytes: BitsAndBytes = field(default_factory=lambda: BitsAndBytes)
bitsandbytes: dict | list | None = None # deprecated
optimizations: Optimizations = field(default_factory=lambda: Optimizations)
tokenizer: str = "./tokenizer.json"
fp8: FP8 = field(default_factory=lambda: FP8)
sample_rate: int = 24_000
variable_sample_rate: bool = True
@ -594,30 +590,31 @@ class Config(_Config):
self.dataset.use_hdf5 = False
def format( self ):
#if not isinstance(self.dataset, type):
self.dataset = Dataset(**self.dataset)
self.dataset.training = [ Path(dir) for dir in self.dataset.training ]
self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ]
self.dataset.noise = [ Path(dir) for dir in self.dataset.noise ]
#if not isinstance(self.model, type):
if self.models is not None:
self.model = Model(**next(iter(self.models)))
else:
self.model = Model(**self.model)
#if not isinstance(self.hyperparameters, type):
self.hyperparameters = Hyperparameters(**self.hyperparameters)
#if not isinstance(self.evaluation, type):
self.evaluation = Evaluation(**self.evaluation)
#if not isinstance(self.trainer, type):
self.trainer = Trainer(**self.trainer)
if not isinstance(self.trainer.deepspeed, type):
self.trainer.deepspeed = DeepSpeed(**self.trainer.deepspeed)
#if not isinstance(self.inference, type):
self.inference = Inference(**self.inference)
#if not isinstance(self.bitsandbytes, type):
self.bitsandbytes = BitsAndBytes(**self.bitsandbytes)
if self.bitsandbytes is not None:
self.optimizations = Optimizations(**self.bitsandbytes)
else:
self.optimizations = Optimizations(**self.optimizations)
# Preserves the old behavior
class NaiveTokenizer:

View File

@ -44,7 +44,7 @@ def load_engines(training=True):
if inferencing:
model._cfg.training = False
if (cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace) or (cfg.fp8.enabled):
if (cfg.optimizations.bitsandbytes and cfg.optimizations.replace) or (cfg.optimizations.fp8):
model.model = ml.replace_linear( model.model )
if backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer):

View File

@ -368,7 +368,7 @@ def example_usage():
'n_layers': 12, # 32
'n_experts': 1,
'l_padding': 8 if cfg.fp8.enabled else 0,
'l_padding': 8 if cfg.optimizations.fp8 else 0,
'config': cfg.model
}
@ -397,33 +397,7 @@ def example_usage():
engine = Engine(model=model, optimizer=optimizer)
# copy embeddings if requested
"""
if cfg.model._embeddings is not None:
embeddings_path = cfg.relpath / cfg.model._embeddings
if embeddings_path.exists():
embeddings = torch.load(embeddings_path, map_location=torch.device(cfg.device))
if "module" in embeddings:
embeddings = embeddings["module"]
frozen_params = set()
for k in list(embeddings.keys()):
if re.findall(r'_emb.', k):
frozen_params.add(k)
else:
del embeddings[k]
engine.module.load_state_dict(embeddings, strict=False)
for name, param in engine.module.named_parameters():
if name not in frozen_params:
continue
param.requires_grad_(False)
engine._frozen_params.add(param)
"""
if (cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace) or (cfg.fp8.enabled):
if (cfg.optimizations.bitsandbytes and cfg.optimizations.replace) or (cfg.optimizations.fp8):
model.model = ml.replace_linear( model.model )
torch.save( {

View File

@ -8,20 +8,20 @@ Embedding = torch.nn.Embedding
Linear = torch.nn.Linear
# https://github.com/kyegomez/BitNet
if cfg.bitsandbytes.bitnet:
if cfg.optimizations.bitnet:
from bitnet import BitLinear
if cfg.bitsandbytes.enabled:
if cfg.optimizations.bitsandbytes:
import bitsandbytes as bnb
if cfg.bitsandbytes.linear:
if cfg.optimizations.linear:
if cfg.bitsandbytes.bitnet:
if cfg.optimizations.bitnet:
Linear = BitLinear
else:
Linear = bnb.nn.Linear8bitLt
if cfg.bitsandbytes.embedding:
if cfg.optimizations.embedding:
Embedding = bnb.nn.modules.Embedding
"""
Embedding.forward = lambda self, input: ( self.norm(F.embedding(
@ -36,7 +36,7 @@ if cfg.bitsandbytes.enabled:
"""
if cfg.bitsandbytes.enabled:
if cfg.optimizations.bitsandbytes:
import bitsandbytes as bnb
Adam = bnb.optim.Adam8bit
@ -77,7 +77,7 @@ def autocast_forward( func ):
return wrapper
Embedding.forward = autocast_forward(Embedding.forward)
if cfg.fp8.enabled:
if cfg.optimizations.fp8:
import transformer_engine.pytorch as te
Linear = te.Linear
@ -90,7 +90,7 @@ else:
def autocast():
yield torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp)
if cfg.bitsandbytes.injects and cfg.bitsandbytes.enabled:
if cfg.optimizations.injects and cfg.optimizations.bitsandbytes:
torch.nn.Linear = Linear
torch.nn.Embedding = Embedding
@ -98,16 +98,17 @@ if cfg.bitsandbytes.injects and cfg.bitsandbytes.enabled:
torch.optim.AdamW = AdamW
torch.optim.SGD = SGD
# disgusting kludge, but it works (just realized BitNet has its own replacement routine)
def replace_linear( model ):
bnb = cfg.bitsandbytes.enabled and cfg.bitsandbytes.linear and not cfg.bitsandbytes.bitnet
bnb = cfg.optimizations.bitsandbytes and cfg.optimizations.linear and not cfg.optimizations.bitnet
klass = Linear
device = next(model.parameters()).device
linears = [k.split('.') for k, m in model.named_modules() if isinstance(m, torch.nn.Linear)]
for *parent, k in linears:
name = '.'.join(parent)
# copy parameters
m = getattr( model.get_submodule(name), k )
@ -120,7 +121,7 @@ def replace_linear( model ):
# overwrite
setattr(
model.get_submodule(name), k,
Linear( **kwargs ).to(device=device, dtype=cfg.trainer.dtype)
klass( **kwargs ).to(device=device, dtype=cfg.trainer.dtype)
)
return model