Add device parameter to Embedding

This commit is contained in:
shadeMe 2023-06-01 17:43:49 +02:00
parent 9cac5dd1b6
commit db49ad43ab
No known key found for this signature in database
GPG Key ID: 6FCA9FC635B2A402

View File

@ -92,6 +92,7 @@ class Embedding(torch.nn.Embedding):
scale_grad_by_freq: bool = False,
sparse: bool = False,
_weight: Optional[Tensor] = None,
device: Optional[device] = None,
) -> None:
super().__init__(
num_embeddings,
@ -102,6 +103,7 @@ class Embedding(torch.nn.Embedding):
scale_grad_by_freq,
sparse,
_weight,
device=device
)
GlobalOptimManager.get_instance().register_module_override(
self, "weight", {"optim_bits": 32}