diff --git a/vall_e/utils/utils.py b/vall_e/utils/utils.py index 99d49eb..37d8ac4 100755 --- a/vall_e/utils/utils.py +++ b/vall_e/utils/utils.py @@ -173,8 +173,8 @@ def tree_map(fn: Callable, x): return x -def to_device(x: T | None, **kwargs) -> T: +def to_device(x: T | None, *args, **kwargs) -> T: if x is None: return - return tree_map(lambda t: t.to(**kwargs), x) + return tree_map(lambda t: t.to(*args, **kwargs), x)