renamed cfg.bitsandbytes to cfg.optimizations (and having it serve as cfg.optimizations.bitsandbytes)

This commit is contained in:
mrq 2024-05-02 20:08:59 -05:00
parent b5d1456a09
commit a7b43b98b5
4 changed files with 32 additions and 60 deletions

View File

@ -202,7 +202,7 @@ class Model:
@property @property
# required for fp8 as the lengths needs to be divisible by 8 # required for fp8 as the lengths needs to be divisible by 8
def input_alignment(self): def input_alignment(self):
return 8 if cfg.fp8.enabled else 0 return 8 if cfg.optimizations.fp8 else 0
@property @property
def full_name(self): def full_name(self):
@ -220,7 +220,7 @@ class Model:
else: else:
name.append(self.arch_type.replace("/", "-")) name.append(self.arch_type.replace("/", "-"))
if cfg.bitsandbytes.bitnet: if cfg.optimizations.bitnet:
name.append("bitnet") name.append("bitnet")
if self.interleave: if self.interleave:
@ -521,9 +521,10 @@ class Inference:
return torch.float8_e4m3fn return torch.float8_e4m3fn
return torch.float32 return torch.float32
# should be renamed to optimizations
@dataclass() @dataclass()
class BitsAndBytes: class Optimizations:
enabled: bool = False bitsandbytes: bool = False
injects: bool = False injects: bool = False
replace: bool = False replace: bool = False
@ -531,11 +532,7 @@ class BitsAndBytes:
embedding: bool = True embedding: bool = True
bitnet: bool = False bitnet: bool = False
fp8: bool = False
@dataclass()
class FP8:
enabled: bool = False
backend: str = "te"
@dataclass() @dataclass()
class Config(_Config): class Config(_Config):
@ -550,11 +547,10 @@ class Config(_Config):
evaluation: Evaluation = field(default_factory=lambda: Evaluation) evaluation: Evaluation = field(default_factory=lambda: Evaluation)
trainer: Trainer = field(default_factory=lambda: Trainer) trainer: Trainer = field(default_factory=lambda: Trainer)
inference: Inference = field(default_factory=lambda: Inference) inference: Inference = field(default_factory=lambda: Inference)
bitsandbytes: BitsAndBytes = field(default_factory=lambda: BitsAndBytes) bitsandbytes: dict | list | None = None # deprecated
optimizations: Optimizations = field(default_factory=lambda: Optimizations)
tokenizer: str = "./tokenizer.json" tokenizer: str = "./tokenizer.json"
fp8: FP8 = field(default_factory=lambda: FP8)
sample_rate: int = 24_000 sample_rate: int = 24_000
variable_sample_rate: bool = True variable_sample_rate: bool = True
@ -594,30 +590,31 @@ class Config(_Config):
self.dataset.use_hdf5 = False self.dataset.use_hdf5 = False
def format( self ): def format( self ):
#if not isinstance(self.dataset, type):
self.dataset = Dataset(**self.dataset) self.dataset = Dataset(**self.dataset)
self.dataset.training = [ Path(dir) for dir in self.dataset.training ] self.dataset.training = [ Path(dir) for dir in self.dataset.training ]
self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ] self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ]
self.dataset.noise = [ Path(dir) for dir in self.dataset.noise ] self.dataset.noise = [ Path(dir) for dir in self.dataset.noise ]
#if not isinstance(self.model, type):
if self.models is not None: if self.models is not None:
self.model = Model(**next(iter(self.models))) self.model = Model(**next(iter(self.models)))
else: else:
self.model = Model(**self.model) self.model = Model(**self.model)
#if not isinstance(self.hyperparameters, type):
self.hyperparameters = Hyperparameters(**self.hyperparameters) self.hyperparameters = Hyperparameters(**self.hyperparameters)
#if not isinstance(self.evaluation, type):
self.evaluation = Evaluation(**self.evaluation) self.evaluation = Evaluation(**self.evaluation)
#if not isinstance(self.trainer, type):
self.trainer = Trainer(**self.trainer) self.trainer = Trainer(**self.trainer)
if not isinstance(self.trainer.deepspeed, type): if not isinstance(self.trainer.deepspeed, type):
self.trainer.deepspeed = DeepSpeed(**self.trainer.deepspeed) self.trainer.deepspeed = DeepSpeed(**self.trainer.deepspeed)
#if not isinstance(self.inference, type):
self.inference = Inference(**self.inference) self.inference = Inference(**self.inference)
#if not isinstance(self.bitsandbytes, type):
self.bitsandbytes = BitsAndBytes(**self.bitsandbytes) if self.bitsandbytes is not None:
self.optimizations = Optimizations(**self.bitsandbytes)
else:
self.optimizations = Optimizations(**self.optimizations)
# Preserves the old behavior # Preserves the old behavior
class NaiveTokenizer: class NaiveTokenizer:

View File

@ -44,7 +44,7 @@ def load_engines(training=True):
if inferencing: if inferencing:
model._cfg.training = False model._cfg.training = False
if (cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace) or (cfg.fp8.enabled): if (cfg.optimizations.bitsandbytes and cfg.optimizations.replace) or (cfg.optimizations.fp8):
model.model = ml.replace_linear( model.model ) model.model = ml.replace_linear( model.model )
if backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer): if backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer):

View File

@ -368,7 +368,7 @@ def example_usage():
'n_layers': 12, # 32 'n_layers': 12, # 32
'n_experts': 1, 'n_experts': 1,
'l_padding': 8 if cfg.fp8.enabled else 0, 'l_padding': 8 if cfg.optimizations.fp8 else 0,
'config': cfg.model 'config': cfg.model
} }
@ -397,33 +397,7 @@ def example_usage():
engine = Engine(model=model, optimizer=optimizer) engine = Engine(model=model, optimizer=optimizer)
# copy embeddings if requested if (cfg.optimizations.bitsandbytes and cfg.optimizations.replace) or (cfg.optimizations.fp8):
"""
if cfg.model._embeddings is not None:
embeddings_path = cfg.relpath / cfg.model._embeddings
if embeddings_path.exists():
embeddings = torch.load(embeddings_path, map_location=torch.device(cfg.device))
if "module" in embeddings:
embeddings = embeddings["module"]
frozen_params = set()
for k in list(embeddings.keys()):
if re.findall(r'_emb.', k):
frozen_params.add(k)
else:
del embeddings[k]
engine.module.load_state_dict(embeddings, strict=False)
for name, param in engine.module.named_parameters():
if name not in frozen_params:
continue
param.requires_grad_(False)
engine._frozen_params.add(param)
"""
if (cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace) or (cfg.fp8.enabled):
model.model = ml.replace_linear( model.model ) model.model = ml.replace_linear( model.model )
torch.save( { torch.save( {

View File

@ -8,20 +8,20 @@ Embedding = torch.nn.Embedding
Linear = torch.nn.Linear Linear = torch.nn.Linear
# https://github.com/kyegomez/BitNet # https://github.com/kyegomez/BitNet
if cfg.bitsandbytes.bitnet: if cfg.optimizations.bitnet:
from bitnet import BitLinear from bitnet import BitLinear
if cfg.bitsandbytes.enabled: if cfg.optimizations.bitsandbytes:
import bitsandbytes as bnb import bitsandbytes as bnb
if cfg.bitsandbytes.linear: if cfg.optimizations.linear:
if cfg.bitsandbytes.bitnet: if cfg.optimizations.bitnet:
Linear = BitLinear Linear = BitLinear
else: else:
Linear = bnb.nn.Linear8bitLt Linear = bnb.nn.Linear8bitLt
if cfg.bitsandbytes.embedding: if cfg.optimizations.embedding:
Embedding = bnb.nn.modules.Embedding Embedding = bnb.nn.modules.Embedding
""" """
Embedding.forward = lambda self, input: ( self.norm(F.embedding( Embedding.forward = lambda self, input: ( self.norm(F.embedding(
@ -36,7 +36,7 @@ if cfg.bitsandbytes.enabled:
""" """
if cfg.bitsandbytes.enabled: if cfg.optimizations.bitsandbytes:
import bitsandbytes as bnb import bitsandbytes as bnb
Adam = bnb.optim.Adam8bit Adam = bnb.optim.Adam8bit
@ -77,7 +77,7 @@ def autocast_forward( func ):
return wrapper return wrapper
Embedding.forward = autocast_forward(Embedding.forward) Embedding.forward = autocast_forward(Embedding.forward)
if cfg.fp8.enabled: if cfg.optimizations.fp8:
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
Linear = te.Linear Linear = te.Linear
@ -90,7 +90,7 @@ else:
def autocast(): def autocast():
yield torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp) yield torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp)
if cfg.bitsandbytes.injects and cfg.bitsandbytes.enabled: if cfg.optimizations.injects and cfg.optimizations.bitsandbytes:
torch.nn.Linear = Linear torch.nn.Linear = Linear
torch.nn.Embedding = Embedding torch.nn.Embedding = Embedding
@ -98,16 +98,17 @@ if cfg.bitsandbytes.injects and cfg.bitsandbytes.enabled:
torch.optim.AdamW = AdamW torch.optim.AdamW = AdamW
torch.optim.SGD = SGD torch.optim.SGD = SGD
# disgusting kludge, but it works (just realized BitNet has its own replacement routine) # disgusting kludge, but it works (just realized BitNet has its own replacement routine)
def replace_linear( model ): def replace_linear( model ):
bnb = cfg.bitsandbytes.enabled and cfg.bitsandbytes.linear and not cfg.bitsandbytes.bitnet bnb = cfg.optimizations.bitsandbytes and cfg.optimizations.linear and not cfg.optimizations.bitnet
klass = Linear
device = next(model.parameters()).device device = next(model.parameters()).device
linears = [k.split('.') for k, m in model.named_modules() if isinstance(m, torch.nn.Linear)] linears = [k.split('.') for k, m in model.named_modules() if isinstance(m, torch.nn.Linear)]
for *parent, k in linears: for *parent, k in linears:
name = '.'.join(parent) name = '.'.join(parent)
# copy parameters # copy parameters
m = getattr( model.get_submodule(name), k ) m = getattr( model.get_submodule(name), k )
@ -120,7 +121,7 @@ def replace_linear( model ):
# overwrite # overwrite
setattr( setattr(
model.get_submodule(name), k, model.get_submodule(name), k,
Linear( **kwargs ).to(device=device, dtype=cfg.trainer.dtype) klass( **kwargs ).to(device=device, dtype=cfg.trainer.dtype)
) )
return model return model