From a52e3026ba844d7cb9c9c7f4f426da01a814023c Mon Sep 17 00:00:00 2001 From: Johan Nordberg Date: Wed, 25 May 2022 10:22:50 +0000 Subject: [PATCH 1/2] Revive CVVP model --- tortoise/api.py | 88 ++++++++++++++++++------- tortoise/do_tts.py | 4 +- tortoise/models/cvvp.py | 143 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 211 insertions(+), 24 deletions(-) create mode 100644 tortoise/models/cvvp.py diff --git a/tortoise/api.py b/tortoise/api.py index b9d8c82..59b1604 100644 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -16,6 +16,7 @@ from tqdm import tqdm from tortoise.models.arch_util import TorchMelSpectrogram from tortoise.models.clvp import CLVP +from tortoise.models.cvvp import CVVP from tortoise.models.random_latent_generator import RandomLatentConverter from tortoise.models.vocoder import UnivNetGenerator from tortoise.utils.audio import wav_to_univnet_mel, denormalize_tacotron_mel @@ -26,21 +27,23 @@ from tortoise.utils.wav2vec_alignment import Wav2VecAlignment pbar = None MODELS_DIR = os.environ.get('TORTOISE_MODELS_DIR', '.models') +MODELS = { + 'autoregressive.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/autoregressive.pth', + 'classifier.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/classifier.pth', + 'clvp2.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/clvp2.pth', + 'cvvp.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/cvvp.pth', + 'diffusion_decoder.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/diffusion_decoder.pth', + 'vocoder.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/vocoder.pth', + 'rlg_auto.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_auto.pth', + 'rlg_diffuser.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth', +} def download_models(specific_models=None): """ Call to download all the models that Tortoise uses. """ - MODELS = { - 'autoregressive.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/autoregressive.pth', - 'classifier.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/classifier.pth', - 'clvp2.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/clvp2.pth', - 'diffusion_decoder.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/diffusion_decoder.pth', - 'vocoder.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/vocoder.pth', - 'rlg_auto.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_auto.pth', - 'rlg_diffuser.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth', - } os.makedirs(MODELS_DIR, exist_ok=True) + def show_progress(block_num, block_size, total_size): global pbar if pbar is None: @@ -64,6 +67,18 @@ def download_models(specific_models=None): print('Done.') +def get_model_path(model_name, models_dir=MODELS_DIR): + """ + Get path to given model, download it if it doesn't exist. + """ + if model_name not in MODELS: + raise ValueError(f'Model {model_name} not found in available models.') + model_path = os.path.join(models_dir, model_name) + if not os.path.exists(model_path) and models_dir == MODELS_DIR: + download_models([model_name]) + return model_path + + def pad_or_truncate(t, length): """ Utility function for forcing to have the specified sequence length, whether by clipping it or padding it with 0s. @@ -151,11 +166,10 @@ def classify_audio_clip(clip): :param clip: torch tensor containing audio waveform data (get it from load_audio) :return: True if the clip was classified as coming from Tortoise and false if it was classified as real. """ - download_models(['classifier.pth']) classifier = AudioMiniEncoderWithClassifierHead(2, spec_dim=1, embedding_dim=512, depth=5, downsample_factor=4, resnet_blocks=2, attn_blocks=4, num_attn_heads=4, base_channels=32, dropout=0, kernel_size=5, distribute_zero_label=False) - classifier.load_state_dict(torch.load(os.path.join(MODELS_DIR, 'classifier.pth'), map_location=torch.device('cpu'))) + classifier.load_state_dict(torch.load(get_model_path('classifier.pth'), map_location=torch.device('cpu'))) clip = clip.cpu().unsqueeze(0) results = F.softmax(classifier(clip), dim=-1) return results[0][0] @@ -193,13 +207,13 @@ class TextToSpeech: (but are still rendered by the model). This can be used for prompt engineering. Default is true. """ + self.models_dir = models_dir self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size self.enable_redaction = enable_redaction if self.enable_redaction: self.aligner = Wav2VecAlignment() self.tokenizer = VoiceBpeTokenizer() - download_models() if os.path.exists(f'{models_dir}/autoregressive.ptt'): # Assume this is a traced directory. @@ -210,27 +224,34 @@ class TextToSpeech: model_dim=1024, heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False, train_solo_embeddings=False).cpu().eval() - self.autoregressive.load_state_dict(torch.load(f'{models_dir}/autoregressive.pth')) + self.autoregressive.load_state_dict(torch.load(get_model_path('autoregressive.pth', models_dir))) self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200, in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16, layer_drop=0, unconditioned_percentage=0).cpu().eval() - self.diffusion.load_state_dict(torch.load(f'{models_dir}/diffusion_decoder.pth')) + self.diffusion.load_state_dict(torch.load(get_model_path('diffusion_decoder.pth', models_dir))) self.clvp = CLVP(dim_text=768, dim_speech=768, dim_latent=768, num_text_tokens=256, text_enc_depth=20, text_seq_len=350, text_heads=12, num_speech_tokens=8192, speech_enc_depth=20, speech_heads=12, speech_seq_len=430, use_xformers=True).cpu().eval() - self.clvp.load_state_dict(torch.load(f'{models_dir}/clvp2.pth')) + self.clvp.load_state_dict(torch.load(get_model_path('clvp2.pth', models_dir))) + self.cvvp = None # CVVP model is only loaded if used. self.vocoder = UnivNetGenerator().cpu() - self.vocoder.load_state_dict(torch.load(f'{models_dir}/vocoder.pth')['model_g']) + self.vocoder.load_state_dict(torch.load(get_model_path('vocoder.pth', models_dir))['model_g']) self.vocoder.eval(inference=True) # Random latent generators (RLGs) are loaded lazily. self.rlg_auto = None self.rlg_diffusion = None + def load_cvvp(self): + """Load CVVP model.""" + self.cvvp = CVVP(model_dim=512, transformer_heads=8, dropout=0, mel_codes=8192, conditioning_enc_depth=8, cond_mask_percentage=0, + speech_enc_depth=8, speech_mask_percentage=0, latent_multiplier=1).cpu().eval() + self.cvvp.load_state_dict(torch.load(get_model_path('cvvp.pth', self.models_dir))) + def get_conditioning_latents(self, voice_samples, return_mels=False): """ Transforms one or more voice_samples into a tuple (autoregressive_conditioning_latent, diffusion_conditioning_latent). @@ -273,9 +294,9 @@ class TextToSpeech: # Lazy-load the RLG models. if self.rlg_auto is None: self.rlg_auto = RandomLatentConverter(1024).eval() - self.rlg_auto.load_state_dict(torch.load(os.path.join(MODELS_DIR, 'rlg_auto.pth'), map_location=torch.device('cpu'))) + self.rlg_auto.load_state_dict(torch.load(get_model_path('rlg_auto.pth', self.models_dir), map_location=torch.device('cpu'))) self.rlg_diffusion = RandomLatentConverter(2048).eval() - self.rlg_diffusion.load_state_dict(torch.load(os.path.join(MODELS_DIR, 'rlg_diffuser.pth'), map_location=torch.device('cpu'))) + self.rlg_diffusion.load_state_dict(torch.load(get_model_path('rlg_diffuser.pth', self.models_dir), map_location=torch.device('cpu'))) with torch.no_grad(): return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0])) @@ -305,6 +326,8 @@ class TextToSpeech: return_deterministic_state=False, # autoregressive generation parameters follow num_autoregressive_samples=512, temperature=.8, length_penalty=1, repetition_penalty=2.0, top_p=.8, max_mel_tokens=500, + # CVVP parameters follow + cvvp_amount=.0, # diffusion generation parameters follow diffusion_iterations=100, cond_free=True, cond_free_k=2, diffusion_temperature=1.0, **hf_generate_kwargs): @@ -330,6 +353,9 @@ class TextToSpeech: I was interested in the premise, but the results were not as good as I was hoping. This is off by default, but could use some tuning. :param typical_mass: The typical_mass parameter from the typical_sampling algorithm. + ~~CLVP-CVVP KNOBS~~ + :param cvvp_amount: Controls the influence of the CVVP model in selecting the best output from the autoregressive model. + [0,1]. Values closer to 1 mean the CVVP model is more important, 0 disables the CVVP model. ~~DIFFUSION KNOBS~~ :param diffusion_iterations: Number of diffusion steps to perform. [0,4000]. More steps means the network has more chances to iteratively refine the output, which should theoretically mean a higher quality output. Generally a value above 250 is not noticeably better, @@ -391,19 +417,35 @@ class TextToSpeech: samples.append(codes) self.autoregressive = self.autoregressive.cpu() - clvp_results = [] + clip_results = [] self.clvp = self.clvp.cuda() + if cvvp_amount > 0: + if self.cvvp is None: + self.load_cvvp() + self.cvvp = self.cvvp.cuda() if verbose: - print("Computing best candidates using CLVP") + if self.cvvp is None: + print("Computing best candidates using CLVP") + else: + print(f"Computing best candidates using CLVP {int((1-cvvp_amount) * 100):02d}% and CVVP {int(cvvp_amount * 100):02d}%") for batch in tqdm(samples, disable=not verbose): for i in range(batch.shape[0]): batch[i] = fix_autoregressive_output(batch[i], stop_mel_token) clvp = self.clvp(text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False) - clvp_results.append(clvp) - clvp_results = torch.cat(clvp_results, dim=0) + if auto_conds is not None and cvvp_amount > 0: + cvvp_accumulator = 0 + for cl in range(auto_conds.shape[1]): + cvvp_accumulator = cvvp_accumulator + self.cvvp(auto_conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False) + cvvp = cvvp_accumulator / auto_conds.shape[1] + clip_results.append(cvvp * cvvp_amount + clvp * (1-cvvp_amount)) + else: + clip_results.append(clvp) + clip_results = torch.cat(clip_results, dim=0) samples = torch.cat(samples, dim=0) - best_results = samples[torch.topk(clvp_results, k=k).indices] + best_results = samples[torch.topk(clip_results, k=k).indices] self.clvp = self.clvp.cpu() + if self.cvvp is not None: + self.cvvp = self.cvvp.cpu() del samples # The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning diff --git a/tortoise/do_tts.py b/tortoise/do_tts.py index 1adccc1..47f78ec 100644 --- a/tortoise/do_tts.py +++ b/tortoise/do_tts.py @@ -19,6 +19,8 @@ if __name__ == '__main__': parser.add_argument('--candidates', type=int, help='How many output candidates to produce per-voice.', default=3) parser.add_argument('--seed', type=int, help='Random seed which can be used to reproduce results.', default=None) parser.add_argument('--produce_debug_state', type=bool, help='Whether or not to produce debug_state.pth, which can aid in reproducing problems. Defaults to true.', default=True) + parser.add_argument('--cvvp_amount', type=float, help='How much the CVVP model should influence the output.' + 'Increasing this can in some cases reduce the likelyhood of multiple speakers. Defaults to 0 (disabled)', default=.0) args = parser.parse_args() os.makedirs(args.output_path, exist_ok=True) @@ -33,7 +35,7 @@ if __name__ == '__main__': voice_samples, conditioning_latents = load_voices(voice_sel) gen, dbg_state = tts.tts_with_preset(args.text, k=args.candidates, voice_samples=voice_samples, conditioning_latents=conditioning_latents, - preset=args.preset, use_deterministic_seed=args.seed, return_deterministic_state=True) + preset=args.preset, use_deterministic_seed=args.seed, return_deterministic_state=True, cvvp_amount=args.cvvp_amount) if isinstance(gen, list): for j, g in enumerate(gen): torchaudio.save(os.path.join(args.output_path, f'{selected_voice}_{k}_{j}.wav'), g.squeeze(0).cpu(), 24000) diff --git a/tortoise/models/cvvp.py b/tortoise/models/cvvp.py new file mode 100644 index 0000000..0cce0ae --- /dev/null +++ b/tortoise/models/cvvp.py @@ -0,0 +1,143 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import einsum +from torch.utils.checkpoint import checkpoint + +from models.arch_util import AttentionBlock +from models.xtransformers import ContinuousTransformerWrapper, Encoder + + +def exists(val): + return val is not None + + +def masked_mean(t, mask): + t = t.masked_fill(~mask, 0.) + return t.sum(dim=1) / mask.sum(dim=1) + + +class CollapsingTransformer(nn.Module): + def __init__(self, model_dim, output_dims, heads, dropout, depth, mask_percentage=0, **encoder_kwargs): + super().__init__() + self.transformer = ContinuousTransformerWrapper( + max_seq_len=-1, + use_pos_emb=False, + attn_layers=Encoder( + dim=model_dim, + depth=depth, + heads=heads, + ff_dropout=dropout, + ff_mult=1, + attn_dropout=dropout, + use_rmsnorm=True, + ff_glu=True, + rotary_pos_emb=True, + **encoder_kwargs, + )) + self.pre_combiner = nn.Sequential(nn.Conv1d(model_dim, output_dims, 1), + AttentionBlock( + output_dims, num_heads=heads, do_checkpoint=False), + nn.Conv1d(output_dims, output_dims, 1)) + self.mask_percentage = mask_percentage + + def forward(self, x, **transformer_kwargs): + h = self.transformer(x, **transformer_kwargs) + h = h.permute(0, 2, 1) + h = checkpoint(self.pre_combiner, h).permute(0, 2, 1) + if self.training: + mask = torch.rand_like(h.float()) > self.mask_percentage + else: + mask = torch.ones_like(h.float()).bool() + return masked_mean(h, mask) + + +class ConvFormatEmbedding(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + self.emb = nn.Embedding(*args, **kwargs) + + def forward(self, x): + y = self.emb(x) + return y.permute(0, 2, 1) + + +class CVVP(nn.Module): + def __init__( + self, + model_dim=512, + transformer_heads=8, + dropout=.1, + conditioning_enc_depth=8, + cond_mask_percentage=0, + mel_channels=80, + mel_codes=None, + speech_enc_depth=8, + speech_mask_percentage=0, + latent_multiplier=1, + ): + super().__init__() + latent_dim = latent_multiplier*model_dim + self.temperature = nn.Parameter(torch.tensor(1.)) + + self.cond_emb = nn.Sequential(nn.Conv1d(mel_channels, model_dim//2, kernel_size=5, stride=2, padding=2), + nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1)) + self.conditioning_transformer = CollapsingTransformer( + model_dim, model_dim, transformer_heads, dropout, conditioning_enc_depth, cond_mask_percentage) + self.to_conditioning_latent = nn.Linear( + latent_dim, latent_dim, bias=False) + + if mel_codes is None: + self.speech_emb = nn.Conv1d( + mel_channels, model_dim, kernel_size=5, padding=2) + else: + self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim) + self.speech_transformer = CollapsingTransformer( + model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage) + self.to_speech_latent = nn.Linear( + latent_dim, latent_dim, bias=False) + + def get_grad_norm_parameter_groups(self): + return { + 'conditioning': list(self.conditioning_transformer.parameters()), + 'speech': list(self.speech_transformer.parameters()), + } + + def forward( + self, + mel_cond, + mel_input, + return_loss=False + ): + cond_emb = self.cond_emb(mel_cond).permute(0, 2, 1) + enc_cond = self.conditioning_transformer(cond_emb) + cond_latents = self.to_conditioning_latent(enc_cond) + + speech_emb = self.speech_emb(mel_input).permute(0, 2, 1) + enc_speech = self.speech_transformer(speech_emb) + speech_latents = self.to_speech_latent(enc_speech) + + cond_latents, speech_latents = map(lambda t: F.normalize( + t, p=2, dim=-1), (cond_latents, speech_latents)) + temp = self.temperature.exp() + + if not return_loss: + sim = einsum('n d, n d -> n', cond_latents, + speech_latents) * temp + return sim + + sim = einsum('i d, j d -> i j', cond_latents, + speech_latents) * temp + labels = torch.arange( + cond_latents.shape[0], device=mel_input.device) + loss = (F.cross_entropy(sim, labels) + + F.cross_entropy(sim.t(), labels)) / 2 + + return loss + + +if __name__ == '__main__': + clvp = CVVP() + clvp(torch.randn(2, 80, 100), + torch.randn(2, 80, 95), + return_loss=True) From b681fa9d11fcbeb6b41a92846c77cb72b4309020 Mon Sep 17 00:00:00 2001 From: Johan Nordberg Date: Wed, 25 May 2022 11:12:53 +0000 Subject: [PATCH 2/2] Skip CLVP if cvvp_amount is 1 Also fixes formatting bug in log message --- tortoise/api.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tortoise/api.py b/tortoise/api.py index 59b1604..f3b729f 100644 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -427,17 +427,21 @@ class TextToSpeech: if self.cvvp is None: print("Computing best candidates using CLVP") else: - print(f"Computing best candidates using CLVP {int((1-cvvp_amount) * 100):02d}% and CVVP {int(cvvp_amount * 100):02d}%") + print(f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%") for batch in tqdm(samples, disable=not verbose): for i in range(batch.shape[0]): batch[i] = fix_autoregressive_output(batch[i], stop_mel_token) - clvp = self.clvp(text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False) + if cvvp_amount != 1: + clvp = self.clvp(text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False) if auto_conds is not None and cvvp_amount > 0: cvvp_accumulator = 0 for cl in range(auto_conds.shape[1]): cvvp_accumulator = cvvp_accumulator + self.cvvp(auto_conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False) cvvp = cvvp_accumulator / auto_conds.shape[1] - clip_results.append(cvvp * cvvp_amount + clvp * (1-cvvp_amount)) + if cvvp_amount == 1: + clip_results.append(cvvp) + else: + clip_results.append(cvvp * cvvp_amount + clvp * (1-cvvp_amount)) else: clip_results.append(clvp) clip_results = torch.cat(clip_results, dim=0)