diff --git a/.gitignore b/.gitignore index 82504f8..d632a14 100644 --- a/.gitignore +++ b/.gitignore @@ -129,6 +129,7 @@ dmypy.json .pyre/ .idea/* -.models/* +tortoise/.models/* +tortoise/random_voices/* .custom/* results/* \ No newline at end of file diff --git a/tortoise/api.py b/tortoise/api.py index 6cb4733..07ce6b4 100644 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -1,5 +1,6 @@ import os import random +import uuid from urllib import request import torch @@ -15,6 +16,7 @@ from tqdm import tqdm from tortoise.models.arch_util import TorchMelSpectrogram from tortoise.models.clvp import CLVP +from tortoise.models.random_latent_generator import RandomLatentConverter from tortoise.models.vocoder import UnivNetGenerator from tortoise.utils.audio import wav_to_univnet_mel, denormalize_tacotron_mel from tortoise.utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule @@ -161,7 +163,8 @@ class TextToSpeech: Main entry point into Tortoise. """ - def __init__(self, autoregressive_batch_size=16, models_dir='.models', enable_redaction=True): + def __init__(self, autoregressive_batch_size=16, models_dir='.models', enable_redaction=True, + save_random_voices=False): """ Constructor :param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing @@ -170,11 +173,15 @@ class TextToSpeech: models, otherwise use the defaults. :param enable_redaction: When true, text enclosed in brackets are automatically redacted from the spoken output (but are still rendered by the model). This can be used for prompt engineering. + Default is true. + :param save_random_voices: When true, voices that are randomly generated are saved to the `random_voices` + directory. Default is false. """ self.autoregressive_batch_size = autoregressive_batch_size self.enable_redaction = enable_redaction if self.enable_redaction: self.aligner = Wav2VecAlignment() + self.save_random_voices = save_random_voices self.tokenizer = VoiceBpeTokenizer() download_models() @@ -210,6 +217,10 @@ class TextToSpeech: self.vocoder.load_state_dict(torch.load(f'{models_dir}/vocoder.pth')['model_g']) self.vocoder.eval(inference=True) + # Random latent generators (RLGs) are loaded lazily. + self.rlg_auto = None + self.rlg_diffusion = None + def tts_with_preset(self, text, preset='fast', **kwargs): """ Calls TTS with one of a set of preset generation parameters. Options: @@ -265,7 +276,21 @@ class TextToSpeech: diffusion_latent = self.diffusion.get_conditioning(diffusion_conds) self.diffusion = self.diffusion.cpu() - return auto_latent, diffusion_latent + return auto_latent, diffusion_latent, auto_conds + + def get_random_conditioning_latents(self): + # Lazy-load the RLG models. + if self.rlg_auto is None: + self.rlg_auto = RandomLatentConverter(1024).eval() + self.rlg_auto.load_state_dict(torch.load('.models/rlg_auto.pth', map_location=torch.device('cpu'))) + self.rlg_diffusion = RandomLatentConverter(2048).eval() + self.rlg_diffusion.load_state_dict(torch.load('.models/rlg_diffuser.pth', map_location=torch.device('cpu'))) + with torch.no_grad(): + latents = self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0])) + if self.save_random_voices: + os.makedirs('random_voices', exist_ok=True) + torch.save(latents, f'random_voices/{str(uuid.uuid4())}.pth') + return latents def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose=True, # autoregressive generation parameters follow @@ -323,14 +348,19 @@ class TextToSpeech: :return: Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length. Sample rate is 24kHz. """ - text = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).cuda() - text = F.pad(text, (0, 1)) # This may not be necessary. - assert text.shape[-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.' + text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).cuda() + text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary. + assert text_tokens.shape[-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.' + auto_conds = None if voice_samples is not None: - auto_conditioning, diffusion_conditioning = self.get_conditioning_latents(voice_samples) - else: + auto_conditioning, diffusion_conditioning, auto_conds = self.get_conditioning_latents(voice_samples) + elif conditioning_latents is not None: auto_conditioning, diffusion_conditioning = conditioning_latents + else: + auto_conditioning, diffusion_conditioning = self.get_random_conditioning_latents() + auto_conditioning = auto_conditioning.cuda() + diffusion_conditioning = diffusion_conditioning.cuda() diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k) @@ -343,7 +373,7 @@ class TextToSpeech: if verbose: print("Generating autoregressive samples..") for b in tqdm(range(num_batches), disable=not verbose): - codes = self.autoregressive.inference_speech(auto_conditioning, text, + codes = self.autoregressive.inference_speech(auto_conditioning, text_tokens, do_sample=True, top_p=top_p, temperature=temperature, @@ -365,12 +395,15 @@ class TextToSpeech: for batch in tqdm(samples, disable=not verbose): for i in range(batch.shape[0]): batch[i] = fix_autoregressive_output(batch[i], stop_mel_token) - clvp = self.clvp(text.repeat(batch.shape[0], 1), batch, return_loss=False) - cvvp_accumulator = 0 - for cl in range(conds.shape[1]): - cvvp_accumulator = cvvp_accumulator + self.cvvp(conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False ) - cvvp = cvvp_accumulator / conds.shape[1] - clip_results.append(clvp * clvp_cvvp_slider + cvvp * (1-clvp_cvvp_slider)) + clvp = self.clvp(text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False) + if auto_conds is not None: + cvvp_accumulator = 0 + for cl in range(auto_conds.shape[1]): + cvvp_accumulator = cvvp_accumulator + self.cvvp(auto_conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False) + cvvp = cvvp_accumulator / auto_conds.shape[1] + clip_results.append(clvp * clvp_cvvp_slider + cvvp * (1-clvp_cvvp_slider)) + else: + clip_results.append(clvp) clip_results = torch.cat(clip_results, dim=0) samples = torch.cat(samples, dim=0) best_results = samples[torch.topk(clip_results, k=k).indices] @@ -382,8 +415,8 @@ class TextToSpeech: # inputs. Re-produce those for the top results. This could be made more efficient by storing all of these # results, but will increase memory usage. self.autoregressive = self.autoregressive.cuda() - best_latents = self.autoregressive(auto_conditioning, text, torch.tensor([text.shape[-1]], device=conds.device), best_results, - torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=conds.device), + best_latents = self.autoregressive(auto_conditioning, text_tokens, torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results, + torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device), return_latent=True, clip_inputs=False) self.autoregressive = self.autoregressive.cpu() del auto_conditioning @@ -415,7 +448,7 @@ class TextToSpeech: self.diffusion = self.diffusion.cpu() self.vocoder = self.vocoder.cpu() - def potentially_redact(self, clip, text): + def potentially_redact(clip, text): if self.enable_redaction: return self.aligner.redact(clip, text) return clip diff --git a/tortoise/do_tts.py b/tortoise/do_tts.py index 6f2bd88..6abf8ea 100644 --- a/tortoise/do_tts.py +++ b/tortoise/do_tts.py @@ -10,23 +10,23 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.") parser.add_argument('--voice', type=str, help='Selects the voice to use for generation. See options in voices/ directory (and add your own!) ' - 'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.', default='pat') - parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='standard') + 'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.', default='random') + parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='fast') parser.add_argument('--voice_diversity_intelligibility_slider', type=float, help='How to balance vocal diversity with the quality/intelligibility of the spoken text. 0 means highly diverse voice (not recommended), 1 means maximize intellibility', default=.5) - parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/') + parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='../results/') parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this' 'should only be specified if you have custom checkpoints.', default='.models') args = parser.parse_args() os.makedirs(args.output_path, exist_ok=True) - tts = TextToSpeech(models_dir=args.model_dir) + tts = TextToSpeech(models_dir=args.model_dir, save_random_voices=True) selected_voices = args.voice.split(',') - for voice in selected_voices: + for k, voice in enumerate(selected_voices): voice_samples, conditioning_latents = load_voice(voice) gen = tts.tts_with_preset(args.text, voice_samples=voice_samples, conditioning_latents=conditioning_latents, preset=args.preset, clvp_cvvp_slider=args.voice_diversity_intelligibility_slider) - torchaudio.save(os.path.join(args.output_path, f'{voice}.wav'), gen.squeeze(0).cpu(), 24000) + torchaudio.save(os.path.join(args.output_path, f'{voice}_{k}.wav'), gen.squeeze(0).cpu(), 24000) diff --git a/tortoise/models/autoregressive.py b/tortoise/models/autoregressive.py index 1cd02cb..e56ad27 100644 --- a/tortoise/models/autoregressive.py +++ b/tortoise/models/autoregressive.py @@ -401,13 +401,13 @@ class UnifiedVoice(nn.Module): conds = conds.mean(dim=1).unsqueeze(1) return conds - def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, types=None, text_first=True, raw_mels=None, return_attentions=False, + def forward(self, speech_conditioning_latent, text_inputs, text_lengths, mel_codes, wav_lengths, types=None, text_first=True, raw_mels=None, return_attentions=False, return_latent=False, clip_inputs=True): """ Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode (actuated by `text_first`). - speech_conditioning_input: MEL float tensor, (b,80,s) + speech_conditioning_input: MEL float tensor, (b,1024) text_inputs: long tensor, (b,t) text_lengths: long tensor, (b,) mel_inputs: long tensor, (b,m) @@ -421,7 +421,7 @@ class UnifiedVoice(nn.Module): # Types are expressed by expanding the text embedding space. if types is not None: text_inputs = text_inputs * (1+types).unsqueeze(-1) - + if clip_inputs: # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by # chopping the inputs by the maximum actual length. @@ -435,7 +435,7 @@ class UnifiedVoice(nn.Module): text_inputs = F.pad(text_inputs, (0,1), value=self.stop_text_token) mel_codes = F.pad(mel_codes, (0,1), value=self.stop_mel_token) - conds = self.get_conditioning(speech_conditioning_input) + conds = speech_conditioning_latent.unsqueeze(1) text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token) @@ -540,7 +540,7 @@ class UnifiedVoice(nn.Module): text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) - conds = speech_conditioning_latent + conds = speech_conditioning_latent.unsqueeze(1) emb = torch.cat([conds, text_emb], dim=1) self.inference_model.store_mel_emb(emb) diff --git a/tortoise/models/diffusion_decoder.py b/tortoise/models/diffusion_decoder.py index fc3990f..f67d21a 100644 --- a/tortoise/models/diffusion_decoder.py +++ b/tortoise/models/diffusion_decoder.py @@ -226,6 +226,7 @@ class DiffusionTts(nn.Module): for j in range(speech_conditioning_input.shape[1]): conds.append(self.contextual_embedder(speech_conditioning_input[:, j])) conds = torch.cat(conds, dim=-1) + conds = conds.mean(dim=-1) return conds def timestep_independent(self, aligned_conditioning, conditioning_latent, expected_seq_len, return_code_pred): @@ -233,9 +234,7 @@ class DiffusionTts(nn.Module): if is_latent(aligned_conditioning): aligned_conditioning = aligned_conditioning.permute(0, 2, 1) - conds = conditioning_latent - cond_emb = conds.mean(dim=-1) - cond_scale, cond_shift = torch.chunk(cond_emb, 2, dim=1) + cond_scale, cond_shift = torch.chunk(conditioning_latent, 2, dim=1) if is_latent(aligned_conditioning): code_emb = self.latent_conditioner(aligned_conditioning) else: diff --git a/tortoise/models/random_latent_generator.py b/tortoise/models/random_latent_generator.py new file mode 100644 index 0000000..e90ef21 --- /dev/null +++ b/tortoise/models/random_latent_generator.py @@ -0,0 +1,55 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5): + if bias is not None: + rest_dim = [1] * (input.ndim - bias.ndim - 1) + return ( + F.leaky_relu( + input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope + ) + * scale + ) + else: + return F.leaky_relu(input, negative_slope=0.2) * scale + + +class EqualLinear(nn.Module): + def __init__( + self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1 + ): + super().__init__() + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + else: + self.bias = None + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + return out + + +class RandomLatentConverter(nn.Module): + def __init__(self, channels): + super().__init__() + self.layers = nn.Sequential(*[EqualLinear(channels, channels, lr_mul=.1) for _ in range(5)], + nn.Linear(channels, channels)) + self.channels = channels + + def forward(self, ref): + r = torch.randn(ref.shape[0], self.channels, device=ref.device) + y = self.layers(r) + return y + + +if __name__ == '__main__': + model = RandomLatentConverter(512) + model(torch.randn(5,512)) \ No newline at end of file diff --git a/tortoise/read.py b/tortoise/read.py index b22f62e..2300308 100644 --- a/tortoise/read.py +++ b/tortoise/read.py @@ -31,7 +31,7 @@ if __name__ == '__main__': parser.add_argument('--textfile', type=str, help='A file containing the text to read.', default="data/riding_hood.txt") parser.add_argument('--voice', type=str, help='Selects the voice to use for generation. See options in voices/ directory (and add your own!) ' 'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.', default='pat') - parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/longform/') + parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='../results/longform/') parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='standard') parser.add_argument('--regenerate', type=str, help='Comma-separated list of clip numbers to re-generate, or nothing.', default=None) parser.add_argument('--voice_diversity_intelligibility_slider', type=float, @@ -40,7 +40,7 @@ if __name__ == '__main__': parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this' 'should only be specified if you have custom checkpoints.', default='.models') args = parser.parse_args() - tts = TextToSpeech(models_dir=args.model_dir) + tts = TextToSpeech(models_dir=args.model_dir, save_random_voices=True) outpath = args.output_path selected_voices = args.voice.split(',') diff --git a/tortoise/utils/audio.py b/tortoise/utils/audio.py index c9cbcf5..a33abf1 100644 --- a/tortoise/utils/audio.py +++ b/tortoise/utils/audio.py @@ -92,6 +92,9 @@ def get_voices(): def load_voice(voice): + if voice == 'random': + return None, None + voices = get_voices() paths = voices[voice] if len(paths) == 1 and paths[0].endswith('.pth'):