forked from mrq/bitsandbytes-rocm
Merge pull request #87 from lostmsu/main
Add `device` and `dtype` parameters to `StableEmbedding`
This commit is contained in:
commit
9d353ca786
|
@ -25,6 +25,8 @@ class StableEmbedding(torch.nn.Embedding):
|
||||||
scale_grad_by_freq: bool = False,
|
scale_grad_by_freq: bool = False,
|
||||||
sparse: bool = False,
|
sparse: bool = False,
|
||||||
_weight: Optional[Tensor] = None,
|
_weight: Optional[Tensor] = None,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_embeddings,
|
num_embeddings,
|
||||||
|
@ -35,8 +37,10 @@ class StableEmbedding(torch.nn.Embedding):
|
||||||
scale_grad_by_freq,
|
scale_grad_by_freq,
|
||||||
sparse,
|
sparse,
|
||||||
_weight,
|
_weight,
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
)
|
)
|
||||||
self.norm = torch.nn.LayerNorm(embedding_dim)
|
self.norm = torch.nn.LayerNorm(embedding_dim, device=device)
|
||||||
GlobalOptimManager.get_instance().register_module_override(
|
GlobalOptimManager.get_instance().register_module_override(
|
||||||
self, "weight", {"optim_bits": 32}
|
self, "weight", {"optim_bits": 32}
|
||||||
)
|
)
|
||||||
|
@ -68,7 +72,10 @@ class StableEmbedding(torch.nn.Embedding):
|
||||||
self.sparse,
|
self.sparse,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.norm(emb)
|
# always apply layer norm in full precision
|
||||||
|
emb = emb.to(torch.get_default_dtype())
|
||||||
|
|
||||||
|
return self.norm(emb).to(self.weight.dtype)
|
||||||
|
|
||||||
|
|
||||||
class Embedding(torch.nn.Embedding):
|
class Embedding(torch.nn.Embedding):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user