diff --git a/tortoise/models/autoregressive.py b/tortoise/models/autoregressive.py
index 1207a5f..06f53b2 100755
--- a/tortoise/models/autoregressive.py
+++ b/tortoise/models/autoregressive.py
@@ -352,7 +352,7 @@ class UnifiedVoice(nn.Module):
         for module in embeddings:
             module.weight.data.normal_(mean=0.0, std=.02)
 
-    def post_init_gpt2_config(self, kv_cache=False):
+    def post_init_gpt2_config(self, use_deepspeed=False, kv_cache=False):
         seq_length = self.max_mel_tokens + self.max_text_tokens + 2
         gpt_config = GPT2Config(vocab_size=self.max_mel_tokens,
                                 n_positions=seq_length,
@@ -363,6 +363,17 @@ class UnifiedVoice(nn.Module):
                                 gradient_checkpointing=False,
                                 use_cache=True)
         self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head, kv_cache=kv_cache)
+        #print(f'use_deepspeed autoregressive_debug {use_deepspeed}')
+        if use_deepspeed and torch.cuda.is_available():
+            import deepspeed
+            self.ds_engine = deepspeed.init_inference(model=self.inference_model,  
+                                                    mp_size=1,
+                                                    replace_with_kernel_inject=True,
+                                                    dtype=torch.float32)
+            self.inference_model = self.ds_engine.module.eval()
+        else:
+            self.inference_model = self.inference_model.eval()
+			
         self.gpt.wte = self.mel_embedding
 
     def build_aligned_inputs_and_targets(self, input, start_token, stop_token):