Clone and detach in recursively_detach

This commit is contained in:
James Betker 2020-10-07 12:41:00 -06:00
parent 1c44d395af
commit a62a5dbb5f

View File

@ -356,7 +356,7 @@ class ProgressBar(object):
# Recursively detaches all tensors in a tree of lists, dicts and tuples and returns the same structure. # Recursively detaches all tensors in a tree of lists, dicts and tuples and returns the same structure.
def recursively_detach(v): def recursively_detach(v):
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
return v.detach() return v.detach().clone()
elif isinstance(v, list) or isinstance(v, tuple): elif isinstance(v, list) or isinstance(v, tuple):
out = [recursively_detach(i) for i in v] out = [recursively_detach(i) for i in v]
if isinstance(v, tuple): if isinstance(v, tuple):