diff --git a/tortoise/models/autoregressive.py b/tortoise/models/autoregressive.py index 1207a5f..06f53b2 100755 --- a/tortoise/models/autoregressive.py +++ b/tortoise/models/autoregressive.py @@ -352,7 +352,7 @@ class UnifiedVoice(nn.Module): for module in embeddings: module.weight.data.normal_(mean=0.0, std=.02) - def post_init_gpt2_config(self, kv_cache=False): + def post_init_gpt2_config(self, use_deepspeed=False, kv_cache=False): seq_length = self.max_mel_tokens + self.max_text_tokens + 2 gpt_config = GPT2Config(vocab_size=self.max_mel_tokens, n_positions=seq_length, @@ -363,6 +363,17 @@ class UnifiedVoice(nn.Module): gradient_checkpointing=False, use_cache=True) self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head, kv_cache=kv_cache) + #print(f'use_deepspeed autoregressive_debug {use_deepspeed}') + if use_deepspeed and torch.cuda.is_available(): + import deepspeed + self.ds_engine = deepspeed.init_inference(model=self.inference_model, + mp_size=1, + replace_with_kernel_inject=True, + dtype=torch.float32) + self.inference_model = self.ds_engine.module.eval() + else: + self.inference_model = self.inference_model.eval() + self.gpt.wte = self.mel_embedding def build_aligned_inputs_and_targets(self, input, start_token, stop_token):