forked from mrq/DL-Art-School
fix xtransformers bug
This commit is contained in:
parent
048f6f729a
commit
423293e518
|
@ -1180,7 +1180,9 @@ class TransformerWrapper(nn.Module):
|
||||||
if use_cache:
|
if use_cache:
|
||||||
res.append(intermediates.past_key_values)
|
res.append(intermediates.past_key_values)
|
||||||
|
|
||||||
return res
|
if len(res) > 1:
|
||||||
|
return tuple(res)
|
||||||
|
return res[0]
|
||||||
|
|
||||||
|
|
||||||
class ContinuousTransformerWrapper(nn.Module):
|
class ContinuousTransformerWrapper(nn.Module):
|
||||||
|
@ -1241,7 +1243,9 @@ class ContinuousTransformerWrapper(nn.Module):
|
||||||
if use_cache:
|
if use_cache:
|
||||||
res.append(intermediates.past_key_values)
|
res.append(intermediates.past_key_values)
|
||||||
|
|
||||||
return tuple(res)
|
if len(res) > 1:
|
||||||
|
return tuple(res)
|
||||||
|
return res[0]
|
||||||
|
|
||||||
|
|
||||||
class XTransformer(nn.Module):
|
class XTransformer(nn.Module):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user