Add device
parameter to Embedding
This commit is contained in:
parent
9cac5dd1b6
commit
db49ad43ab
|
@ -92,6 +92,7 @@ class Embedding(torch.nn.Embedding):
|
||||||
scale_grad_by_freq: bool = False,
|
scale_grad_by_freq: bool = False,
|
||||||
sparse: bool = False,
|
sparse: bool = False,
|
||||||
_weight: Optional[Tensor] = None,
|
_weight: Optional[Tensor] = None,
|
||||||
|
device: Optional[device] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_embeddings,
|
num_embeddings,
|
||||||
|
@ -102,6 +103,7 @@ class Embedding(torch.nn.Embedding):
|
||||||
scale_grad_by_freq,
|
scale_grad_by_freq,
|
||||||
sparse,
|
sparse,
|
||||||
_weight,
|
_weight,
|
||||||
|
device=device
|
||||||
)
|
)
|
||||||
GlobalOptimManager.get_instance().register_module_override(
|
GlobalOptimManager.get_instance().register_module_override(
|
||||||
self, "weight", {"optim_bits": 32}
|
self, "weight", {"optim_bits": 32}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user