oops
This commit is contained in:
parent
75b04686f8
commit
e33c4b0cb1
|
@ -173,8 +173,8 @@ def tree_map(fn: Callable, x):
|
|||
return x
|
||||
|
||||
|
||||
def to_device(x: T | None, **kwargs) -> T:
|
||||
def to_device(x: T | None, *args, **kwargs) -> T:
|
||||
if x is None:
|
||||
return
|
||||
|
||||
return tree_map(lambda t: t.to(**kwargs), x)
|
||||
return tree_map(lambda t: t.to(*args, **kwargs), x)
|
||||
|
|
Loading…
Reference in New Issue
Block a user