1
1
forked from mrq/tortoise-tts

Support totally random voices (and make fixes to previous changes)

This commit is contained in:
James Betker 2022-05-02 15:40:03 -06:00
parent a57fcaf814
commit ee24d3ee4b
8 changed files with 125 additions and 34 deletions

3
.gitignore vendored
View File

@ -129,6 +129,7 @@ dmypy.json
.pyre/ .pyre/
.idea/* .idea/*
.models/* tortoise/.models/*
tortoise/random_voices/*
.custom/* .custom/*
results/* results/*

View File

@ -1,5 +1,6 @@
import os import os
import random import random
import uuid
from urllib import request from urllib import request
import torch import torch
@ -15,6 +16,7 @@ from tqdm import tqdm
from tortoise.models.arch_util import TorchMelSpectrogram from tortoise.models.arch_util import TorchMelSpectrogram
from tortoise.models.clvp import CLVP from tortoise.models.clvp import CLVP
from tortoise.models.random_latent_generator import RandomLatentConverter
from tortoise.models.vocoder import UnivNetGenerator from tortoise.models.vocoder import UnivNetGenerator
from tortoise.utils.audio import wav_to_univnet_mel, denormalize_tacotron_mel from tortoise.utils.audio import wav_to_univnet_mel, denormalize_tacotron_mel
from tortoise.utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule from tortoise.utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule
@ -161,7 +163,8 @@ class TextToSpeech:
Main entry point into Tortoise. Main entry point into Tortoise.
""" """
def __init__(self, autoregressive_batch_size=16, models_dir='.models', enable_redaction=True): def __init__(self, autoregressive_batch_size=16, models_dir='.models', enable_redaction=True,
save_random_voices=False):
""" """
Constructor Constructor
:param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing :param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
@ -170,11 +173,15 @@ class TextToSpeech:
models, otherwise use the defaults. models, otherwise use the defaults.
:param enable_redaction: When true, text enclosed in brackets are automatically redacted from the spoken output :param enable_redaction: When true, text enclosed in brackets are automatically redacted from the spoken output
(but are still rendered by the model). This can be used for prompt engineering. (but are still rendered by the model). This can be used for prompt engineering.
Default is true.
:param save_random_voices: When true, voices that are randomly generated are saved to the `random_voices`
directory. Default is false.
""" """
self.autoregressive_batch_size = autoregressive_batch_size self.autoregressive_batch_size = autoregressive_batch_size
self.enable_redaction = enable_redaction self.enable_redaction = enable_redaction
if self.enable_redaction: if self.enable_redaction:
self.aligner = Wav2VecAlignment() self.aligner = Wav2VecAlignment()
self.save_random_voices = save_random_voices
self.tokenizer = VoiceBpeTokenizer() self.tokenizer = VoiceBpeTokenizer()
download_models() download_models()
@ -210,6 +217,10 @@ class TextToSpeech:
self.vocoder.load_state_dict(torch.load(f'{models_dir}/vocoder.pth')['model_g']) self.vocoder.load_state_dict(torch.load(f'{models_dir}/vocoder.pth')['model_g'])
self.vocoder.eval(inference=True) self.vocoder.eval(inference=True)
# Random latent generators (RLGs) are loaded lazily.
self.rlg_auto = None
self.rlg_diffusion = None
def tts_with_preset(self, text, preset='fast', **kwargs): def tts_with_preset(self, text, preset='fast', **kwargs):
""" """
Calls TTS with one of a set of preset generation parameters. Options: Calls TTS with one of a set of preset generation parameters. Options:
@ -265,7 +276,21 @@ class TextToSpeech:
diffusion_latent = self.diffusion.get_conditioning(diffusion_conds) diffusion_latent = self.diffusion.get_conditioning(diffusion_conds)
self.diffusion = self.diffusion.cpu() self.diffusion = self.diffusion.cpu()
return auto_latent, diffusion_latent return auto_latent, diffusion_latent, auto_conds
def get_random_conditioning_latents(self):
# Lazy-load the RLG models.
if self.rlg_auto is None:
self.rlg_auto = RandomLatentConverter(1024).eval()
self.rlg_auto.load_state_dict(torch.load('.models/rlg_auto.pth', map_location=torch.device('cpu')))
self.rlg_diffusion = RandomLatentConverter(2048).eval()
self.rlg_diffusion.load_state_dict(torch.load('.models/rlg_diffuser.pth', map_location=torch.device('cpu')))
with torch.no_grad():
latents = self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0]))
if self.save_random_voices:
os.makedirs('random_voices', exist_ok=True)
torch.save(latents, f'random_voices/{str(uuid.uuid4())}.pth')
return latents
def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose=True, def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose=True,
# autoregressive generation parameters follow # autoregressive generation parameters follow
@ -323,14 +348,19 @@ class TextToSpeech:
:return: Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length. :return: Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length.
Sample rate is 24kHz. Sample rate is 24kHz.
""" """
text = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).cuda() text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).cuda()
text = F.pad(text, (0, 1)) # This may not be necessary. text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
assert text.shape[-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.' assert text_tokens.shape[-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.'
auto_conds = None
if voice_samples is not None: if voice_samples is not None:
auto_conditioning, diffusion_conditioning = self.get_conditioning_latents(voice_samples) auto_conditioning, diffusion_conditioning, auto_conds = self.get_conditioning_latents(voice_samples)
else: elif conditioning_latents is not None:
auto_conditioning, diffusion_conditioning = conditioning_latents auto_conditioning, diffusion_conditioning = conditioning_latents
else:
auto_conditioning, diffusion_conditioning = self.get_random_conditioning_latents()
auto_conditioning = auto_conditioning.cuda()
diffusion_conditioning = diffusion_conditioning.cuda()
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k) diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k)
@ -343,7 +373,7 @@ class TextToSpeech:
if verbose: if verbose:
print("Generating autoregressive samples..") print("Generating autoregressive samples..")
for b in tqdm(range(num_batches), disable=not verbose): for b in tqdm(range(num_batches), disable=not verbose):
codes = self.autoregressive.inference_speech(auto_conditioning, text, codes = self.autoregressive.inference_speech(auto_conditioning, text_tokens,
do_sample=True, do_sample=True,
top_p=top_p, top_p=top_p,
temperature=temperature, temperature=temperature,
@ -365,12 +395,15 @@ class TextToSpeech:
for batch in tqdm(samples, disable=not verbose): for batch in tqdm(samples, disable=not verbose):
for i in range(batch.shape[0]): for i in range(batch.shape[0]):
batch[i] = fix_autoregressive_output(batch[i], stop_mel_token) batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
clvp = self.clvp(text.repeat(batch.shape[0], 1), batch, return_loss=False) clvp = self.clvp(text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False)
cvvp_accumulator = 0 if auto_conds is not None:
for cl in range(conds.shape[1]): cvvp_accumulator = 0
cvvp_accumulator = cvvp_accumulator + self.cvvp(conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False ) for cl in range(auto_conds.shape[1]):
cvvp = cvvp_accumulator / conds.shape[1] cvvp_accumulator = cvvp_accumulator + self.cvvp(auto_conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False)
clip_results.append(clvp * clvp_cvvp_slider + cvvp * (1-clvp_cvvp_slider)) cvvp = cvvp_accumulator / auto_conds.shape[1]
clip_results.append(clvp * clvp_cvvp_slider + cvvp * (1-clvp_cvvp_slider))
else:
clip_results.append(clvp)
clip_results = torch.cat(clip_results, dim=0) clip_results = torch.cat(clip_results, dim=0)
samples = torch.cat(samples, dim=0) samples = torch.cat(samples, dim=0)
best_results = samples[torch.topk(clip_results, k=k).indices] best_results = samples[torch.topk(clip_results, k=k).indices]
@ -382,8 +415,8 @@ class TextToSpeech:
# inputs. Re-produce those for the top results. This could be made more efficient by storing all of these # inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
# results, but will increase memory usage. # results, but will increase memory usage.
self.autoregressive = self.autoregressive.cuda() self.autoregressive = self.autoregressive.cuda()
best_latents = self.autoregressive(auto_conditioning, text, torch.tensor([text.shape[-1]], device=conds.device), best_results, best_latents = self.autoregressive(auto_conditioning, text_tokens, torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results,
torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=conds.device), torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device),
return_latent=True, clip_inputs=False) return_latent=True, clip_inputs=False)
self.autoregressive = self.autoregressive.cpu() self.autoregressive = self.autoregressive.cpu()
del auto_conditioning del auto_conditioning
@ -415,7 +448,7 @@ class TextToSpeech:
self.diffusion = self.diffusion.cpu() self.diffusion = self.diffusion.cpu()
self.vocoder = self.vocoder.cpu() self.vocoder = self.vocoder.cpu()
def potentially_redact(self, clip, text): def potentially_redact(clip, text):
if self.enable_redaction: if self.enable_redaction:
return self.aligner.redact(clip, text) return self.aligner.redact(clip, text)
return clip return clip

View File

@ -10,23 +10,23 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.") parser.add_argument('--text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.")
parser.add_argument('--voice', type=str, help='Selects the voice to use for generation. See options in voices/ directory (and add your own!) ' parser.add_argument('--voice', type=str, help='Selects the voice to use for generation. See options in voices/ directory (and add your own!) '
'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.', default='pat') 'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.', default='random')
parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='standard') parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='fast')
parser.add_argument('--voice_diversity_intelligibility_slider', type=float, parser.add_argument('--voice_diversity_intelligibility_slider', type=float,
help='How to balance vocal diversity with the quality/intelligibility of the spoken text. 0 means highly diverse voice (not recommended), 1 means maximize intellibility', help='How to balance vocal diversity with the quality/intelligibility of the spoken text. 0 means highly diverse voice (not recommended), 1 means maximize intellibility',
default=.5) default=.5)
parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/') parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='../results/')
parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this' parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this'
'should only be specified if you have custom checkpoints.', default='.models') 'should only be specified if you have custom checkpoints.', default='.models')
args = parser.parse_args() args = parser.parse_args()
os.makedirs(args.output_path, exist_ok=True) os.makedirs(args.output_path, exist_ok=True)
tts = TextToSpeech(models_dir=args.model_dir) tts = TextToSpeech(models_dir=args.model_dir, save_random_voices=True)
selected_voices = args.voice.split(',') selected_voices = args.voice.split(',')
for voice in selected_voices: for k, voice in enumerate(selected_voices):
voice_samples, conditioning_latents = load_voice(voice) voice_samples, conditioning_latents = load_voice(voice)
gen = tts.tts_with_preset(args.text, voice_samples=voice_samples, conditioning_latents=conditioning_latents, gen = tts.tts_with_preset(args.text, voice_samples=voice_samples, conditioning_latents=conditioning_latents,
preset=args.preset, clvp_cvvp_slider=args.voice_diversity_intelligibility_slider) preset=args.preset, clvp_cvvp_slider=args.voice_diversity_intelligibility_slider)
torchaudio.save(os.path.join(args.output_path, f'{voice}.wav'), gen.squeeze(0).cpu(), 24000) torchaudio.save(os.path.join(args.output_path, f'{voice}_{k}.wav'), gen.squeeze(0).cpu(), 24000)

View File

@ -401,13 +401,13 @@ class UnifiedVoice(nn.Module):
conds = conds.mean(dim=1).unsqueeze(1) conds = conds.mean(dim=1).unsqueeze(1)
return conds return conds
def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, types=None, text_first=True, raw_mels=None, return_attentions=False, def forward(self, speech_conditioning_latent, text_inputs, text_lengths, mel_codes, wav_lengths, types=None, text_first=True, raw_mels=None, return_attentions=False,
return_latent=False, clip_inputs=True): return_latent=False, clip_inputs=True):
""" """
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
(actuated by `text_first`). (actuated by `text_first`).
speech_conditioning_input: MEL float tensor, (b,80,s) speech_conditioning_input: MEL float tensor, (b,1024)
text_inputs: long tensor, (b,t) text_inputs: long tensor, (b,t)
text_lengths: long tensor, (b,) text_lengths: long tensor, (b,)
mel_inputs: long tensor, (b,m) mel_inputs: long tensor, (b,m)
@ -421,7 +421,7 @@ class UnifiedVoice(nn.Module):
# Types are expressed by expanding the text embedding space. # Types are expressed by expanding the text embedding space.
if types is not None: if types is not None:
text_inputs = text_inputs * (1+types).unsqueeze(-1) text_inputs = text_inputs * (1+types).unsqueeze(-1)
if clip_inputs: if clip_inputs:
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
# chopping the inputs by the maximum actual length. # chopping the inputs by the maximum actual length.
@ -435,7 +435,7 @@ class UnifiedVoice(nn.Module):
text_inputs = F.pad(text_inputs, (0,1), value=self.stop_text_token) text_inputs = F.pad(text_inputs, (0,1), value=self.stop_text_token)
mel_codes = F.pad(mel_codes, (0,1), value=self.stop_mel_token) mel_codes = F.pad(mel_codes, (0,1), value=self.stop_mel_token)
conds = self.get_conditioning(speech_conditioning_input) conds = speech_conditioning_latent.unsqueeze(1)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token) mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
@ -540,7 +540,7 @@ class UnifiedVoice(nn.Module):
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
conds = speech_conditioning_latent conds = speech_conditioning_latent.unsqueeze(1)
emb = torch.cat([conds, text_emb], dim=1) emb = torch.cat([conds, text_emb], dim=1)
self.inference_model.store_mel_emb(emb) self.inference_model.store_mel_emb(emb)

View File

@ -226,6 +226,7 @@ class DiffusionTts(nn.Module):
for j in range(speech_conditioning_input.shape[1]): for j in range(speech_conditioning_input.shape[1]):
conds.append(self.contextual_embedder(speech_conditioning_input[:, j])) conds.append(self.contextual_embedder(speech_conditioning_input[:, j]))
conds = torch.cat(conds, dim=-1) conds = torch.cat(conds, dim=-1)
conds = conds.mean(dim=-1)
return conds return conds
def timestep_independent(self, aligned_conditioning, conditioning_latent, expected_seq_len, return_code_pred): def timestep_independent(self, aligned_conditioning, conditioning_latent, expected_seq_len, return_code_pred):
@ -233,9 +234,7 @@ class DiffusionTts(nn.Module):
if is_latent(aligned_conditioning): if is_latent(aligned_conditioning):
aligned_conditioning = aligned_conditioning.permute(0, 2, 1) aligned_conditioning = aligned_conditioning.permute(0, 2, 1)
conds = conditioning_latent cond_scale, cond_shift = torch.chunk(conditioning_latent, 2, dim=1)
cond_emb = conds.mean(dim=-1)
cond_scale, cond_shift = torch.chunk(cond_emb, 2, dim=1)
if is_latent(aligned_conditioning): if is_latent(aligned_conditioning):
code_emb = self.latent_conditioner(aligned_conditioning) code_emb = self.latent_conditioner(aligned_conditioning)
else: else:

View File

@ -0,0 +1,55 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
if bias is not None:
rest_dim = [1] * (input.ndim - bias.ndim - 1)
return (
F.leaky_relu(
input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope
)
* scale
)
else:
return F.leaky_relu(input, negative_slope=0.2) * scale
class EqualLinear(nn.Module):
def __init__(
self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1
):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
if bias:
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
else:
self.bias = None
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
self.lr_mul = lr_mul
def forward(self, input):
out = F.linear(input, self.weight * self.scale)
out = fused_leaky_relu(out, self.bias * self.lr_mul)
return out
class RandomLatentConverter(nn.Module):
def __init__(self, channels):
super().__init__()
self.layers = nn.Sequential(*[EqualLinear(channels, channels, lr_mul=.1) for _ in range(5)],
nn.Linear(channels, channels))
self.channels = channels
def forward(self, ref):
r = torch.randn(ref.shape[0], self.channels, device=ref.device)
y = self.layers(r)
return y
if __name__ == '__main__':
model = RandomLatentConverter(512)
model(torch.randn(5,512))

View File

@ -31,7 +31,7 @@ if __name__ == '__main__':
parser.add_argument('--textfile', type=str, help='A file containing the text to read.', default="data/riding_hood.txt") parser.add_argument('--textfile', type=str, help='A file containing the text to read.', default="data/riding_hood.txt")
parser.add_argument('--voice', type=str, help='Selects the voice to use for generation. See options in voices/ directory (and add your own!) ' parser.add_argument('--voice', type=str, help='Selects the voice to use for generation. See options in voices/ directory (and add your own!) '
'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.', default='pat') 'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.', default='pat')
parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/longform/') parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='../results/longform/')
parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='standard') parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='standard')
parser.add_argument('--regenerate', type=str, help='Comma-separated list of clip numbers to re-generate, or nothing.', default=None) parser.add_argument('--regenerate', type=str, help='Comma-separated list of clip numbers to re-generate, or nothing.', default=None)
parser.add_argument('--voice_diversity_intelligibility_slider', type=float, parser.add_argument('--voice_diversity_intelligibility_slider', type=float,
@ -40,7 +40,7 @@ if __name__ == '__main__':
parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this' parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this'
'should only be specified if you have custom checkpoints.', default='.models') 'should only be specified if you have custom checkpoints.', default='.models')
args = parser.parse_args() args = parser.parse_args()
tts = TextToSpeech(models_dir=args.model_dir) tts = TextToSpeech(models_dir=args.model_dir, save_random_voices=True)
outpath = args.output_path outpath = args.output_path
selected_voices = args.voice.split(',') selected_voices = args.voice.split(',')

View File

@ -92,6 +92,9 @@ def get_voices():
def load_voice(voice): def load_voice(voice):
if voice == 'random':
return None, None
voices = get_voices() voices = get_voices()
paths = voices[voice] paths = voices[voice]
if len(paths) == 1 and paths[0].endswith('.pth'): if len(paths) == 1 and paths[0].endswith('.pth'):