actually do the Linear replacement with TE's Linear

This commit is contained in:
mrq 2024-04-09 14:41:13 -05:00
parent 9d97eb5104
commit 4d75ee066c
3 changed files with 6 additions and 4 deletions

View File

@ -44,7 +44,7 @@ def load_engines(training=True):
if inferencing: if inferencing:
model._cfg.training = False model._cfg.training = False
if cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace: if (cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace) or (cfg.fp8.enabled):
model.model = ml.replace_linear( model.model ) model.model = ml.replace_linear( model.model )
if backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer): if backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer):

View File

@ -372,6 +372,7 @@ def example_usage():
engine = Engine(model=model, optimizer=optimizer) engine = Engine(model=model, optimizer=optimizer)
# copy embeddings if requested # copy embeddings if requested
"""
if cfg.models._embeddings is not None: if cfg.models._embeddings is not None:
embeddings_path = cfg.relpath / cfg.models._embeddings embeddings_path = cfg.relpath / cfg.models._embeddings
@ -394,9 +395,10 @@ def example_usage():
continue continue
param.requires_grad_(False) param.requires_grad_(False)
engine._frozen_params.add(param) engine._frozen_params.add(param)
"""
# if cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace: if (cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace) or (cfg.fp8.enabled):
model.model = ml.replace_linear( model.model ) model.model = ml.replace_linear( model.model )
torch.save( { torch.save( {
'module': model.state_dict() 'module': model.state_dict()

View File

@ -70,7 +70,7 @@ def autocasts(input, from_dtype, to_dtype):
# handles temporarily upcasting 'index tensors' so torch will stop bitching # handles temporarily upcasting 'index tensors' so torch will stop bitching
def autocast_forward( func ): def autocast_forward( func ):
def wrapper( self, input, *args, **kwargs ): def wrapper( self, input, *args, **kwargs ):
with autocasts( input, [torch.int16, torch.int8, torch.uint8], torch.int32 ) as k: with autocasts( input, [torch.int16, torch.int8, torch.uint8, torch.float16, torch.bfloat16], torch.int32 ) as k:
return func( self, k, *args, **kwargs ) return func( self, k, *args, **kwargs )
return wrapper return wrapper
Embedding.forward = autocast_forward(Embedding.forward) Embedding.forward = autocast_forward(Embedding.forward)