forked from mrq/tortoise-tts
add use_deepspeed to contructor and update method post_init_gpt2_config
This commit is contained in:
parent
ac97c17bf7
commit
18adfaf785
|
@ -259,7 +259,8 @@ class TextToSpeech:
|
||||||
unsqueeze_sample_batches=False,
|
unsqueeze_sample_batches=False,
|
||||||
input_sample_rate=22050, output_sample_rate=24000,
|
input_sample_rate=22050, output_sample_rate=24000,
|
||||||
autoregressive_model_path=None, diffusion_model_path=None, vocoder_model=None, tokenizer_json=None,
|
autoregressive_model_path=None, diffusion_model_path=None, vocoder_model=None, tokenizer_json=None,
|
||||||
):
|
# ):
|
||||||
|
use_deepspeed=False): # Add use_deepspeed parameter
|
||||||
"""
|
"""
|
||||||
Constructor
|
Constructor
|
||||||
:param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
|
:param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
|
||||||
|
@ -280,7 +281,8 @@ class TextToSpeech:
|
||||||
self.output_sample_rate = output_sample_rate
|
self.output_sample_rate = output_sample_rate
|
||||||
self.minor_optimizations = minor_optimizations
|
self.minor_optimizations = minor_optimizations
|
||||||
self.unsqueeze_sample_batches = unsqueeze_sample_batches
|
self.unsqueeze_sample_batches = unsqueeze_sample_batches
|
||||||
|
self.use_deepspeed = use_deepspeed # Store use_deepspeed as an instance variable
|
||||||
|
print(f'use_deepspeed api_debug {use_deepspeed}')
|
||||||
# for clarity, it's simpler to split these up and just predicate them on requesting VRAM-consuming optimizations
|
# for clarity, it's simpler to split these up and just predicate them on requesting VRAM-consuming optimizations
|
||||||
self.preloaded_tensors = minor_optimizations
|
self.preloaded_tensors = minor_optimizations
|
||||||
self.use_kv_cache = minor_optimizations
|
self.use_kv_cache = minor_optimizations
|
||||||
|
@ -359,7 +361,7 @@ class TextToSpeech:
|
||||||
heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False,
|
heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False,
|
||||||
train_solo_embeddings=False).cpu().eval()
|
train_solo_embeddings=False).cpu().eval()
|
||||||
self.autoregressive.load_state_dict(torch.load(self.autoregressive_model_path))
|
self.autoregressive.load_state_dict(torch.load(self.autoregressive_model_path))
|
||||||
self.autoregressive.post_init_gpt2_config(kv_cache=self.use_kv_cache)
|
self.autoregressive.post_init_gpt2_config(use_deepspeed=self.use_deepspeed, kv_cache=self.use_kv_cache)
|
||||||
if self.preloaded_tensors:
|
if self.preloaded_tensors:
|
||||||
self.autoregressive = migrate_to_device( self.autoregressive, self.device )
|
self.autoregressive = migrate_to_device( self.autoregressive, self.device )
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user