fix xtransformers bug

This commit is contained in:
James Betker 2022-04-08 09:12:46 -06:00
parent 048f6f729a
commit 423293e518

View File

@ -1180,7 +1180,9 @@ class TransformerWrapper(nn.Module):
if use_cache:
res.append(intermediates.past_key_values)
return res
if len(res) > 1:
return tuple(res)
return res[0]
class ContinuousTransformerWrapper(nn.Module):
@ -1241,7 +1243,9 @@ class ContinuousTransformerWrapper(nn.Module):
if use_cache:
res.append(intermediates.past_key_values)
return tuple(res)
if len(res) > 1:
return tuple(res)
return res[0]
class XTransformer(nn.Module):