diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 4f82cdc..148ee80 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -38,6 +38,8 @@ class StableEmbedding(torch.nn.Embedding): scale_grad_by_freq: bool = False, sparse: bool = False, _weight: Optional[Tensor] = None, + device=None, + dtype=None, ) -> None: super(StableEmbedding, self).__init__( num_embeddings, @@ -48,8 +50,10 @@ class StableEmbedding(torch.nn.Embedding): scale_grad_by_freq, sparse, _weight, + device, + dtype, ) - self.norm = torch.nn.LayerNorm(embedding_dim) + self.norm = torch.nn.LayerNorm(embedding_dim, device=device) GlobalOptimManager.get_instance().register_module_override( self, "weight", {"optim_bits": 32} ) @@ -81,7 +85,10 @@ class StableEmbedding(torch.nn.Embedding): self.sparse, ) - return self.norm(emb) + # always apply layer norm in full precision + emb = emb.to(torch.get_default_dtype()) + + return self.norm(emb).to(self.weight.dtype) class Embedding(torch.nn.Embedding):