forked from mrq/tortoise-tts
add use_deepspeed
This commit is contained in:
parent
b10c58436d
commit
ac97c17bf7
|
@ -352,7 +352,7 @@ class UnifiedVoice(nn.Module):
|
||||||
for module in embeddings:
|
for module in embeddings:
|
||||||
module.weight.data.normal_(mean=0.0, std=.02)
|
module.weight.data.normal_(mean=0.0, std=.02)
|
||||||
|
|
||||||
def post_init_gpt2_config(self, kv_cache=False):
|
def post_init_gpt2_config(self, use_deepspeed=False, kv_cache=False):
|
||||||
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
|
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
|
||||||
gpt_config = GPT2Config(vocab_size=self.max_mel_tokens,
|
gpt_config = GPT2Config(vocab_size=self.max_mel_tokens,
|
||||||
n_positions=seq_length,
|
n_positions=seq_length,
|
||||||
|
@ -363,6 +363,17 @@ class UnifiedVoice(nn.Module):
|
||||||
gradient_checkpointing=False,
|
gradient_checkpointing=False,
|
||||||
use_cache=True)
|
use_cache=True)
|
||||||
self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head, kv_cache=kv_cache)
|
self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head, kv_cache=kv_cache)
|
||||||
|
#print(f'use_deepspeed autoregressive_debug {use_deepspeed}')
|
||||||
|
if use_deepspeed and torch.cuda.is_available():
|
||||||
|
import deepspeed
|
||||||
|
self.ds_engine = deepspeed.init_inference(model=self.inference_model,
|
||||||
|
mp_size=1,
|
||||||
|
replace_with_kernel_inject=True,
|
||||||
|
dtype=torch.float32)
|
||||||
|
self.inference_model = self.ds_engine.module.eval()
|
||||||
|
else:
|
||||||
|
self.inference_model = self.inference_model.eval()
|
||||||
|
|
||||||
self.gpt.wte = self.mel_embedding
|
self.gpt.wte = self.mel_embedding
|
||||||
|
|
||||||
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
|
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user