actually do the Linear replacement with TE's Linear
This commit is contained in:
parent
9d97eb5104
commit
4d75ee066c
|
@ -44,7 +44,7 @@ def load_engines(training=True):
|
|||
if inferencing:
|
||||
model._cfg.training = False
|
||||
|
||||
if cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace:
|
||||
if (cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace) or (cfg.fp8.enabled):
|
||||
model.model = ml.replace_linear( model.model )
|
||||
|
||||
if backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer):
|
||||
|
|
|
@ -372,6 +372,7 @@ def example_usage():
|
|||
engine = Engine(model=model, optimizer=optimizer)
|
||||
|
||||
# copy embeddings if requested
|
||||
"""
|
||||
if cfg.models._embeddings is not None:
|
||||
embeddings_path = cfg.relpath / cfg.models._embeddings
|
||||
|
||||
|
@ -394,9 +395,10 @@ def example_usage():
|
|||
continue
|
||||
param.requires_grad_(False)
|
||||
engine._frozen_params.add(param)
|
||||
"""
|
||||
|
||||
# if cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace:
|
||||
model.model = ml.replace_linear( model.model )
|
||||
if (cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace) or (cfg.fp8.enabled):
|
||||
model.model = ml.replace_linear( model.model )
|
||||
|
||||
torch.save( {
|
||||
'module': model.state_dict()
|
||||
|
|
|
@ -70,7 +70,7 @@ def autocasts(input, from_dtype, to_dtype):
|
|||
# handles temporarily upcasting 'index tensors' so torch will stop bitching
|
||||
def autocast_forward( func ):
|
||||
def wrapper( self, input, *args, **kwargs ):
|
||||
with autocasts( input, [torch.int16, torch.int8, torch.uint8], torch.int32 ) as k:
|
||||
with autocasts( input, [torch.int16, torch.int8, torch.uint8, torch.float16, torch.bfloat16], torch.int32 ) as k:
|
||||
return func( self, k, *args, **kwargs )
|
||||
return wrapper
|
||||
Embedding.forward = autocast_forward(Embedding.forward)
|
||||
|
|
Loading…
Reference in New Issue
Block a user