Merge pull request #87 from lostmsu/main

Add `device` and `dtype` parameters to `StableEmbedding`
This commit is contained in:
Tim Dettmers 2023-01-02 13:22:45 +01:00 committed by GitHub
commit 9d353ca786
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -25,6 +25,8 @@ class StableEmbedding(torch.nn.Embedding):
scale_grad_by_freq: bool = False,
sparse: bool = False,
_weight: Optional[Tensor] = None,
device=None,
dtype=None,
) -> None:
super().__init__(
num_embeddings,
@ -35,8 +37,10 @@ class StableEmbedding(torch.nn.Embedding):
scale_grad_by_freq,
sparse,
_weight,
device,
dtype,
)
self.norm = torch.nn.LayerNorm(embedding_dim)
self.norm = torch.nn.LayerNorm(embedding_dim, device=device)
GlobalOptimManager.get_instance().register_module_override(
self, "weight", {"optim_bits": 32}
)
@ -68,7 +72,10 @@ class StableEmbedding(torch.nn.Embedding):
self.sparse,
)
return self.norm(emb)
# always apply layer norm in full precision
emb = emb.to(torch.get_default_dtype())
return self.norm(emb).to(self.weight.dtype)
class Embedding(torch.nn.Embedding):