actually do the Linear replacement with TE's Linear
This commit is contained in:
parent
9d97eb5104
commit
4d75ee066c
@ -44,7 +44,7 @@ def load_engines(training=True):
|
|||||||
if inferencing:
|
if inferencing:
|
||||||
model._cfg.training = False
|
model._cfg.training = False
|
||||||
|
|
||||||
if cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace:
|
if (cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace) or (cfg.fp8.enabled):
|
||||||
model.model = ml.replace_linear( model.model )
|
model.model = ml.replace_linear( model.model )
|
||||||
|
|
||||||
if backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer):
|
if backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer):
|
||||||
|
|||||||
@ -372,6 +372,7 @@ def example_usage():
|
|||||||
engine = Engine(model=model, optimizer=optimizer)
|
engine = Engine(model=model, optimizer=optimizer)
|
||||||
|
|
||||||
# copy embeddings if requested
|
# copy embeddings if requested
|
||||||
|
"""
|
||||||
if cfg.models._embeddings is not None:
|
if cfg.models._embeddings is not None:
|
||||||
embeddings_path = cfg.relpath / cfg.models._embeddings
|
embeddings_path = cfg.relpath / cfg.models._embeddings
|
||||||
|
|
||||||
@ -394,9 +395,10 @@ def example_usage():
|
|||||||
continue
|
continue
|
||||||
param.requires_grad_(False)
|
param.requires_grad_(False)
|
||||||
engine._frozen_params.add(param)
|
engine._frozen_params.add(param)
|
||||||
|
"""
|
||||||
|
|
||||||
# if cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace:
|
if (cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace) or (cfg.fp8.enabled):
|
||||||
model.model = ml.replace_linear( model.model )
|
model.model = ml.replace_linear( model.model )
|
||||||
|
|
||||||
torch.save( {
|
torch.save( {
|
||||||
'module': model.state_dict()
|
'module': model.state_dict()
|
||||||
|
|||||||
@ -70,7 +70,7 @@ def autocasts(input, from_dtype, to_dtype):
|
|||||||
# handles temporarily upcasting 'index tensors' so torch will stop bitching
|
# handles temporarily upcasting 'index tensors' so torch will stop bitching
|
||||||
def autocast_forward( func ):
|
def autocast_forward( func ):
|
||||||
def wrapper( self, input, *args, **kwargs ):
|
def wrapper( self, input, *args, **kwargs ):
|
||||||
with autocasts( input, [torch.int16, torch.int8, torch.uint8], torch.int32 ) as k:
|
with autocasts( input, [torch.int16, torch.int8, torch.uint8, torch.float16, torch.bfloat16], torch.int32 ) as k:
|
||||||
return func( self, k, *args, **kwargs )
|
return func( self, k, *args, **kwargs )
|
||||||
return wrapper
|
return wrapper
|
||||||
Embedding.forward = autocast_forward(Embedding.forward)
|
Embedding.forward = autocast_forward(Embedding.forward)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user