Update docs

This commit is contained in:
James Betker 2022-02-03 22:18:21 -07:00
parent 86004ad967
commit 16f5d4f625
3 changed files with 110 additions and 84 deletions

View File

@ -14,19 +14,25 @@ expect ~5 seconds of speech to take ~30 seconds to produce on the latest hardwar
## What the heck is this? ## What the heck is this?
Tortoise TTS is inspired by OpenAI's DALLE, applied to speech data. It is made up of 4 separate models that work together: Tortoise TTS is inspired by OpenAI's DALLE, applied to speech data. It is made up of 4 separate models that work together.
These models are all derived from different repositories which are all linked. All the models have been modified
for this use case (some substantially so).
First, an autoregressive transformer stack predicts discrete speech "tokens" given a text prompt. This model is very First, an autoregressive transformer stack predicts discrete speech "tokens" given a text prompt. This model is very
similar to the GPT model used by DALLE, except it operates on speech data. similar to the GPT model used by DALLE, except it operates on speech data.
Based on: [GPT2 from Transformers](https://huggingface.co/docs/transformers/model_doc/gpt2)
Next, a CLIP model judges a batch of outputs from the autoregressive transformer against the provided text and stack Next, a CLIP model judges a batch of outputs from the autoregressive transformer against the provided text and stack
ranks the outputs according to most probable. You could use greedy or beam-search decoding but in my experience CLIP ranks the outputs according to most probable. You could use greedy or beam-search decoding but in my experience CLIP
decoding creates considerably better results. decoding creates considerably better results.
Based on [CLIP from lucidrains](https://github.com/lucidrains/DALLE-pytorch/blob/main/dalle_pytorch/dalle_pytorch.py)
Next, the speech "tokens" are decoded into a low-quality MEL spectrogram using a VQVAE. Next, the speech "tokens" are decoded into a low-quality MEL spectrogram using a VQVAE.
Based on [VQVAE2 by rosinality](https://github.com/rosinality/vq-vae-2-pytorch)
Finally, the output of the VQVAE is further decoded by a UNet diffusion model into raw audio, which can be placed in Finally, the output of the VQVAE is further decoded by a UNet diffusion model into raw audio, which can be placed in
a wav file. a wav file.
Based on [ImprovedDiffusion by openai](https://github.com/openai/improved-diffusion)
## How do I use this? ## How do I use this?

184
do_tts.py
View File

@ -25,25 +25,7 @@ def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusi
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps)) model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps))
def do_spectrogram_diffusion(diffusion_model, dvae_model, diffuser, mel_codes, conditioning_input, spectrogram_compression_factor=128): def load_conditioning(path, sample_rate=22050, cond_length=132300):
"""
Uses the specified diffusion model and DVAE model to convert the provided MEL & conditioning inputs into an audio clip.
"""
with torch.no_grad():
mel = dvae_model.decode(mel_codes)[0]
# Pad MEL to multiples of 2048//spectrogram_compression_factor
msl = mel.shape[-1]
dsl = 2048 // spectrogram_compression_factor
gap = dsl - (msl % dsl)
if gap > 0:
mel = torch.nn.functional.pad(mel, (0, gap))
output_shape = (mel.shape[0], 1, mel.shape[-1] * spectrogram_compression_factor)
return diffuser.p_sample_loop(diffusion_model, output_shape, model_kwargs={'spectrogram': mel, 'conditioning_input': conditioning_input})
def load_conditioning(path, sample_rate=22050, cond_length=44100):
rel_clip = load_audio(path, sample_rate) rel_clip = load_audio(path, sample_rate)
gap = rel_clip.shape[-1] - cond_length gap = rel_clip.shape[-1] - cond_length
if gap < 0: if gap < 0:
@ -82,86 +64,122 @@ def fix_autoregressive_output(codes, stop_token):
return codes return codes
def do_spectrogram_diffusion(diffusion_model, dvae_model, diffuser, mel_codes, conditioning_input, spectrogram_compression_factor=128, mean=False):
"""
Uses the specified diffusion model and DVAE model to convert the provided MEL & conditioning inputs into an audio clip.
"""
with torch.no_grad():
mel = dvae_model.decode(mel_codes)[0]
# Pad MEL to multiples of 2048//spectrogram_compression_factor
msl = mel.shape[-1]
dsl = 2048 // spectrogram_compression_factor
gap = dsl - (msl % dsl)
if gap > 0:
mel = torch.nn.functional.pad(mel, (0, gap))
output_shape = (mel.shape[0], 1, mel.shape[-1] * spectrogram_compression_factor)
if mean:
return diffuser.p_sample_loop(diffusion_model, output_shape, noise=torch.zeros(output_shape, device=mel_codes.device),
model_kwargs={'spectrogram': mel, 'conditioning_input': conditioning_input})
else:
return diffuser.p_sample_loop(diffusion_model, output_shape, model_kwargs={'spectrogram': mel, 'conditioning_input': conditioning_input})
if __name__ == '__main__': if __name__ == '__main__':
# These are voices drawn randomly from the training set. You are free to substitute your own voices in, but testing
# has shown that the model does not generalize to new voices very well.
preselected_cond_voices = { preselected_cond_voices = {
'simmons': ['Y:\\clips\\books1\\754_Dan Simmons - The Rise Of Endymion 356 of 450\\00026.wav'], # Male voices
'news_girl': ['Y:\\clips\\podcasts-0\\8288_20210113-Is More Violence Coming_\\00022.wav', 'Y:\\clips\\podcasts-0\\8288_20210113-Is More Violence Coming_\\00016.wav'], 'dotrice': ['voices/dotrice/1.wav', 'voices/dotrice/2.wav'],
'dan_carlin': ['Y:\\clips\\books1\\5_dchha06 Shield of the West\\00476.wav', 'Y:\\clips\\books1\\15_dchha16 Nazi Tidbits\\00036.wav'], 'harris': ['voices/male_harris1.wav', 'voices/male_harris2.wav'],
'libri_test': ['Y:\\libritts\\test-clean\\672\\122797\\672_122797_000057_000002.wav'], 'lescault': ['voices/male_lescault1.wav', 'voices/male_lescault2.wav'],
'otto': ['voices/male_otto1.wav', 'voices/male_otto2.wav'],
# Female voices
'atkins': ['voices/female_atkins1.wav', 'voices/female_atkins2.wav'],
'grace': ['voices/female_grace1.wav', 'voices/female_grace2.wav'],
'kennard': ['voices/female_kennard1.wav', 'voices/female_kennard2.wav'],
'mol': ['voices/female_mol1.wav', 'voices/female_mol2.wav'],
} }
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-autoregressive_model_path', type=str, help='Autoregressive model checkpoint to load.', default='.models/unified_voice.pth') parser.add_argument('-autoregressive_model_path', type=str, help='Autoregressive model checkpoint to load.', default='.models/unified_voice.pth')
parser.add_argument('-clip_model_path', type=str, help='CLIP model checkpoint to load.', default='.models/clip.pth') parser.add_argument('-clip_model_path', type=str, help='CLIP model checkpoint to load.', default='.models/clip.pth')
parser.add_argument('-diffusion_model_path', type=str, help='Diffusion model checkpoint to load.', default='./models/diffusion_vocoder.pth') parser.add_argument('-diffusion_model_path', type=str, help='Diffusion model checkpoint to load.', default='.models/diffusion_vocoder.pth')
parser.add_argument('-dvae_model_path', type=str, help='DVAE model checkpoint to load.', default='./models/dvae.pth') parser.add_argument('-dvae_model_path', type=str, help='DVAE model checkpoint to load.', default='.models/dvae.pth')
parser.add_argument('-text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.") parser.add_argument('-text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.")
parser.add_argument('-cond_preset', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='dan_carlin') parser.add_argument('-voice', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='dotrice,harris,lescault,otto,atkins,grace,kennard,mol')
parser.add_argument('-num_samples', type=int, help='How many total outputs the autoregressive transformer should produce.', default=32) parser.add_argument('-num_samples', type=int, help='How many total outputs the autoregressive transformer should produce.', default=512)
parser.add_argument('-num_batches', type=int, help='How many batches those samples should be produced over.', default=2) parser.add_argument('-num_batches', type=int, help='How many batches those samples should be produced over.', default=16)
parser.add_argument('-num_outputs', type=int, help='Number of outputs to produce.', default=2) parser.add_argument('-num_outputs', type=int, help='Number of outputs to produce.', default=2)
parser.add_argument('-output_path', type=str, help='Where to store outputs.', default='results/') parser.add_argument('-output_path', type=str, help='Where to store outputs.', default='results/')
args = parser.parse_args() args = parser.parse_args()
os.makedirs(args.output_path, exist_ok=True) os.makedirs(args.output_path, exist_ok=True)
print("Loading GPT TTS..") for voice in args.voice.split(','):
autoregressive = UnifiedVoice(max_mel_tokens=300, max_text_tokens=200, max_conditioning_inputs=2, layers=30, model_dim=1024, heads=16, number_text_tokens=256, start_text_token=255, checkpointing=False, train_solo_embeddings=False).eval() print("Loading GPT TTS..")
autoregressive.load_state_dict(torch.load(args.autoregressive_model_path)) autoregressive = UnifiedVoice(max_mel_tokens=300, max_text_tokens=200, max_conditioning_inputs=2, layers=30, model_dim=1024,
stop_mel_token = autoregressive.stop_mel_token heads=16, number_text_tokens=256, start_text_token=255, checkpointing=False, train_solo_embeddings=False).cuda().eval()
autoregressive.load_state_dict(torch.load(args.autoregressive_model_path))
stop_mel_token = autoregressive.stop_mel_token
print("Loading data..") print("Loading data..")
tokenizer = VoiceBpeTokenizer() tokenizer = VoiceBpeTokenizer()
text = torch.IntTensor(tokenizer.encode(args.text)).unsqueeze(0).cuda() text = torch.IntTensor(tokenizer.encode(args.text)).unsqueeze(0).cuda()
text = F.pad(text, (0,1)) # This may not be necessary. text = F.pad(text, (0,1)) # This may not be necessary.
cond_paths = preselected_cond_voices[args.cond_preset] cond_paths = preselected_cond_voices[voice]
conds = [] conds = []
for cond_path in cond_paths: for cond_path in cond_paths:
c, cond_wav = load_conditioning(cond_path, cond_length=132300) c, cond_wav = load_conditioning(cond_path)
conds.append(c) conds.append(c)
conds = torch.stack(conds, dim=1) # And just use the last cond_wav for the diffusion model. conds = torch.stack(conds, dim=1) # And just use the last cond_wav for the diffusion model.
with torch.no_grad(): with torch.no_grad():
print("Performing GPT inference..") print("Performing autoregressive inference..")
samples = [] samples = []
for b in tqdm(range(args.num_batches)): for b in tqdm(range(args.num_batches)):
codes = autoregressive.inference_speech(conds, text, num_beams=1, repetition_penalty=1.0, do_sample=True, top_k=50, top_p=.95, codes = autoregressive.inference_speech(conds, text, num_beams=1, repetition_penalty=1.0, do_sample=True, top_k=50, top_p=.95,
temperature=.9, num_return_sequences=args.num_samples//args.num_batches, length_penalty=1) temperature=.9, num_return_sequences=args.num_samples//args.num_batches, length_penalty=1)
padding_needed = 250 - codes.shape[1] padding_needed = 250 - codes.shape[1]
codes = F.pad(codes, (0, padding_needed), value=stop_mel_token) codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
samples.append(codes) samples.append(codes)
samples = torch.cat(samples, dim=0) del autoregressive
del autoregressive
print("Loading CLIP..") print("Loading CLIP..")
clip = VoiceCLIP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=8, text_seq_len=120, text_heads=8, clip = VoiceCLIP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=8, text_seq_len=120, text_heads=8,
num_speech_tokens=8192, speech_enc_depth=10, speech_heads=8, speech_seq_len=250).eval() num_speech_tokens=8192, speech_enc_depth=10, speech_heads=8, speech_seq_len=250).cuda().eval()
clip.load_state_dict(torch.load(args.clip_model_path)) clip.load_state_dict(torch.load(args.clip_model_path))
print("Performing CLIP filtering..") print("Performing CLIP filtering..")
for i in range(samples.shape[0]): clip_results = []
samples[i] = fix_autoregressive_output(samples[i], stop_mel_token) for batch in samples:
clip_results = clip(text.repeat(samples.shape[0], 1), for i in range(batch.shape[0]):
torch.full((samples.shape[0],), fill_value=text.shape[1]-1, dtype=torch.long, device='cuda'), batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
samples, torch.full((samples.shape[0],), fill_value=samples.shape[1]*1024, dtype=torch.long, device='cuda'), text = text[:, :120] # Ugly hack to fix the fact that I didn't train CLIP to handle long enough text.
return_loss=False) clip_results.append(clip(text.repeat(batch.shape[0], 1),
best_results = samples[torch.topk(clip_results, k=args.num_outputs).indices] torch.full((batch.shape[0],), fill_value=text.shape[1]-1, dtype=torch.long, device='cuda'),
batch, torch.full((batch.shape[0],), fill_value=batch.shape[1]*1024, dtype=torch.long, device='cuda'),
return_loss=False))
clip_results = torch.cat(clip_results, dim=0)
samples = torch.cat(samples, dim=0)
best_results = samples[torch.topk(clip_results, k=args.num_outputs).indices]
# Delete the autoregressive and clip models to free up GPU memory # Delete the autoregressive and clip models to free up GPU memory
del samples, clip del samples, clip
print("Loading DVAE..") print("Loading DVAE..")
dvae = DiscreteVAE(positional_dims=1, channels=80, hidden_dim=512, num_resnet_blocks=3, codebook_dim=512, num_tokens=8192, num_layers=2, dvae = DiscreteVAE(positional_dims=1, channels=80, hidden_dim=512, num_resnet_blocks=3, codebook_dim=512, num_tokens=8192, num_layers=2,
record_codes=True, kernel_size=3, use_transposed_convs=False).eval() record_codes=True, kernel_size=3, use_transposed_convs=False).cuda().eval()
dvae.load_state_dict(torch.load(args.dvae_model_path)) dvae.load_state_dict(torch.load(args.dvae_model_path))
print("Loading Diffusion Model..") print("Loading Diffusion Model..")
diffusion = DiscreteDiffusionVocoder(model_channels=128, dvae_dim=80, channel_mult=[1, 1, 1.5, 2, 3, 4, 6, 8, 8, 8, 8], num_res_blocks=[1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1], diffusion = DiscreteDiffusionVocoder(model_channels=128, dvae_dim=80, channel_mult=[1, 1, 1.5, 2, 3, 4, 6, 8, 8, 8, 8], num_res_blocks=[1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1],
spectrogram_conditioning_resolutions=[2,512], attention_resolutions=[512,1024], num_heads=4, kernel_size=3, scale_factor=2, spectrogram_conditioning_resolutions=[2,512], attention_resolutions=[512,1024], num_heads=4, kernel_size=3, scale_factor=2,
conditioning_inputs_provided=True, time_embed_dim_multiplier=4).eval() conditioning_inputs_provided=True, time_embed_dim_multiplier=4).cuda().eval()
diffusion.load_state_dict(torch.load(args.diffusion_model_path)) diffusion.load_state_dict(torch.load(args.diffusion_model_path))
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=100) diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=100)
print("Performing vocoding..") print("Performing vocoding..")
# Perform vocoding on each batch element separately: Vocoding is very memory (and compute!) intensive. # Perform vocoding on each batch element separately: The diffusion model is very memory (and compute!) intensive.
for b in range(best_results.shape[0]): for b in range(best_results.shape[0]):
code = best_results[b].unsqueeze(0) code = best_results[b].unsqueeze(0)
wav = do_spectrogram_diffusion(diffusion, dvae, diffuser, code, cond_wav, spectrogram_compression_factor=256) wav = do_spectrogram_diffusion(diffusion, dvae, diffuser, code, cond_wav, spectrogram_compression_factor=256, mean=True)
torchaudio.save(os.path.join(args.output_path, f'gpt_tts_output_{b}.wav'), wav.squeeze(0).cpu(), 22050) torchaudio.save(os.path.join(args.output_path, f'{voice}_{b}.wav'), wav.squeeze(0).cpu(), 22050)

View File

@ -1,5 +1,7 @@
import torch import torch
import torchaudio import torchaudio
import numpy as np
from scipy.io.wavfile import read
def load_wav_to_torch(full_path): def load_wav_to_torch(full_path):