diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 9622abb..35e163d 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -44,7 +44,7 @@ def load_engines(training=True): if inferencing: model._cfg.training = False - if cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace: + if (cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace) or (cfg.fp8.enabled): model.model = ml.replace_linear( model.model ) if backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer): diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 06dc3e4..3e4568b 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -372,6 +372,7 @@ def example_usage(): engine = Engine(model=model, optimizer=optimizer) # copy embeddings if requested + """ if cfg.models._embeddings is not None: embeddings_path = cfg.relpath / cfg.models._embeddings @@ -394,9 +395,10 @@ def example_usage(): continue param.requires_grad_(False) engine._frozen_params.add(param) + """ -# if cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace: - model.model = ml.replace_linear( model.model ) + if (cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace) or (cfg.fp8.enabled): + model.model = ml.replace_linear( model.model ) torch.save( { 'module': model.state_dict() diff --git a/vall_e/utils/wrapper.py b/vall_e/utils/wrapper.py index 3f92f29..e22c037 100755 --- a/vall_e/utils/wrapper.py +++ b/vall_e/utils/wrapper.py @@ -70,7 +70,7 @@ def autocasts(input, from_dtype, to_dtype): # handles temporarily upcasting 'index tensors' so torch will stop bitching def autocast_forward( func ): def wrapper( self, input, *args, **kwargs ): - with autocasts( input, [torch.int16, torch.int8, torch.uint8], torch.int32 ) as k: + with autocasts( input, [torch.int16, torch.int8, torch.uint8, torch.float16, torch.bfloat16], torch.int32 ) as k: return func( self, k, *args, **kwargs ) return wrapper Embedding.forward = autocast_forward(Embedding.forward)