From db49ad43ab25c7f7ddd28ee730d4fd27043e938f Mon Sep 17 00:00:00 2001 From: shadeMe Date: Thu, 1 Jun 2023 17:43:49 +0200 Subject: [PATCH] Add `device` parameter to `Embedding` --- bitsandbytes/nn/modules.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 92afea3..d8c19ff 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -92,6 +92,7 @@ class Embedding(torch.nn.Embedding): scale_grad_by_freq: bool = False, sparse: bool = False, _weight: Optional[Tensor] = None, + device: Optional[device] = None, ) -> None: super().__init__( num_embeddings, @@ -102,6 +103,7 @@ class Embedding(torch.nn.Embedding): scale_grad_by_freq, sparse, _weight, + device=device ) GlobalOptimManager.get_instance().register_module_override( self, "weight", {"optim_bits": 32}