diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 92afea3..d8c19ff 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -92,6 +92,7 @@ class Embedding(torch.nn.Embedding): scale_grad_by_freq: bool = False, sparse: bool = False, _weight: Optional[Tensor] = None, + device: Optional[device] = None, ) -> None: super().__init__( num_embeddings, @@ -102,6 +103,7 @@ class Embedding(torch.nn.Embedding): scale_grad_by_freq, sparse, _weight, + device=device ) GlobalOptimManager.get_instance().register_module_override( self, "weight", {"optim_bits": 32}