add device and dtype parameters to StableEmbedding
This commit is contained in:
parent
1efb87d89d
commit
62d39a237c
|
@ -38,6 +38,8 @@ class StableEmbedding(torch.nn.Embedding):
|
|||
scale_grad_by_freq: bool = False,
|
||||
sparse: bool = False,
|
||||
_weight: Optional[Tensor] = None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
super(StableEmbedding, self).__init__(
|
||||
num_embeddings,
|
||||
|
@ -48,8 +50,10 @@ class StableEmbedding(torch.nn.Embedding):
|
|||
scale_grad_by_freq,
|
||||
sparse,
|
||||
_weight,
|
||||
device,
|
||||
dtype,
|
||||
)
|
||||
self.norm = torch.nn.LayerNorm(embedding_dim)
|
||||
self.norm = torch.nn.LayerNorm(embedding_dim, device=device)
|
||||
GlobalOptimManager.get_instance().register_module_override(
|
||||
self, "weight", {"optim_bits": 32}
|
||||
)
|
||||
|
@ -81,7 +85,10 @@ class StableEmbedding(torch.nn.Embedding):
|
|||
self.sparse,
|
||||
)
|
||||
|
||||
return self.norm(emb)
|
||||
# always apply layer norm in full precision
|
||||
emb = emb.to(torch.get_default_dtype())
|
||||
|
||||
return self.norm(emb).to(self.weight.dtype)
|
||||
|
||||
|
||||
class Embedding(torch.nn.Embedding):
|
||||
|
|
Loading…
Reference in New Issue
Block a user