From 423293e51851e1398265bfe89254103f7a2ee786 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 8 Apr 2022 09:12:46 -0600 Subject: [PATCH] fix xtransformers bug --- codes/models/lucidrains/x_transformers.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/codes/models/lucidrains/x_transformers.py b/codes/models/lucidrains/x_transformers.py index 6027b38f..7f389dba 100644 --- a/codes/models/lucidrains/x_transformers.py +++ b/codes/models/lucidrains/x_transformers.py @@ -1180,7 +1180,9 @@ class TransformerWrapper(nn.Module): if use_cache: res.append(intermediates.past_key_values) - return res + if len(res) > 1: + return tuple(res) + return res[0] class ContinuousTransformerWrapper(nn.Module): @@ -1241,7 +1243,9 @@ class ContinuousTransformerWrapper(nn.Module): if use_cache: res.append(intermediates.past_key_values) - return tuple(res) + if len(res) > 1: + return tuple(res) + return res[0] class XTransformer(nn.Module):