Add device parameter to Embedding

This commit is contained in:
shadeMe 2023-06-01 17:43:49 +02:00
parent 9cac5dd1b6
commit db49ad43ab
No known key found for this signature in database
GPG Key ID: 6FCA9FC635B2A402

View File

@ -92,6 +92,7 @@ class Embedding(torch.nn.Embedding):
scale_grad_by_freq: bool = False, scale_grad_by_freq: bool = False,
sparse: bool = False, sparse: bool = False,
_weight: Optional[Tensor] = None, _weight: Optional[Tensor] = None,
device: Optional[device] = None,
) -> None: ) -> None:
super().__init__( super().__init__(
num_embeddings, num_embeddings,
@ -102,6 +103,7 @@ class Embedding(torch.nn.Embedding):
scale_grad_by_freq, scale_grad_by_freq,
sparse, sparse,
_weight, _weight,
device=device
) )
GlobalOptimManager.get_instance().register_module_override( GlobalOptimManager.get_instance().register_module_override(
self, "weight", {"optim_bits": 32} self, "weight", {"optim_bits": 32}