This commit is contained in:
mrq 2024-07-22 19:38:39 -05:00
parent 75b04686f8
commit e33c4b0cb1

View File

@ -173,8 +173,8 @@ def tree_map(fn: Callable, x):
return x
def to_device(x: T | None, **kwargs) -> T:
def to_device(x: T | None, *args, **kwargs) -> T:
if x is None:
return
return tree_map(lambda t: t.to(**kwargs), x)
return tree_map(lambda t: t.to(*args, **kwargs), x)