diff --git a/codes/models/lucidrains/x_transformers.py b/codes/models/lucidrains/x_transformers.py index 6027b38f..7f389dba 100644 --- a/codes/models/lucidrains/x_transformers.py +++ b/codes/models/lucidrains/x_transformers.py @@ -1180,7 +1180,9 @@ class TransformerWrapper(nn.Module): if use_cache: res.append(intermediates.past_key_values) - return res + if len(res) > 1: + return tuple(res) + return res[0] class ContinuousTransformerWrapper(nn.Module): @@ -1241,7 +1243,9 @@ class ContinuousTransformerWrapper(nn.Module): if use_cache: res.append(intermediates.past_key_values) - return tuple(res) + if len(res) > 1: + return tuple(res) + return res[0] class XTransformer(nn.Module):