oops
This commit is contained in:
parent
69c1d2991f
commit
e5f9da2221
|
@ -690,9 +690,6 @@ class MixtralModel_Adapted(MixtralModel):
|
|||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
|
@ -750,7 +747,7 @@ class MixtralModel_Adapted(MixtralModel):
|
|||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
causal_mask,
|
||||
x_mask,
|
||||
is_causal,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
|
@ -763,7 +760,7 @@ class MixtralModel_Adapted(MixtralModel):
|
|||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
attention_mask=x_mask,
|
||||
is_causal=is_causal,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
|
|
Loading…
Reference in New Issue
Block a user