forked from mrq/tortoise-tts
Merge pull request 'main' (#47) from ken11o2/tortoise-tts:main into main
Reviewed-on: mrq/tortoise-tts#47
This commit is contained in:
commit
d7e6914fb8
|
@ -259,7 +259,8 @@ class TextToSpeech:
|
||||||
unsqueeze_sample_batches=False,
|
unsqueeze_sample_batches=False,
|
||||||
input_sample_rate=22050, output_sample_rate=24000,
|
input_sample_rate=22050, output_sample_rate=24000,
|
||||||
autoregressive_model_path=None, diffusion_model_path=None, vocoder_model=None, tokenizer_json=None,
|
autoregressive_model_path=None, diffusion_model_path=None, vocoder_model=None, tokenizer_json=None,
|
||||||
):
|
# ):
|
||||||
|
use_deepspeed=False): # Add use_deepspeed parameter
|
||||||
"""
|
"""
|
||||||
Constructor
|
Constructor
|
||||||
:param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
|
:param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
|
||||||
|
@ -280,7 +281,8 @@ class TextToSpeech:
|
||||||
self.output_sample_rate = output_sample_rate
|
self.output_sample_rate = output_sample_rate
|
||||||
self.minor_optimizations = minor_optimizations
|
self.minor_optimizations = minor_optimizations
|
||||||
self.unsqueeze_sample_batches = unsqueeze_sample_batches
|
self.unsqueeze_sample_batches = unsqueeze_sample_batches
|
||||||
|
self.use_deepspeed = use_deepspeed # Store use_deepspeed as an instance variable
|
||||||
|
print(f'use_deepspeed api_debug {use_deepspeed}')
|
||||||
# for clarity, it's simpler to split these up and just predicate them on requesting VRAM-consuming optimizations
|
# for clarity, it's simpler to split these up and just predicate them on requesting VRAM-consuming optimizations
|
||||||
self.preloaded_tensors = minor_optimizations
|
self.preloaded_tensors = minor_optimizations
|
||||||
self.use_kv_cache = minor_optimizations
|
self.use_kv_cache = minor_optimizations
|
||||||
|
@ -359,7 +361,7 @@ class TextToSpeech:
|
||||||
heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False,
|
heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False,
|
||||||
train_solo_embeddings=False).cpu().eval()
|
train_solo_embeddings=False).cpu().eval()
|
||||||
self.autoregressive.load_state_dict(torch.load(self.autoregressive_model_path))
|
self.autoregressive.load_state_dict(torch.load(self.autoregressive_model_path))
|
||||||
self.autoregressive.post_init_gpt2_config(kv_cache=self.use_kv_cache)
|
self.autoregressive.post_init_gpt2_config(use_deepspeed=self.use_deepspeed, kv_cache=self.use_kv_cache)
|
||||||
if self.preloaded_tensors:
|
if self.preloaded_tensors:
|
||||||
self.autoregressive = migrate_to_device( self.autoregressive, self.device )
|
self.autoregressive = migrate_to_device( self.autoregressive, self.device )
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,7 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--voice', type=str, help='Selects the voice to use for generation. See options in voices/ directory (and add your own!) '
|
parser.add_argument('--voice', type=str, help='Selects the voice to use for generation. See options in voices/ directory (and add your own!) '
|
||||||
'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.', default='random')
|
'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.', default='random')
|
||||||
parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='standard')
|
parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='standard')
|
||||||
|
parser.add_argument('--use_deepspeed', type=bool, help='Use deepspeed for speed bump.', default=True)
|
||||||
parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/')
|
parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/')
|
||||||
parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this'
|
parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this'
|
||||||
'should only be specified if you have custom checkpoints.', default=MODELS_DIR)
|
'should only be specified if you have custom checkpoints.', default=MODELS_DIR)
|
||||||
|
@ -37,8 +38,8 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
|
|
||||||
os.makedirs(args.output_path, exist_ok=True)
|
os.makedirs(args.output_path, exist_ok=True)
|
||||||
|
#print(f'use_deepspeed do_tts_debug {use_deepspeed}')
|
||||||
tts = TextToSpeech(models_dir=args.model_dir)
|
tts = TextToSpeech(models_dir=args.model_dir, use_deepspeed=args.use_deepspeed)
|
||||||
|
|
||||||
selected_voices = args.voice.split(',')
|
selected_voices = args.voice.split(',')
|
||||||
for k, selected_voice in enumerate(selected_voices):
|
for k, selected_voice in enumerate(selected_voices):
|
||||||
|
|
|
@ -352,7 +352,7 @@ class UnifiedVoice(nn.Module):
|
||||||
for module in embeddings:
|
for module in embeddings:
|
||||||
module.weight.data.normal_(mean=0.0, std=.02)
|
module.weight.data.normal_(mean=0.0, std=.02)
|
||||||
|
|
||||||
def post_init_gpt2_config(self, kv_cache=False):
|
def post_init_gpt2_config(self, use_deepspeed=False, kv_cache=False):
|
||||||
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
|
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
|
||||||
gpt_config = GPT2Config(vocab_size=self.max_mel_tokens,
|
gpt_config = GPT2Config(vocab_size=self.max_mel_tokens,
|
||||||
n_positions=seq_length,
|
n_positions=seq_length,
|
||||||
|
@ -363,6 +363,17 @@ class UnifiedVoice(nn.Module):
|
||||||
gradient_checkpointing=False,
|
gradient_checkpointing=False,
|
||||||
use_cache=True)
|
use_cache=True)
|
||||||
self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head, kv_cache=kv_cache)
|
self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head, kv_cache=kv_cache)
|
||||||
|
#print(f'use_deepspeed autoregressive_debug {use_deepspeed}')
|
||||||
|
if use_deepspeed and torch.cuda.is_available():
|
||||||
|
import deepspeed
|
||||||
|
self.ds_engine = deepspeed.init_inference(model=self.inference_model,
|
||||||
|
mp_size=1,
|
||||||
|
replace_with_kernel_inject=True,
|
||||||
|
dtype=torch.float32)
|
||||||
|
self.inference_model = self.ds_engine.module.eval()
|
||||||
|
else:
|
||||||
|
self.inference_model = self.inference_model.eval()
|
||||||
|
|
||||||
self.gpt.wte = self.mel_embedding
|
self.gpt.wte = self.mel_embedding
|
||||||
|
|
||||||
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
|
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
|
||||||
|
|
|
@ -17,6 +17,7 @@ if __name__ == '__main__':
|
||||||
'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.', default='pat')
|
'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.', default='pat')
|
||||||
parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/longform/')
|
parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/longform/')
|
||||||
parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='standard')
|
parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='standard')
|
||||||
|
parser.add_argument('--use_deepspeed', type=bool, help='Use deepspeed for speed bump.', default=True)
|
||||||
parser.add_argument('--regenerate', type=str, help='Comma-separated list of clip numbers to re-generate, or nothing.', default=None)
|
parser.add_argument('--regenerate', type=str, help='Comma-separated list of clip numbers to re-generate, or nothing.', default=None)
|
||||||
parser.add_argument('--candidates', type=int, help='How many output candidates to produce per-voice. Only the first candidate is actually used in the final product, the others can be used manually.', default=1)
|
parser.add_argument('--candidates', type=int, help='How many output candidates to produce per-voice. Only the first candidate is actually used in the final product, the others can be used manually.', default=1)
|
||||||
parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this'
|
parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this'
|
||||||
|
@ -25,7 +26,7 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--produce_debug_state', type=bool, help='Whether or not to produce debug_state.pth, which can aid in reproducing problems. Defaults to true.', default=True)
|
parser.add_argument('--produce_debug_state', type=bool, help='Whether or not to produce debug_state.pth, which can aid in reproducing problems. Defaults to true.', default=True)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
tts = TextToSpeech(models_dir=args.model_dir)
|
tts = TextToSpeech(models_dir=args.model_dir, use_deepspeed=args.use_deepspeed)
|
||||||
|
|
||||||
outpath = args.output_path
|
outpath = args.output_path
|
||||||
selected_voices = args.voice.split(',')
|
selected_voices = args.voice.split(',')
|
||||||
|
|
Loading…
Reference in New Issue
Block a user