From e5f9da22216afe6eb3fbc531eba22a695f0b6b45 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 21 Jan 2025 11:59:24 -0600 Subject: [PATCH] oops --- vall_e/models/arch/mixtral.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/vall_e/models/arch/mixtral.py b/vall_e/models/arch/mixtral.py index 6462dab..9bb736e 100644 --- a/vall_e/models/arch/mixtral.py +++ b/vall_e/models/arch/mixtral.py @@ -690,9 +690,6 @@ class MixtralModel_Adapted(MixtralModel): if self.gradient_checkpointing and self.training: if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) use_cache = False if use_cache and past_key_values is None: @@ -750,7 +747,7 @@ class MixtralModel_Adapted(MixtralModel): layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - causal_mask, + x_mask, is_causal, position_ids, past_key_values, @@ -763,7 +760,7 @@ class MixtralModel_Adapted(MixtralModel): else: layer_outputs = decoder_layer( hidden_states, - attention_mask=causal_mask, + attention_mask=x_mask, is_causal=is_causal, position_ids=position_ids, past_key_value=past_key_values,