This commit is contained in:
mrq 2025-01-21 11:59:24 -06:00
parent 69c1d2991f
commit e5f9da2221

View File

@ -690,9 +690,6 @@ class MixtralModel_Adapted(MixtralModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if use_cache and past_key_values is None:
@ -750,7 +747,7 @@ class MixtralModel_Adapted(MixtralModel):
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
x_mask,
is_causal,
position_ids,
past_key_values,
@ -763,7 +760,7 @@ class MixtralModel_Adapted(MixtralModel):
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
attention_mask=x_mask,
is_causal=is_causal,
position_ids=position_ids,
past_key_value=past_key_values,