1
1
forked from mrq/tortoise-tts

add use_deepspeed

This commit is contained in:
ken11o2 2023-09-04 19:10:27 +00:00
parent b10c58436d
commit ac97c17bf7

View File

@ -352,7 +352,7 @@ class UnifiedVoice(nn.Module):
for module in embeddings:
module.weight.data.normal_(mean=0.0, std=.02)
def post_init_gpt2_config(self, kv_cache=False):
def post_init_gpt2_config(self, use_deepspeed=False, kv_cache=False):
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
gpt_config = GPT2Config(vocab_size=self.max_mel_tokens,
n_positions=seq_length,
@ -363,6 +363,17 @@ class UnifiedVoice(nn.Module):
gradient_checkpointing=False,
use_cache=True)
self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head, kv_cache=kv_cache)
#print(f'use_deepspeed autoregressive_debug {use_deepspeed}')
if use_deepspeed and torch.cuda.is_available():
import deepspeed
self.ds_engine = deepspeed.init_inference(model=self.inference_model,
mp_size=1,
replace_with_kernel_inject=True,
dtype=torch.float32)
self.inference_model = self.ds_engine.module.eval()
else:
self.inference_model = self.inference_model.eval()
self.gpt.wte = self.mel_embedding
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):