Clone and detach in recursively_detach

This commit is contained in:
James Betker 2020-10-07 12:41:00 -06:00
parent 1c44d395af
commit a62a5dbb5f

View File

@ -356,7 +356,7 @@ class ProgressBar(object):
# Recursively detaches all tensors in a tree of lists, dicts and tuples and returns the same structure.
def recursively_detach(v):
if isinstance(v, torch.Tensor):
return v.detach()
return v.detach().clone()
elif isinstance(v, list) or isinstance(v, tuple):
out = [recursively_detach(i) for i in v]
if isinstance(v, tuple):