forked from mrq/tortoise-tts
Compare commits
7 Commits
Author | SHA1 | Date | |
---|---|---|---|
95f679f4ba | |||
bf3b6c87aa | |||
|
d7e6914fb8 | ||
|
b7c7fd1c5f | ||
|
2478dc255e | ||
|
18adfaf785 | ||
|
ac97c17bf7 |
|
@ -259,7 +259,8 @@ class TextToSpeech:
|
||||||
unsqueeze_sample_batches=False,
|
unsqueeze_sample_batches=False,
|
||||||
input_sample_rate=22050, output_sample_rate=24000,
|
input_sample_rate=22050, output_sample_rate=24000,
|
||||||
autoregressive_model_path=None, diffusion_model_path=None, vocoder_model=None, tokenizer_json=None,
|
autoregressive_model_path=None, diffusion_model_path=None, vocoder_model=None, tokenizer_json=None,
|
||||||
):
|
# ):
|
||||||
|
use_deepspeed=False): # Add use_deepspeed parameter
|
||||||
"""
|
"""
|
||||||
Constructor
|
Constructor
|
||||||
:param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
|
:param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
|
||||||
|
@ -280,7 +281,8 @@ class TextToSpeech:
|
||||||
self.output_sample_rate = output_sample_rate
|
self.output_sample_rate = output_sample_rate
|
||||||
self.minor_optimizations = minor_optimizations
|
self.minor_optimizations = minor_optimizations
|
||||||
self.unsqueeze_sample_batches = unsqueeze_sample_batches
|
self.unsqueeze_sample_batches = unsqueeze_sample_batches
|
||||||
|
self.use_deepspeed = use_deepspeed # Store use_deepspeed as an instance variable
|
||||||
|
print(f'use_deepspeed api_debug {use_deepspeed}')
|
||||||
# for clarity, it's simpler to split these up and just predicate them on requesting VRAM-consuming optimizations
|
# for clarity, it's simpler to split these up and just predicate them on requesting VRAM-consuming optimizations
|
||||||
self.preloaded_tensors = minor_optimizations
|
self.preloaded_tensors = minor_optimizations
|
||||||
self.use_kv_cache = minor_optimizations
|
self.use_kv_cache = minor_optimizations
|
||||||
|
@ -336,7 +338,7 @@ class TextToSpeech:
|
||||||
|
|
||||||
self.loading = False
|
self.loading = False
|
||||||
|
|
||||||
def load_autoregressive_model(self, autoregressive_model_path):
|
def load_autoregressive_model(self, autoregressive_model_path, is_xtts=False):
|
||||||
if hasattr(self,"autoregressive_model_path") and os.path.samefile(self.autoregressive_model_path, autoregressive_model_path):
|
if hasattr(self,"autoregressive_model_path") and os.path.samefile(self.autoregressive_model_path, autoregressive_model_path):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -354,12 +356,40 @@ class TextToSpeech:
|
||||||
if hasattr(self, 'autoregressive'):
|
if hasattr(self, 'autoregressive'):
|
||||||
del self.autoregressive
|
del self.autoregressive
|
||||||
|
|
||||||
self.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30,
|
# XTTS requires a different "dimensionality" for its autoregressive model
|
||||||
model_dim=1024,
|
if new_hash == "e4ce21eae0043f7691d6a6c8540b74b8" or is_xtts:
|
||||||
heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False,
|
dimensionality = {
|
||||||
train_solo_embeddings=False).cpu().eval()
|
"max_mel_tokens": 605,
|
||||||
|
"max_text_tokens": 402,
|
||||||
|
"max_prompt_tokens": 70,
|
||||||
|
"max_conditioning_inputs": 1,
|
||||||
|
"layers": 30,
|
||||||
|
"model_dim": 1024,
|
||||||
|
"heads": 16,
|
||||||
|
"number_text_tokens": 5023, # -1
|
||||||
|
"start_text_token": 261,
|
||||||
|
"stop_text_token": 0,
|
||||||
|
"number_mel_codes": 8194,
|
||||||
|
"start_mel_token": 8192,
|
||||||
|
"stop_mel_token": 8193,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
dimensionality = {
|
||||||
|
"max_mel_tokens": 604,
|
||||||
|
"max_text_tokens": 402,
|
||||||
|
"max_conditioning_inputs": 2,
|
||||||
|
"layers": 30,
|
||||||
|
"model_dim": 1024,
|
||||||
|
"heads": 16,
|
||||||
|
"number_text_tokens": 255,
|
||||||
|
"start_text_token": 255,
|
||||||
|
"checkpointing": False,
|
||||||
|
"train_solo_embeddings": False
|
||||||
|
}
|
||||||
|
|
||||||
|
self.autoregressive = UnifiedVoice(**dimensionality).cpu().eval()
|
||||||
self.autoregressive.load_state_dict(torch.load(self.autoregressive_model_path))
|
self.autoregressive.load_state_dict(torch.load(self.autoregressive_model_path))
|
||||||
self.autoregressive.post_init_gpt2_config(kv_cache=self.use_kv_cache)
|
self.autoregressive.post_init_gpt2_config(use_deepspeed=self.use_deepspeed, kv_cache=self.use_kv_cache)
|
||||||
if self.preloaded_tensors:
|
if self.preloaded_tensors:
|
||||||
self.autoregressive = migrate_to_device( self.autoregressive, self.device )
|
self.autoregressive = migrate_to_device( self.autoregressive, self.device )
|
||||||
|
|
||||||
|
@ -378,9 +408,21 @@ class TextToSpeech:
|
||||||
if hasattr(self, 'diffusion'):
|
if hasattr(self, 'diffusion'):
|
||||||
del self.diffusion
|
del self.diffusion
|
||||||
|
|
||||||
self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
|
# XTTS does not require a different "dimensionality" for its diffusion model
|
||||||
in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16,
|
dimensionality = {
|
||||||
layer_drop=0, unconditioned_percentage=0).cpu().eval()
|
"model_channels": 1024,
|
||||||
|
"num_layers": 10,
|
||||||
|
"in_channels": 100,
|
||||||
|
"out_channels": 200,
|
||||||
|
"in_latent_channels": 1024,
|
||||||
|
"in_tokens": 8193,
|
||||||
|
"dropout": 0,
|
||||||
|
"use_fp16": False,
|
||||||
|
"num_heads": 16,
|
||||||
|
"layer_drop": 0,
|
||||||
|
"unconditioned_percentage": 0
|
||||||
|
}
|
||||||
|
self.diffusion = DiffusionTts(**dimensionality)
|
||||||
self.diffusion.load_state_dict(torch.load(get_model_path('diffusion_decoder.pth', self.models_dir)))
|
self.diffusion.load_state_dict(torch.load(get_model_path('diffusion_decoder.pth', self.models_dir)))
|
||||||
if self.preloaded_tensors:
|
if self.preloaded_tensors:
|
||||||
self.diffusion = migrate_to_device( self.diffusion, self.device )
|
self.diffusion = migrate_to_device( self.diffusion, self.device )
|
||||||
|
@ -773,7 +815,10 @@ class TextToSpeech:
|
||||||
|
|
||||||
clip_results = torch.cat(clip_results, dim=0)
|
clip_results = torch.cat(clip_results, dim=0)
|
||||||
samples = torch.cat(samples, dim=0)
|
samples = torch.cat(samples, dim=0)
|
||||||
best_results = samples[torch.topk(clip_results, k=k).indices]
|
if k < num_autoregressive_samples:
|
||||||
|
best_results = samples[torch.topk(clip_results, k=k).indices]
|
||||||
|
else:
|
||||||
|
best_results = samples
|
||||||
|
|
||||||
if not self.preloaded_tensors:
|
if not self.preloaded_tensors:
|
||||||
self.clvp = migrate_to_device( self.clvp, 'cpu' )
|
self.clvp = migrate_to_device( self.clvp, 'cpu' )
|
||||||
|
|
|
@ -14,6 +14,7 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--voice', type=str, help='Selects the voice to use for generation. See options in voices/ directory (and add your own!) '
|
parser.add_argument('--voice', type=str, help='Selects the voice to use for generation. See options in voices/ directory (and add your own!) '
|
||||||
'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.', default='random')
|
'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.', default='random')
|
||||||
parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='standard')
|
parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='standard')
|
||||||
|
parser.add_argument('--use_deepspeed', type=bool, help='Use deepspeed for speed bump.', default=True)
|
||||||
parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/')
|
parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/')
|
||||||
parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this'
|
parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this'
|
||||||
'should only be specified if you have custom checkpoints.', default=MODELS_DIR)
|
'should only be specified if you have custom checkpoints.', default=MODELS_DIR)
|
||||||
|
@ -37,8 +38,8 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
|
|
||||||
os.makedirs(args.output_path, exist_ok=True)
|
os.makedirs(args.output_path, exist_ok=True)
|
||||||
|
#print(f'use_deepspeed do_tts_debug {use_deepspeed}')
|
||||||
tts = TextToSpeech(models_dir=args.model_dir)
|
tts = TextToSpeech(models_dir=args.model_dir, use_deepspeed=args.use_deepspeed)
|
||||||
|
|
||||||
selected_voices = args.voice.split(',')
|
selected_voices = args.voice.split(',')
|
||||||
for k, selected_voice in enumerate(selected_voices):
|
for k, selected_voice in enumerate(selected_voices):
|
||||||
|
|
|
@ -283,9 +283,9 @@ class MelEncoder(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class UnifiedVoice(nn.Module):
|
class UnifiedVoice(nn.Module):
|
||||||
def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1,
|
def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_prompt_tokens=2, max_mel_tokens=250, max_conditioning_inputs=1,
|
||||||
mel_length_compression=1024, number_text_tokens=256,
|
mel_length_compression=1024, number_text_tokens=256,
|
||||||
start_text_token=None, number_mel_codes=8194, start_mel_token=8192,
|
start_text_token=None, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
|
||||||
stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
|
stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
|
||||||
checkpointing=True, types=1):
|
checkpointing=True, types=1):
|
||||||
"""
|
"""
|
||||||
|
@ -295,6 +295,7 @@ class UnifiedVoice(nn.Module):
|
||||||
heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64
|
heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64
|
||||||
max_text_tokens: Maximum number of text tokens that will be encountered by model.
|
max_text_tokens: Maximum number of text tokens that will be encountered by model.
|
||||||
max_mel_tokens: Maximum number of MEL tokens that will be encountered by model.
|
max_mel_tokens: Maximum number of MEL tokens that will be encountered by model.
|
||||||
|
max_prompt_tokens: compat set to 2, 70 for XTTS
|
||||||
max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s).
|
max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s).
|
||||||
mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length.
|
mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length.
|
||||||
number_text_tokens:
|
number_text_tokens:
|
||||||
|
@ -311,7 +312,7 @@ class UnifiedVoice(nn.Module):
|
||||||
|
|
||||||
self.number_text_tokens = number_text_tokens
|
self.number_text_tokens = number_text_tokens
|
||||||
self.start_text_token = number_text_tokens * types if start_text_token is None else start_text_token
|
self.start_text_token = number_text_tokens * types if start_text_token is None else start_text_token
|
||||||
self.stop_text_token = 0
|
self.stop_text_token = stop_text_token
|
||||||
self.number_mel_codes = number_mel_codes
|
self.number_mel_codes = number_mel_codes
|
||||||
self.start_mel_token = start_mel_token
|
self.start_mel_token = start_mel_token
|
||||||
self.stop_mel_token = stop_mel_token
|
self.stop_mel_token = stop_mel_token
|
||||||
|
@ -319,6 +320,7 @@ class UnifiedVoice(nn.Module):
|
||||||
self.heads = heads
|
self.heads = heads
|
||||||
self.max_mel_tokens = max_mel_tokens
|
self.max_mel_tokens = max_mel_tokens
|
||||||
self.max_text_tokens = max_text_tokens
|
self.max_text_tokens = max_text_tokens
|
||||||
|
self.max_prompt_tokens = max_prompt_tokens
|
||||||
self.model_dim = model_dim
|
self.model_dim = model_dim
|
||||||
self.max_conditioning_inputs = max_conditioning_inputs
|
self.max_conditioning_inputs = max_conditioning_inputs
|
||||||
self.mel_length_compression = mel_length_compression
|
self.mel_length_compression = mel_length_compression
|
||||||
|
@ -352,8 +354,8 @@ class UnifiedVoice(nn.Module):
|
||||||
for module in embeddings:
|
for module in embeddings:
|
||||||
module.weight.data.normal_(mean=0.0, std=.02)
|
module.weight.data.normal_(mean=0.0, std=.02)
|
||||||
|
|
||||||
def post_init_gpt2_config(self, kv_cache=False):
|
def post_init_gpt2_config(self, use_deepspeed=False, kv_cache=False):
|
||||||
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
|
seq_length = self.max_mel_tokens + self.max_text_tokens + self.max_prompt_tokens
|
||||||
gpt_config = GPT2Config(vocab_size=self.max_mel_tokens,
|
gpt_config = GPT2Config(vocab_size=self.max_mel_tokens,
|
||||||
n_positions=seq_length,
|
n_positions=seq_length,
|
||||||
n_ctx=seq_length,
|
n_ctx=seq_length,
|
||||||
|
@ -363,6 +365,17 @@ class UnifiedVoice(nn.Module):
|
||||||
gradient_checkpointing=False,
|
gradient_checkpointing=False,
|
||||||
use_cache=True)
|
use_cache=True)
|
||||||
self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head, kv_cache=kv_cache)
|
self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head, kv_cache=kv_cache)
|
||||||
|
#print(f'use_deepspeed autoregressive_debug {use_deepspeed}')
|
||||||
|
if use_deepspeed and torch.cuda.is_available():
|
||||||
|
import deepspeed
|
||||||
|
self.ds_engine = deepspeed.init_inference(model=self.inference_model,
|
||||||
|
mp_size=1,
|
||||||
|
replace_with_kernel_inject=True,
|
||||||
|
dtype=torch.float32)
|
||||||
|
self.inference_model = self.ds_engine.module.eval()
|
||||||
|
else:
|
||||||
|
self.inference_model = self.inference_model.eval()
|
||||||
|
|
||||||
self.gpt.wte = self.mel_embedding
|
self.gpt.wte = self.mel_embedding
|
||||||
|
|
||||||
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
|
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
|
||||||
|
@ -483,7 +496,7 @@ class UnifiedVoice(nn.Module):
|
||||||
|
|
||||||
def inference_speech(self, speech_conditioning_latent, text_inputs, input_tokens=None, num_return_sequences=1,
|
def inference_speech(self, speech_conditioning_latent, text_inputs, input_tokens=None, num_return_sequences=1,
|
||||||
max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
|
max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
|
||||||
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
|
seq_length = self.max_mel_tokens + self.max_text_tokens + self.max_prompt_tokens
|
||||||
if not hasattr(self, 'inference_model'):
|
if not hasattr(self, 'inference_model'):
|
||||||
self.post_init_gpt2_config(kv_cache=self.kv_cache)
|
self.post_init_gpt2_config(kv_cache=self.kv_cache)
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@ if __name__ == '__main__':
|
||||||
'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.', default='pat')
|
'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.', default='pat')
|
||||||
parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/longform/')
|
parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/longform/')
|
||||||
parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='standard')
|
parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='standard')
|
||||||
|
parser.add_argument('--use_deepspeed', type=bool, help='Use deepspeed for speed bump.', default=True)
|
||||||
parser.add_argument('--regenerate', type=str, help='Comma-separated list of clip numbers to re-generate, or nothing.', default=None)
|
parser.add_argument('--regenerate', type=str, help='Comma-separated list of clip numbers to re-generate, or nothing.', default=None)
|
||||||
parser.add_argument('--candidates', type=int, help='How many output candidates to produce per-voice. Only the first candidate is actually used in the final product, the others can be used manually.', default=1)
|
parser.add_argument('--candidates', type=int, help='How many output candidates to produce per-voice. Only the first candidate is actually used in the final product, the others can be used manually.', default=1)
|
||||||
parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this'
|
parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this'
|
||||||
|
@ -25,7 +26,7 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--produce_debug_state', type=bool, help='Whether or not to produce debug_state.pth, which can aid in reproducing problems. Defaults to true.', default=True)
|
parser.add_argument('--produce_debug_state', type=bool, help='Whether or not to produce debug_state.pth, which can aid in reproducing problems. Defaults to true.', default=True)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
tts = TextToSpeech(models_dir=args.model_dir)
|
tts = TextToSpeech(models_dir=args.model_dir, use_deepspeed=args.use_deepspeed)
|
||||||
|
|
||||||
outpath = args.output_path
|
outpath = args.output_path
|
||||||
selected_voices = args.voice.split(',')
|
selected_voices = args.voice.split(',')
|
||||||
|
|
Loading…
Reference in New Issue
Block a user