diff --git a/tortoise_tts/models/stream_generator.py b/tortoise_tts/models/stream_generator.py index 5f41407..4cad80a 100644 --- a/tortoise_tts/models/stream_generator.py +++ b/tortoise_tts/models/stream_generator.py @@ -431,7 +431,7 @@ class NewGenerationMixin(GenerationMixin): elif is_sample_gen_mode: # 11. prepare logits warper - logits_warper = self._get_logits_warper(generation_config, input_ids.device) + logits_warper = self._get_logits_warper(generation_config) #, input_ids.device) # 12. expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( @@ -458,7 +458,7 @@ class NewGenerationMixin(GenerationMixin): ) elif is_sample_gen_stream_mode: # 11. prepare logits warper - logits_warper = self._get_logits_warper(generation_config, input_ids.device) + logits_warper = self._get_logits_warper(generation_config) #, input_ids.device) # 12. expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( @@ -524,7 +524,7 @@ class NewGenerationMixin(GenerationMixin): elif is_beam_sample_gen_mode: # 11. prepare logits warper - logits_warper = self._get_logits_warper(generation_config, input_ids.device) + logits_warper = self._get_logits_warper(generation_config) #, input_ids.device) if stopping_criteria.max_length is None: raise ValueError(