fix xtransformers bug
This commit is contained in:
parent
048f6f729a
commit
423293e518
|
@ -1180,7 +1180,9 @@ class TransformerWrapper(nn.Module):
|
|||
if use_cache:
|
||||
res.append(intermediates.past_key_values)
|
||||
|
||||
return res
|
||||
if len(res) > 1:
|
||||
return tuple(res)
|
||||
return res[0]
|
||||
|
||||
|
||||
class ContinuousTransformerWrapper(nn.Module):
|
||||
|
@ -1241,7 +1243,9 @@ class ContinuousTransformerWrapper(nn.Module):
|
|||
if use_cache:
|
||||
res.append(intermediates.past_key_values)
|
||||
|
||||
return tuple(res)
|
||||
if len(res) > 1:
|
||||
return tuple(res)
|
||||
return res[0]
|
||||
|
||||
|
||||
class XTransformer(nn.Module):
|
||||
|
|
Loading…
Reference in New Issue
Block a user