actually do the Linear replacement with TE's Linear

This commit is contained in:
mrq 2024-04-09 14:41:13 -05:00
parent 9d97eb5104
commit 4d75ee066c
3 changed files with 6 additions and 4 deletions

View File

@ -44,7 +44,7 @@ def load_engines(training=True):
if inferencing:
model._cfg.training = False
if cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace:
if (cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace) or (cfg.fp8.enabled):
model.model = ml.replace_linear( model.model )
if backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer):

View File

@ -372,6 +372,7 @@ def example_usage():
engine = Engine(model=model, optimizer=optimizer)
# copy embeddings if requested
"""
if cfg.models._embeddings is not None:
embeddings_path = cfg.relpath / cfg.models._embeddings
@ -394,9 +395,10 @@ def example_usage():
continue
param.requires_grad_(False)
engine._frozen_params.add(param)
"""
# if cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace:
model.model = ml.replace_linear( model.model )
if (cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace) or (cfg.fp8.enabled):
model.model = ml.replace_linear( model.model )
torch.save( {
'module': model.state_dict()

View File

@ -70,7 +70,7 @@ def autocasts(input, from_dtype, to_dtype):
# handles temporarily upcasting 'index tensors' so torch will stop bitching
def autocast_forward( func ):
def wrapper( self, input, *args, **kwargs ):
with autocasts( input, [torch.int16, torch.int8, torch.uint8], torch.int32 ) as k:
with autocasts( input, [torch.int16, torch.int8, torch.uint8, torch.float16, torch.bfloat16], torch.int32 ) as k:
return func( self, k, *args, **kwargs )
return wrapper
Embedding.forward = autocast_forward(Embedding.forward)