renamed cfg.bitsandbytes to cfg.optimizations (and having it serve as cfg.optimizations.bitsandbytes)
This commit is contained in:
parent
b5d1456a09
commit
a7b43b98b5
|
@ -202,7 +202,7 @@ class Model:
|
||||||
@property
|
@property
|
||||||
# required for fp8 as the lengths needs to be divisible by 8
|
# required for fp8 as the lengths needs to be divisible by 8
|
||||||
def input_alignment(self):
|
def input_alignment(self):
|
||||||
return 8 if cfg.fp8.enabled else 0
|
return 8 if cfg.optimizations.fp8 else 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def full_name(self):
|
def full_name(self):
|
||||||
|
@ -220,7 +220,7 @@ class Model:
|
||||||
else:
|
else:
|
||||||
name.append(self.arch_type.replace("/", "-"))
|
name.append(self.arch_type.replace("/", "-"))
|
||||||
|
|
||||||
if cfg.bitsandbytes.bitnet:
|
if cfg.optimizations.bitnet:
|
||||||
name.append("bitnet")
|
name.append("bitnet")
|
||||||
|
|
||||||
if self.interleave:
|
if self.interleave:
|
||||||
|
@ -521,9 +521,10 @@ class Inference:
|
||||||
return torch.float8_e4m3fn
|
return torch.float8_e4m3fn
|
||||||
return torch.float32
|
return torch.float32
|
||||||
|
|
||||||
|
# should be renamed to optimizations
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class BitsAndBytes:
|
class Optimizations:
|
||||||
enabled: bool = False
|
bitsandbytes: bool = False
|
||||||
injects: bool = False
|
injects: bool = False
|
||||||
replace: bool = False
|
replace: bool = False
|
||||||
|
|
||||||
|
@ -531,11 +532,7 @@ class BitsAndBytes:
|
||||||
embedding: bool = True
|
embedding: bool = True
|
||||||
|
|
||||||
bitnet: bool = False
|
bitnet: bool = False
|
||||||
|
fp8: bool = False
|
||||||
@dataclass()
|
|
||||||
class FP8:
|
|
||||||
enabled: bool = False
|
|
||||||
backend: str = "te"
|
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Config(_Config):
|
class Config(_Config):
|
||||||
|
@ -550,11 +547,10 @@ class Config(_Config):
|
||||||
evaluation: Evaluation = field(default_factory=lambda: Evaluation)
|
evaluation: Evaluation = field(default_factory=lambda: Evaluation)
|
||||||
trainer: Trainer = field(default_factory=lambda: Trainer)
|
trainer: Trainer = field(default_factory=lambda: Trainer)
|
||||||
inference: Inference = field(default_factory=lambda: Inference)
|
inference: Inference = field(default_factory=lambda: Inference)
|
||||||
bitsandbytes: BitsAndBytes = field(default_factory=lambda: BitsAndBytes)
|
bitsandbytes: dict | list | None = None # deprecated
|
||||||
|
optimizations: Optimizations = field(default_factory=lambda: Optimizations)
|
||||||
|
|
||||||
tokenizer: str = "./tokenizer.json"
|
tokenizer: str = "./tokenizer.json"
|
||||||
|
|
||||||
fp8: FP8 = field(default_factory=lambda: FP8)
|
|
||||||
|
|
||||||
sample_rate: int = 24_000
|
sample_rate: int = 24_000
|
||||||
variable_sample_rate: bool = True
|
variable_sample_rate: bool = True
|
||||||
|
@ -594,30 +590,31 @@ class Config(_Config):
|
||||||
self.dataset.use_hdf5 = False
|
self.dataset.use_hdf5 = False
|
||||||
|
|
||||||
def format( self ):
|
def format( self ):
|
||||||
#if not isinstance(self.dataset, type):
|
|
||||||
self.dataset = Dataset(**self.dataset)
|
self.dataset = Dataset(**self.dataset)
|
||||||
self.dataset.training = [ Path(dir) for dir in self.dataset.training ]
|
self.dataset.training = [ Path(dir) for dir in self.dataset.training ]
|
||||||
self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ]
|
self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ]
|
||||||
self.dataset.noise = [ Path(dir) for dir in self.dataset.noise ]
|
self.dataset.noise = [ Path(dir) for dir in self.dataset.noise ]
|
||||||
|
|
||||||
#if not isinstance(self.model, type):
|
|
||||||
if self.models is not None:
|
if self.models is not None:
|
||||||
self.model = Model(**next(iter(self.models)))
|
self.model = Model(**next(iter(self.models)))
|
||||||
else:
|
else:
|
||||||
self.model = Model(**self.model)
|
self.model = Model(**self.model)
|
||||||
|
|
||||||
#if not isinstance(self.hyperparameters, type):
|
|
||||||
self.hyperparameters = Hyperparameters(**self.hyperparameters)
|
self.hyperparameters = Hyperparameters(**self.hyperparameters)
|
||||||
#if not isinstance(self.evaluation, type):
|
|
||||||
self.evaluation = Evaluation(**self.evaluation)
|
self.evaluation = Evaluation(**self.evaluation)
|
||||||
#if not isinstance(self.trainer, type):
|
|
||||||
self.trainer = Trainer(**self.trainer)
|
self.trainer = Trainer(**self.trainer)
|
||||||
|
|
||||||
if not isinstance(self.trainer.deepspeed, type):
|
if not isinstance(self.trainer.deepspeed, type):
|
||||||
self.trainer.deepspeed = DeepSpeed(**self.trainer.deepspeed)
|
self.trainer.deepspeed = DeepSpeed(**self.trainer.deepspeed)
|
||||||
#if not isinstance(self.inference, type):
|
|
||||||
self.inference = Inference(**self.inference)
|
self.inference = Inference(**self.inference)
|
||||||
#if not isinstance(self.bitsandbytes, type):
|
|
||||||
self.bitsandbytes = BitsAndBytes(**self.bitsandbytes)
|
if self.bitsandbytes is not None:
|
||||||
|
self.optimizations = Optimizations(**self.bitsandbytes)
|
||||||
|
else:
|
||||||
|
self.optimizations = Optimizations(**self.optimizations)
|
||||||
|
|
||||||
# Preserves the old behavior
|
# Preserves the old behavior
|
||||||
class NaiveTokenizer:
|
class NaiveTokenizer:
|
||||||
|
|
|
@ -44,7 +44,7 @@ def load_engines(training=True):
|
||||||
if inferencing:
|
if inferencing:
|
||||||
model._cfg.training = False
|
model._cfg.training = False
|
||||||
|
|
||||||
if (cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace) or (cfg.fp8.enabled):
|
if (cfg.optimizations.bitsandbytes and cfg.optimizations.replace) or (cfg.optimizations.fp8):
|
||||||
model.model = ml.replace_linear( model.model )
|
model.model = ml.replace_linear( model.model )
|
||||||
|
|
||||||
if backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer):
|
if backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer):
|
||||||
|
|
|
@ -368,7 +368,7 @@ def example_usage():
|
||||||
'n_layers': 12, # 32
|
'n_layers': 12, # 32
|
||||||
'n_experts': 1,
|
'n_experts': 1,
|
||||||
|
|
||||||
'l_padding': 8 if cfg.fp8.enabled else 0,
|
'l_padding': 8 if cfg.optimizations.fp8 else 0,
|
||||||
|
|
||||||
'config': cfg.model
|
'config': cfg.model
|
||||||
}
|
}
|
||||||
|
@ -397,33 +397,7 @@ def example_usage():
|
||||||
|
|
||||||
engine = Engine(model=model, optimizer=optimizer)
|
engine = Engine(model=model, optimizer=optimizer)
|
||||||
|
|
||||||
# copy embeddings if requested
|
if (cfg.optimizations.bitsandbytes and cfg.optimizations.replace) or (cfg.optimizations.fp8):
|
||||||
"""
|
|
||||||
if cfg.model._embeddings is not None:
|
|
||||||
embeddings_path = cfg.relpath / cfg.model._embeddings
|
|
||||||
|
|
||||||
if embeddings_path.exists():
|
|
||||||
embeddings = torch.load(embeddings_path, map_location=torch.device(cfg.device))
|
|
||||||
if "module" in embeddings:
|
|
||||||
embeddings = embeddings["module"]
|
|
||||||
|
|
||||||
frozen_params = set()
|
|
||||||
for k in list(embeddings.keys()):
|
|
||||||
if re.findall(r'_emb.', k):
|
|
||||||
frozen_params.add(k)
|
|
||||||
else:
|
|
||||||
del embeddings[k]
|
|
||||||
|
|
||||||
engine.module.load_state_dict(embeddings, strict=False)
|
|
||||||
|
|
||||||
for name, param in engine.module.named_parameters():
|
|
||||||
if name not in frozen_params:
|
|
||||||
continue
|
|
||||||
param.requires_grad_(False)
|
|
||||||
engine._frozen_params.add(param)
|
|
||||||
"""
|
|
||||||
|
|
||||||
if (cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace) or (cfg.fp8.enabled):
|
|
||||||
model.model = ml.replace_linear( model.model )
|
model.model = ml.replace_linear( model.model )
|
||||||
|
|
||||||
torch.save( {
|
torch.save( {
|
||||||
|
|
|
@ -8,20 +8,20 @@ Embedding = torch.nn.Embedding
|
||||||
Linear = torch.nn.Linear
|
Linear = torch.nn.Linear
|
||||||
|
|
||||||
# https://github.com/kyegomez/BitNet
|
# https://github.com/kyegomez/BitNet
|
||||||
if cfg.bitsandbytes.bitnet:
|
if cfg.optimizations.bitnet:
|
||||||
from bitnet import BitLinear
|
from bitnet import BitLinear
|
||||||
|
|
||||||
if cfg.bitsandbytes.enabled:
|
if cfg.optimizations.bitsandbytes:
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
|
|
||||||
if cfg.bitsandbytes.linear:
|
if cfg.optimizations.linear:
|
||||||
|
|
||||||
if cfg.bitsandbytes.bitnet:
|
if cfg.optimizations.bitnet:
|
||||||
Linear = BitLinear
|
Linear = BitLinear
|
||||||
else:
|
else:
|
||||||
Linear = bnb.nn.Linear8bitLt
|
Linear = bnb.nn.Linear8bitLt
|
||||||
|
|
||||||
if cfg.bitsandbytes.embedding:
|
if cfg.optimizations.embedding:
|
||||||
Embedding = bnb.nn.modules.Embedding
|
Embedding = bnb.nn.modules.Embedding
|
||||||
"""
|
"""
|
||||||
Embedding.forward = lambda self, input: ( self.norm(F.embedding(
|
Embedding.forward = lambda self, input: ( self.norm(F.embedding(
|
||||||
|
@ -36,7 +36,7 @@ if cfg.bitsandbytes.enabled:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
if cfg.bitsandbytes.enabled:
|
if cfg.optimizations.bitsandbytes:
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
|
|
||||||
Adam = bnb.optim.Adam8bit
|
Adam = bnb.optim.Adam8bit
|
||||||
|
@ -77,7 +77,7 @@ def autocast_forward( func ):
|
||||||
return wrapper
|
return wrapper
|
||||||
Embedding.forward = autocast_forward(Embedding.forward)
|
Embedding.forward = autocast_forward(Embedding.forward)
|
||||||
|
|
||||||
if cfg.fp8.enabled:
|
if cfg.optimizations.fp8:
|
||||||
import transformer_engine.pytorch as te
|
import transformer_engine.pytorch as te
|
||||||
|
|
||||||
Linear = te.Linear
|
Linear = te.Linear
|
||||||
|
@ -90,7 +90,7 @@ else:
|
||||||
def autocast():
|
def autocast():
|
||||||
yield torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp)
|
yield torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp)
|
||||||
|
|
||||||
if cfg.bitsandbytes.injects and cfg.bitsandbytes.enabled:
|
if cfg.optimizations.injects and cfg.optimizations.bitsandbytes:
|
||||||
torch.nn.Linear = Linear
|
torch.nn.Linear = Linear
|
||||||
torch.nn.Embedding = Embedding
|
torch.nn.Embedding = Embedding
|
||||||
|
|
||||||
|
@ -98,16 +98,17 @@ if cfg.bitsandbytes.injects and cfg.bitsandbytes.enabled:
|
||||||
torch.optim.AdamW = AdamW
|
torch.optim.AdamW = AdamW
|
||||||
torch.optim.SGD = SGD
|
torch.optim.SGD = SGD
|
||||||
|
|
||||||
|
|
||||||
# disgusting kludge, but it works (just realized BitNet has its own replacement routine)
|
# disgusting kludge, but it works (just realized BitNet has its own replacement routine)
|
||||||
def replace_linear( model ):
|
def replace_linear( model ):
|
||||||
bnb = cfg.bitsandbytes.enabled and cfg.bitsandbytes.linear and not cfg.bitsandbytes.bitnet
|
bnb = cfg.optimizations.bitsandbytes and cfg.optimizations.linear and not cfg.optimizations.bitnet
|
||||||
|
klass = Linear
|
||||||
|
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
linears = [k.split('.') for k, m in model.named_modules() if isinstance(m, torch.nn.Linear)]
|
linears = [k.split('.') for k, m in model.named_modules() if isinstance(m, torch.nn.Linear)]
|
||||||
for *parent, k in linears:
|
for *parent, k in linears:
|
||||||
name = '.'.join(parent)
|
name = '.'.join(parent)
|
||||||
|
|
||||||
|
|
||||||
# copy parameters
|
# copy parameters
|
||||||
m = getattr( model.get_submodule(name), k )
|
m = getattr( model.get_submodule(name), k )
|
||||||
|
|
||||||
|
@ -120,7 +121,7 @@ def replace_linear( model ):
|
||||||
# overwrite
|
# overwrite
|
||||||
setattr(
|
setattr(
|
||||||
model.get_submodule(name), k,
|
model.get_submodule(name), k,
|
||||||
Linear( **kwargs ).to(device=device, dtype=cfg.trainer.dtype)
|
klass( **kwargs ).to(device=device, dtype=cfg.trainer.dtype)
|
||||||
)
|
)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
Loading…
Reference in New Issue
Block a user