Add device
parameter to Embedding
This commit is contained in:
parent
9cac5dd1b6
commit
db49ad43ab
|
@ -92,6 +92,7 @@ class Embedding(torch.nn.Embedding):
|
|||
scale_grad_by_freq: bool = False,
|
||||
sparse: bool = False,
|
||||
_weight: Optional[Tensor] = None,
|
||||
device: Optional[device] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
num_embeddings,
|
||||
|
@ -102,6 +103,7 @@ class Embedding(torch.nn.Embedding):
|
|||
scale_grad_by_freq,
|
||||
sparse,
|
||||
_weight,
|
||||
device=device
|
||||
)
|
||||
GlobalOptimManager.get_instance().register_module_override(
|
||||
self, "weight", {"optim_bits": 32}
|
||||
|
|
Loading…
Reference in New Issue
Block a user