renamed cfg.bitsandbytes to cfg.optimizations (and having it serve as cfg.optimizations.bitsandbytes)
This commit is contained in:
parent
b5d1456a09
commit
a7b43b98b5
|
@ -202,7 +202,7 @@ class Model:
|
|||
@property
|
||||
# required for fp8 as the lengths needs to be divisible by 8
|
||||
def input_alignment(self):
|
||||
return 8 if cfg.fp8.enabled else 0
|
||||
return 8 if cfg.optimizations.fp8 else 0
|
||||
|
||||
@property
|
||||
def full_name(self):
|
||||
|
@ -220,7 +220,7 @@ class Model:
|
|||
else:
|
||||
name.append(self.arch_type.replace("/", "-"))
|
||||
|
||||
if cfg.bitsandbytes.bitnet:
|
||||
if cfg.optimizations.bitnet:
|
||||
name.append("bitnet")
|
||||
|
||||
if self.interleave:
|
||||
|
@ -521,9 +521,10 @@ class Inference:
|
|||
return torch.float8_e4m3fn
|
||||
return torch.float32
|
||||
|
||||
# should be renamed to optimizations
|
||||
@dataclass()
|
||||
class BitsAndBytes:
|
||||
enabled: bool = False
|
||||
class Optimizations:
|
||||
bitsandbytes: bool = False
|
||||
injects: bool = False
|
||||
replace: bool = False
|
||||
|
||||
|
@ -531,11 +532,7 @@ class BitsAndBytes:
|
|||
embedding: bool = True
|
||||
|
||||
bitnet: bool = False
|
||||
|
||||
@dataclass()
|
||||
class FP8:
|
||||
enabled: bool = False
|
||||
backend: str = "te"
|
||||
fp8: bool = False
|
||||
|
||||
@dataclass()
|
||||
class Config(_Config):
|
||||
|
@ -550,11 +547,10 @@ class Config(_Config):
|
|||
evaluation: Evaluation = field(default_factory=lambda: Evaluation)
|
||||
trainer: Trainer = field(default_factory=lambda: Trainer)
|
||||
inference: Inference = field(default_factory=lambda: Inference)
|
||||
bitsandbytes: BitsAndBytes = field(default_factory=lambda: BitsAndBytes)
|
||||
bitsandbytes: dict | list | None = None # deprecated
|
||||
optimizations: Optimizations = field(default_factory=lambda: Optimizations)
|
||||
|
||||
tokenizer: str = "./tokenizer.json"
|
||||
|
||||
fp8: FP8 = field(default_factory=lambda: FP8)
|
||||
|
||||
sample_rate: int = 24_000
|
||||
variable_sample_rate: bool = True
|
||||
|
@ -594,30 +590,31 @@ class Config(_Config):
|
|||
self.dataset.use_hdf5 = False
|
||||
|
||||
def format( self ):
|
||||
#if not isinstance(self.dataset, type):
|
||||
self.dataset = Dataset(**self.dataset)
|
||||
self.dataset.training = [ Path(dir) for dir in self.dataset.training ]
|
||||
self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ]
|
||||
self.dataset.noise = [ Path(dir) for dir in self.dataset.noise ]
|
||||
|
||||
#if not isinstance(self.model, type):
|
||||
if self.models is not None:
|
||||
self.model = Model(**next(iter(self.models)))
|
||||
else:
|
||||
self.model = Model(**self.model)
|
||||
|
||||
#if not isinstance(self.hyperparameters, type):
|
||||
self.hyperparameters = Hyperparameters(**self.hyperparameters)
|
||||
#if not isinstance(self.evaluation, type):
|
||||
|
||||
self.evaluation = Evaluation(**self.evaluation)
|
||||
#if not isinstance(self.trainer, type):
|
||||
|
||||
self.trainer = Trainer(**self.trainer)
|
||||
|
||||
if not isinstance(self.trainer.deepspeed, type):
|
||||
self.trainer.deepspeed = DeepSpeed(**self.trainer.deepspeed)
|
||||
#if not isinstance(self.inference, type):
|
||||
|
||||
self.inference = Inference(**self.inference)
|
||||
#if not isinstance(self.bitsandbytes, type):
|
||||
self.bitsandbytes = BitsAndBytes(**self.bitsandbytes)
|
||||
|
||||
if self.bitsandbytes is not None:
|
||||
self.optimizations = Optimizations(**self.bitsandbytes)
|
||||
else:
|
||||
self.optimizations = Optimizations(**self.optimizations)
|
||||
|
||||
# Preserves the old behavior
|
||||
class NaiveTokenizer:
|
||||
|
|
|
@ -44,7 +44,7 @@ def load_engines(training=True):
|
|||
if inferencing:
|
||||
model._cfg.training = False
|
||||
|
||||
if (cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace) or (cfg.fp8.enabled):
|
||||
if (cfg.optimizations.bitsandbytes and cfg.optimizations.replace) or (cfg.optimizations.fp8):
|
||||
model.model = ml.replace_linear( model.model )
|
||||
|
||||
if backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer):
|
||||
|
|
|
@ -368,7 +368,7 @@ def example_usage():
|
|||
'n_layers': 12, # 32
|
||||
'n_experts': 1,
|
||||
|
||||
'l_padding': 8 if cfg.fp8.enabled else 0,
|
||||
'l_padding': 8 if cfg.optimizations.fp8 else 0,
|
||||
|
||||
'config': cfg.model
|
||||
}
|
||||
|
@ -397,33 +397,7 @@ def example_usage():
|
|||
|
||||
engine = Engine(model=model, optimizer=optimizer)
|
||||
|
||||
# copy embeddings if requested
|
||||
"""
|
||||
if cfg.model._embeddings is not None:
|
||||
embeddings_path = cfg.relpath / cfg.model._embeddings
|
||||
|
||||
if embeddings_path.exists():
|
||||
embeddings = torch.load(embeddings_path, map_location=torch.device(cfg.device))
|
||||
if "module" in embeddings:
|
||||
embeddings = embeddings["module"]
|
||||
|
||||
frozen_params = set()
|
||||
for k in list(embeddings.keys()):
|
||||
if re.findall(r'_emb.', k):
|
||||
frozen_params.add(k)
|
||||
else:
|
||||
del embeddings[k]
|
||||
|
||||
engine.module.load_state_dict(embeddings, strict=False)
|
||||
|
||||
for name, param in engine.module.named_parameters():
|
||||
if name not in frozen_params:
|
||||
continue
|
||||
param.requires_grad_(False)
|
||||
engine._frozen_params.add(param)
|
||||
"""
|
||||
|
||||
if (cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace) or (cfg.fp8.enabled):
|
||||
if (cfg.optimizations.bitsandbytes and cfg.optimizations.replace) or (cfg.optimizations.fp8):
|
||||
model.model = ml.replace_linear( model.model )
|
||||
|
||||
torch.save( {
|
||||
|
|
|
@ -8,20 +8,20 @@ Embedding = torch.nn.Embedding
|
|||
Linear = torch.nn.Linear
|
||||
|
||||
# https://github.com/kyegomez/BitNet
|
||||
if cfg.bitsandbytes.bitnet:
|
||||
if cfg.optimizations.bitnet:
|
||||
from bitnet import BitLinear
|
||||
|
||||
if cfg.bitsandbytes.enabled:
|
||||
if cfg.optimizations.bitsandbytes:
|
||||
import bitsandbytes as bnb
|
||||
|
||||
if cfg.bitsandbytes.linear:
|
||||
if cfg.optimizations.linear:
|
||||
|
||||
if cfg.bitsandbytes.bitnet:
|
||||
if cfg.optimizations.bitnet:
|
||||
Linear = BitLinear
|
||||
else:
|
||||
Linear = bnb.nn.Linear8bitLt
|
||||
|
||||
if cfg.bitsandbytes.embedding:
|
||||
if cfg.optimizations.embedding:
|
||||
Embedding = bnb.nn.modules.Embedding
|
||||
"""
|
||||
Embedding.forward = lambda self, input: ( self.norm(F.embedding(
|
||||
|
@ -36,7 +36,7 @@ if cfg.bitsandbytes.enabled:
|
|||
"""
|
||||
|
||||
|
||||
if cfg.bitsandbytes.enabled:
|
||||
if cfg.optimizations.bitsandbytes:
|
||||
import bitsandbytes as bnb
|
||||
|
||||
Adam = bnb.optim.Adam8bit
|
||||
|
@ -77,7 +77,7 @@ def autocast_forward( func ):
|
|||
return wrapper
|
||||
Embedding.forward = autocast_forward(Embedding.forward)
|
||||
|
||||
if cfg.fp8.enabled:
|
||||
if cfg.optimizations.fp8:
|
||||
import transformer_engine.pytorch as te
|
||||
|
||||
Linear = te.Linear
|
||||
|
@ -90,7 +90,7 @@ else:
|
|||
def autocast():
|
||||
yield torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp)
|
||||
|
||||
if cfg.bitsandbytes.injects and cfg.bitsandbytes.enabled:
|
||||
if cfg.optimizations.injects and cfg.optimizations.bitsandbytes:
|
||||
torch.nn.Linear = Linear
|
||||
torch.nn.Embedding = Embedding
|
||||
|
||||
|
@ -98,16 +98,17 @@ if cfg.bitsandbytes.injects and cfg.bitsandbytes.enabled:
|
|||
torch.optim.AdamW = AdamW
|
||||
torch.optim.SGD = SGD
|
||||
|
||||
|
||||
# disgusting kludge, but it works (just realized BitNet has its own replacement routine)
|
||||
def replace_linear( model ):
|
||||
bnb = cfg.bitsandbytes.enabled and cfg.bitsandbytes.linear and not cfg.bitsandbytes.bitnet
|
||||
bnb = cfg.optimizations.bitsandbytes and cfg.optimizations.linear and not cfg.optimizations.bitnet
|
||||
klass = Linear
|
||||
|
||||
device = next(model.parameters()).device
|
||||
linears = [k.split('.') for k, m in model.named_modules() if isinstance(m, torch.nn.Linear)]
|
||||
for *parent, k in linears:
|
||||
name = '.'.join(parent)
|
||||
|
||||
|
||||
# copy parameters
|
||||
m = getattr( model.get_submodule(name), k )
|
||||
|
||||
|
@ -120,7 +121,7 @@ def replace_linear( model ):
|
|||
# overwrite
|
||||
setattr(
|
||||
model.get_submodule(name), k,
|
||||
Linear( **kwargs ).to(device=device, dtype=cfg.trainer.dtype)
|
||||
klass( **kwargs ).to(device=device, dtype=cfg.trainer.dtype)
|
||||
)
|
||||
|
||||
return model
|
||||
|
|
Loading…
Reference in New Issue
Block a user