i swear it worked before and now it didnt

This commit is contained in:
mrq 2024-06-25 19:17:14 -05:00
parent 6ee5f21ddc
commit 20789a0b8a

View File

@ -431,7 +431,7 @@ class NewGenerationMixin(GenerationMixin):
elif is_sample_gen_mode:
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config, input_ids.device)
logits_warper = self._get_logits_warper(generation_config) #, input_ids.device)
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
@ -458,7 +458,7 @@ class NewGenerationMixin(GenerationMixin):
)
elif is_sample_gen_stream_mode:
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config, input_ids.device)
logits_warper = self._get_logits_warper(generation_config) #, input_ids.device)
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
@ -524,7 +524,7 @@ class NewGenerationMixin(GenerationMixin):
elif is_beam_sample_gen_mode:
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config, input_ids.device)
logits_warper = self._get_logits_warper(generation_config) #, input_ids.device)
if stopping_criteria.max_length is None:
raise ValueError(