From 18adfaf78529cbf66bca37cdf44f0bb4e0cf9159 Mon Sep 17 00:00:00 2001 From: ken11o2 Date: Mon, 4 Sep 2023 19:12:13 +0000 Subject: [PATCH] add use_deepspeed to contructor and update method post_init_gpt2_config --- tortoise/api.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tortoise/api.py b/tortoise/api.py index c8691d8..88acb40 100755 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -259,7 +259,8 @@ class TextToSpeech: unsqueeze_sample_batches=False, input_sample_rate=22050, output_sample_rate=24000, autoregressive_model_path=None, diffusion_model_path=None, vocoder_model=None, tokenizer_json=None, - ): +# ): + use_deepspeed=False): # Add use_deepspeed parameter """ Constructor :param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing @@ -280,7 +281,8 @@ class TextToSpeech: self.output_sample_rate = output_sample_rate self.minor_optimizations = minor_optimizations self.unsqueeze_sample_batches = unsqueeze_sample_batches - + self.use_deepspeed = use_deepspeed # Store use_deepspeed as an instance variable + print(f'use_deepspeed api_debug {use_deepspeed}') # for clarity, it's simpler to split these up and just predicate them on requesting VRAM-consuming optimizations self.preloaded_tensors = minor_optimizations self.use_kv_cache = minor_optimizations @@ -359,7 +361,7 @@ class TextToSpeech: heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False, train_solo_embeddings=False).cpu().eval() self.autoregressive.load_state_dict(torch.load(self.autoregressive_model_path)) - self.autoregressive.post_init_gpt2_config(kv_cache=self.use_kv_cache) + self.autoregressive.post_init_gpt2_config(use_deepspeed=self.use_deepspeed, kv_cache=self.use_kv_cache) if self.preloaded_tensors: self.autoregressive = migrate_to_device( self.autoregressive, self.device )