Update docs

This commit is contained in:
James Betker 2022-02-03 22:18:21 -07:00
parent 86004ad967
commit 16f5d4f625
3 changed files with 110 additions and 84 deletions

View File

@ -14,19 +14,25 @@ expect ~5 seconds of speech to take ~30 seconds to produce on the latest hardwar
## What the heck is this?
Tortoise TTS is inspired by OpenAI's DALLE, applied to speech data. It is made up of 4 separate models that work together:
Tortoise TTS is inspired by OpenAI's DALLE, applied to speech data. It is made up of 4 separate models that work together.
These models are all derived from different repositories which are all linked. All the models have been modified
for this use case (some substantially so).
First, an autoregressive transformer stack predicts discrete speech "tokens" given a text prompt. This model is very
similar to the GPT model used by DALLE, except it operates on speech data.
Based on: [GPT2 from Transformers](https://huggingface.co/docs/transformers/model_doc/gpt2)
Next, a CLIP model judges a batch of outputs from the autoregressive transformer against the provided text and stack
ranks the outputs according to most probable. You could use greedy or beam-search decoding but in my experience CLIP
decoding creates considerably better results.
Based on [CLIP from lucidrains](https://github.com/lucidrains/DALLE-pytorch/blob/main/dalle_pytorch/dalle_pytorch.py)
Next, the speech "tokens" are decoded into a low-quality MEL spectrogram using a VQVAE.
Based on [VQVAE2 by rosinality](https://github.com/rosinality/vq-vae-2-pytorch)
Finally, the output of the VQVAE is further decoded by a UNet diffusion model into raw audio, which can be placed in
a wav file.
Based on [ImprovedDiffusion by openai](https://github.com/openai/improved-diffusion)
## How do I use this?

184
do_tts.py
View File

@ -25,25 +25,7 @@ def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusi
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps))
def do_spectrogram_diffusion(diffusion_model, dvae_model, diffuser, mel_codes, conditioning_input, spectrogram_compression_factor=128):
"""
Uses the specified diffusion model and DVAE model to convert the provided MEL & conditioning inputs into an audio clip.
"""
with torch.no_grad():
mel = dvae_model.decode(mel_codes)[0]
# Pad MEL to multiples of 2048//spectrogram_compression_factor
msl = mel.shape[-1]
dsl = 2048 // spectrogram_compression_factor
gap = dsl - (msl % dsl)
if gap > 0:
mel = torch.nn.functional.pad(mel, (0, gap))
output_shape = (mel.shape[0], 1, mel.shape[-1] * spectrogram_compression_factor)
return diffuser.p_sample_loop(diffusion_model, output_shape, model_kwargs={'spectrogram': mel, 'conditioning_input': conditioning_input})
def load_conditioning(path, sample_rate=22050, cond_length=44100):
def load_conditioning(path, sample_rate=22050, cond_length=132300):
rel_clip = load_audio(path, sample_rate)
gap = rel_clip.shape[-1] - cond_length
if gap < 0:
@ -82,86 +64,122 @@ def fix_autoregressive_output(codes, stop_token):
return codes
def do_spectrogram_diffusion(diffusion_model, dvae_model, diffuser, mel_codes, conditioning_input, spectrogram_compression_factor=128, mean=False):
"""
Uses the specified diffusion model and DVAE model to convert the provided MEL & conditioning inputs into an audio clip.
"""
with torch.no_grad():
mel = dvae_model.decode(mel_codes)[0]
# Pad MEL to multiples of 2048//spectrogram_compression_factor
msl = mel.shape[-1]
dsl = 2048 // spectrogram_compression_factor
gap = dsl - (msl % dsl)
if gap > 0:
mel = torch.nn.functional.pad(mel, (0, gap))
output_shape = (mel.shape[0], 1, mel.shape[-1] * spectrogram_compression_factor)
if mean:
return diffuser.p_sample_loop(diffusion_model, output_shape, noise=torch.zeros(output_shape, device=mel_codes.device),
model_kwargs={'spectrogram': mel, 'conditioning_input': conditioning_input})
else:
return diffuser.p_sample_loop(diffusion_model, output_shape, model_kwargs={'spectrogram': mel, 'conditioning_input': conditioning_input})
if __name__ == '__main__':
# These are voices drawn randomly from the training set. You are free to substitute your own voices in, but testing
# has shown that the model does not generalize to new voices very well.
preselected_cond_voices = {
'simmons': ['Y:\\clips\\books1\\754_Dan Simmons - The Rise Of Endymion 356 of 450\\00026.wav'],
'news_girl': ['Y:\\clips\\podcasts-0\\8288_20210113-Is More Violence Coming_\\00022.wav', 'Y:\\clips\\podcasts-0\\8288_20210113-Is More Violence Coming_\\00016.wav'],
'dan_carlin': ['Y:\\clips\\books1\\5_dchha06 Shield of the West\\00476.wav', 'Y:\\clips\\books1\\15_dchha16 Nazi Tidbits\\00036.wav'],
'libri_test': ['Y:\\libritts\\test-clean\\672\\122797\\672_122797_000057_000002.wav'],
# Male voices
'dotrice': ['voices/dotrice/1.wav', 'voices/dotrice/2.wav'],
'harris': ['voices/male_harris1.wav', 'voices/male_harris2.wav'],
'lescault': ['voices/male_lescault1.wav', 'voices/male_lescault2.wav'],
'otto': ['voices/male_otto1.wav', 'voices/male_otto2.wav'],
# Female voices
'atkins': ['voices/female_atkins1.wav', 'voices/female_atkins2.wav'],
'grace': ['voices/female_grace1.wav', 'voices/female_grace2.wav'],
'kennard': ['voices/female_kennard1.wav', 'voices/female_kennard2.wav'],
'mol': ['voices/female_mol1.wav', 'voices/female_mol2.wav'],
}
parser = argparse.ArgumentParser()
parser.add_argument('-autoregressive_model_path', type=str, help='Autoregressive model checkpoint to load.', default='.models/unified_voice.pth')
parser.add_argument('-clip_model_path', type=str, help='CLIP model checkpoint to load.', default='.models/clip.pth')
parser.add_argument('-diffusion_model_path', type=str, help='Diffusion model checkpoint to load.', default='./models/diffusion_vocoder.pth')
parser.add_argument('-dvae_model_path', type=str, help='DVAE model checkpoint to load.', default='./models/dvae.pth')
parser.add_argument('-diffusion_model_path', type=str, help='Diffusion model checkpoint to load.', default='.models/diffusion_vocoder.pth')
parser.add_argument('-dvae_model_path', type=str, help='DVAE model checkpoint to load.', default='.models/dvae.pth')
parser.add_argument('-text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.")
parser.add_argument('-cond_preset', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='dan_carlin')
parser.add_argument('-num_samples', type=int, help='How many total outputs the autoregressive transformer should produce.', default=32)
parser.add_argument('-num_batches', type=int, help='How many batches those samples should be produced over.', default=2)
parser.add_argument('-voice', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='dotrice,harris,lescault,otto,atkins,grace,kennard,mol')
parser.add_argument('-num_samples', type=int, help='How many total outputs the autoregressive transformer should produce.', default=512)
parser.add_argument('-num_batches', type=int, help='How many batches those samples should be produced over.', default=16)
parser.add_argument('-num_outputs', type=int, help='Number of outputs to produce.', default=2)
parser.add_argument('-output_path', type=str, help='Where to store outputs.', default='results/')
args = parser.parse_args()
os.makedirs(args.output_path, exist_ok=True)
print("Loading GPT TTS..")
autoregressive = UnifiedVoice(max_mel_tokens=300, max_text_tokens=200, max_conditioning_inputs=2, layers=30, model_dim=1024, heads=16, number_text_tokens=256, start_text_token=255, checkpointing=False, train_solo_embeddings=False).eval()
autoregressive.load_state_dict(torch.load(args.autoregressive_model_path))
stop_mel_token = autoregressive.stop_mel_token
for voice in args.voice.split(','):
print("Loading GPT TTS..")
autoregressive = UnifiedVoice(max_mel_tokens=300, max_text_tokens=200, max_conditioning_inputs=2, layers=30, model_dim=1024,
heads=16, number_text_tokens=256, start_text_token=255, checkpointing=False, train_solo_embeddings=False).cuda().eval()
autoregressive.load_state_dict(torch.load(args.autoregressive_model_path))
stop_mel_token = autoregressive.stop_mel_token
print("Loading data..")
tokenizer = VoiceBpeTokenizer()
text = torch.IntTensor(tokenizer.encode(args.text)).unsqueeze(0).cuda()
text = F.pad(text, (0,1)) # This may not be necessary.
cond_paths = preselected_cond_voices[args.cond_preset]
conds = []
for cond_path in cond_paths:
c, cond_wav = load_conditioning(cond_path, cond_length=132300)
conds.append(c)
conds = torch.stack(conds, dim=1) # And just use the last cond_wav for the diffusion model.
print("Loading data..")
tokenizer = VoiceBpeTokenizer()
text = torch.IntTensor(tokenizer.encode(args.text)).unsqueeze(0).cuda()
text = F.pad(text, (0,1)) # This may not be necessary.
cond_paths = preselected_cond_voices[voice]
conds = []
for cond_path in cond_paths:
c, cond_wav = load_conditioning(cond_path)
conds.append(c)
conds = torch.stack(conds, dim=1) # And just use the last cond_wav for the diffusion model.
with torch.no_grad():
print("Performing GPT inference..")
samples = []
for b in tqdm(range(args.num_batches)):
codes = autoregressive.inference_speech(conds, text, num_beams=1, repetition_penalty=1.0, do_sample=True, top_k=50, top_p=.95,
temperature=.9, num_return_sequences=args.num_samples//args.num_batches, length_penalty=1)
padding_needed = 250 - codes.shape[1]
codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
samples.append(codes)
samples = torch.cat(samples, dim=0)
del autoregressive
with torch.no_grad():
print("Performing autoregressive inference..")
samples = []
for b in tqdm(range(args.num_batches)):
codes = autoregressive.inference_speech(conds, text, num_beams=1, repetition_penalty=1.0, do_sample=True, top_k=50, top_p=.95,
temperature=.9, num_return_sequences=args.num_samples//args.num_batches, length_penalty=1)
padding_needed = 250 - codes.shape[1]
codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
samples.append(codes)
del autoregressive
print("Loading CLIP..")
clip = VoiceCLIP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=8, text_seq_len=120, text_heads=8,
num_speech_tokens=8192, speech_enc_depth=10, speech_heads=8, speech_seq_len=250).eval()
clip.load_state_dict(torch.load(args.clip_model_path))
print("Performing CLIP filtering..")
for i in range(samples.shape[0]):
samples[i] = fix_autoregressive_output(samples[i], stop_mel_token)
clip_results = clip(text.repeat(samples.shape[0], 1),
torch.full((samples.shape[0],), fill_value=text.shape[1]-1, dtype=torch.long, device='cuda'),
samples, torch.full((samples.shape[0],), fill_value=samples.shape[1]*1024, dtype=torch.long, device='cuda'),
return_loss=False)
best_results = samples[torch.topk(clip_results, k=args.num_outputs).indices]
print("Loading CLIP..")
clip = VoiceCLIP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=8, text_seq_len=120, text_heads=8,
num_speech_tokens=8192, speech_enc_depth=10, speech_heads=8, speech_seq_len=250).cuda().eval()
clip.load_state_dict(torch.load(args.clip_model_path))
print("Performing CLIP filtering..")
clip_results = []
for batch in samples:
for i in range(batch.shape[0]):
batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
text = text[:, :120] # Ugly hack to fix the fact that I didn't train CLIP to handle long enough text.
clip_results.append(clip(text.repeat(batch.shape[0], 1),
torch.full((batch.shape[0],), fill_value=text.shape[1]-1, dtype=torch.long, device='cuda'),
batch, torch.full((batch.shape[0],), fill_value=batch.shape[1]*1024, dtype=torch.long, device='cuda'),
return_loss=False))
clip_results = torch.cat(clip_results, dim=0)
samples = torch.cat(samples, dim=0)
best_results = samples[torch.topk(clip_results, k=args.num_outputs).indices]
# Delete the autoregressive and clip models to free up GPU memory
del samples, clip
# Delete the autoregressive and clip models to free up GPU memory
del samples, clip
print("Loading DVAE..")
dvae = DiscreteVAE(positional_dims=1, channels=80, hidden_dim=512, num_resnet_blocks=3, codebook_dim=512, num_tokens=8192, num_layers=2,
record_codes=True, kernel_size=3, use_transposed_convs=False).eval()
dvae.load_state_dict(torch.load(args.dvae_model_path))
print("Loading Diffusion Model..")
diffusion = DiscreteDiffusionVocoder(model_channels=128, dvae_dim=80, channel_mult=[1, 1, 1.5, 2, 3, 4, 6, 8, 8, 8, 8], num_res_blocks=[1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1],
spectrogram_conditioning_resolutions=[2,512], attention_resolutions=[512,1024], num_heads=4, kernel_size=3, scale_factor=2,
conditioning_inputs_provided=True, time_embed_dim_multiplier=4).eval()
diffusion.load_state_dict(torch.load(args.diffusion_model_path))
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=100)
print("Loading DVAE..")
dvae = DiscreteVAE(positional_dims=1, channels=80, hidden_dim=512, num_resnet_blocks=3, codebook_dim=512, num_tokens=8192, num_layers=2,
record_codes=True, kernel_size=3, use_transposed_convs=False).cuda().eval()
dvae.load_state_dict(torch.load(args.dvae_model_path))
print("Loading Diffusion Model..")
diffusion = DiscreteDiffusionVocoder(model_channels=128, dvae_dim=80, channel_mult=[1, 1, 1.5, 2, 3, 4, 6, 8, 8, 8, 8], num_res_blocks=[1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1],
spectrogram_conditioning_resolutions=[2,512], attention_resolutions=[512,1024], num_heads=4, kernel_size=3, scale_factor=2,
conditioning_inputs_provided=True, time_embed_dim_multiplier=4).cuda().eval()
diffusion.load_state_dict(torch.load(args.diffusion_model_path))
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=100)
print("Performing vocoding..")
# Perform vocoding on each batch element separately: Vocoding is very memory (and compute!) intensive.
for b in range(best_results.shape[0]):
code = best_results[b].unsqueeze(0)
wav = do_spectrogram_diffusion(diffusion, dvae, diffuser, code, cond_wav, spectrogram_compression_factor=256)
torchaudio.save(os.path.join(args.output_path, f'gpt_tts_output_{b}.wav'), wav.squeeze(0).cpu(), 22050)
print("Performing vocoding..")
# Perform vocoding on each batch element separately: The diffusion model is very memory (and compute!) intensive.
for b in range(best_results.shape[0]):
code = best_results[b].unsqueeze(0)
wav = do_spectrogram_diffusion(diffusion, dvae, diffuser, code, cond_wav, spectrogram_compression_factor=256, mean=True)
torchaudio.save(os.path.join(args.output_path, f'{voice}_{b}.wav'), wav.squeeze(0).cpu(), 22050)

View File

@ -1,5 +1,7 @@
import torch
import torchaudio
import numpy as np
from scipy.io.wavfile import read
def load_wav_to_torch(full_path):