From 8139afd0e5e3a3bd41675d6a212394d7cda1c94f Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Tue, 17 May 2022 12:21:25 -0600
Subject: [PATCH] Remove CVVP

After training a similar model for a different purpose, I realized that
this model is faulty: the contrastive loss it uses only pays attention
to high-frequency details which do not contribute meaningfully to
output quality. I validated this by comparing a no-CVVP output with
a baseline using tts-scores and found no differences.
---
 tortoise/api.py         |  36 +++--------
 tortoise/do_tts.py      |   6 +-
 tortoise/models/cvvp.py | 133 ----------------------------------------
 tortoise/read.py        |   6 +-
 4 files changed, 9 insertions(+), 172 deletions(-)
 delete mode 100644 tortoise/models/cvvp.py

diff --git a/tortoise/api.py b/tortoise/api.py
index 5abcb95..a724ae6 100644
--- a/tortoise/api.py
+++ b/tortoise/api.py
@@ -10,7 +10,6 @@ import progressbar
 import torchaudio
 
 from tortoise.models.classifier import AudioMiniEncoderWithClassifierHead
-from tortoise.models.cvvp import CVVP
 from tortoise.models.diffusion_decoder import DiffusionTts
 from tortoise.models.autoregressive import UnifiedVoice
 from tqdm import tqdm
@@ -35,7 +34,6 @@ def download_models(specific_models=None):
         'autoregressive.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/autoregressive.pth',
         'classifier.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/classifier.pth',
         'clvp2.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/clvp2.pth',
-        'cvvp.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/cvvp.pth',
         'diffusion_decoder.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/diffusion_decoder.pth',
         'vocoder.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/vocoder.pth',
         'rlg_auto.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_auto.pth',
@@ -223,10 +221,6 @@ class TextToSpeech:
                          use_xformers=True).cpu().eval()
         self.clvp.load_state_dict(torch.load(f'{models_dir}/clvp2.pth'))
 
-        self.cvvp = CVVP(model_dim=512, transformer_heads=8, dropout=0, mel_codes=8192, conditioning_enc_depth=8, cond_mask_percentage=0,
-                         speech_enc_depth=8, speech_mask_percentage=0, latent_multiplier=1).cpu().eval()
-        self.cvvp.load_state_dict(torch.load(f'{models_dir}/cvvp.pth'))
-
         self.vocoder = UnivNetGenerator().cpu()
         self.vocoder.load_state_dict(torch.load(f'{models_dir}/vocoder.pth')['model_g'])
         self.vocoder.eval(inference=True)
@@ -309,8 +303,6 @@ class TextToSpeech:
             return_deterministic_state=False,
             # autoregressive generation parameters follow
             num_autoregressive_samples=512, temperature=.8, length_penalty=1, repetition_penalty=2.0, top_p=.8, max_mel_tokens=500,
-            # CLVP & CVVP parameters
-            clvp_cvvp_slider=.5,
             # diffusion generation parameters follow
             diffusion_iterations=100, cond_free=True, cond_free_k=2, diffusion_temperature=1.0,
             **hf_generate_kwargs):
@@ -321,10 +313,10 @@ class TextToSpeech:
         :param conditioning_latents: A tuple of (autoregressive_conditioning_latent, diffusion_conditioning_latent), which
                                      can be provided in lieu of voice_samples. This is ignored unless voice_samples=None.
                                      Conditioning latents can be retrieved via get_conditioning_latents().
-        :param k: The number of returned clips. The most likely (as determined by Tortoises' CLVP and CVVP models) clips are returned.
+        :param k: The number of returned clips. The most likely (as determined by Tortoises' CLVP model) clips are returned.
         :param verbose: Whether or not to print log messages indicating the progress of creating a clip. Default=true.
         ~~AUTOREGRESSIVE KNOBS~~
-        :param num_autoregressive_samples: Number of samples taken from the autoregressive model, all of which are filtered using CLVP+CVVP.
+        :param num_autoregressive_samples: Number of samples taken from the autoregressive model, all of which are filtered using CLVP.
                As Tortoise is a probabilistic model, more samples means a higher probability of creating something "great".
         :param temperature: The softmax temperature of the autoregressive model.
         :param length_penalty: A length penalty applied to the autoregressive decoder. Higher settings causes the model to produce more terse outputs.
@@ -336,11 +328,6 @@ class TextToSpeech:
                                  I was interested in the premise, but the results were not as good as I was hoping. This is off by default, but
                                  could use some tuning.
         :param typical_mass: The typical_mass parameter from the typical_sampling algorithm.
-        ~~CLVP-CVVP KNOBS~~
-        :param clvp_cvvp_slider: Controls the influence of the CLVP and CVVP models in selecting the best output from the autoregressive model.
-                                [0,1]. Values closer to 1 will cause Tortoise to emit clips that follow the text more. Values closer to
-                                0 will cause Tortoise to emit clips that more closely follow the reference clip (e.g. the voice sounds more
-                                similar).
         ~~DIFFUSION KNOBS~~
         :param diffusion_iterations: Number of diffusion steps to perform. [0,4000]. More steps means the network has more chances to iteratively refine
                                      the output, which should theoretically mean a higher quality output. Generally a value above 250 is not noticeably better,
@@ -402,28 +389,19 @@ class TextToSpeech:
                 samples.append(codes)
             self.autoregressive = self.autoregressive.cpu()
 
-            clip_results = []
+            clvp_results = []
             self.clvp = self.clvp.cuda()
-            self.cvvp = self.cvvp.cuda()
             if verbose:
-                print("Computing best candidates using CLVP and CVVP")
+                print("Computing best candidates using CLVP")
             for batch in tqdm(samples, disable=not verbose):
                 for i in range(batch.shape[0]):
                     batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
                 clvp = self.clvp(text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False)
-                if auto_conds is not None:
-                    cvvp_accumulator = 0
-                    for cl in range(auto_conds.shape[1]):
-                        cvvp_accumulator = cvvp_accumulator + self.cvvp(auto_conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False)
-                    cvvp = cvvp_accumulator / auto_conds.shape[1]
-                    clip_results.append(clvp * clvp_cvvp_slider + cvvp * (1-clvp_cvvp_slider))
-                else:
-                    clip_results.append(clvp)
-            clip_results = torch.cat(clip_results, dim=0)
+                clvp_results.append(clvp)
+            clvp_results = torch.cat(clvp_results, dim=0)
             samples = torch.cat(samples, dim=0)
-            best_results = samples[torch.topk(clip_results, k=k).indices]
+            best_results = samples[torch.topk(clvp_results, k=k).indices]
             self.clvp = self.clvp.cpu()
-            self.cvvp = self.cvvp.cpu()
             del samples
 
             # The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning
diff --git a/tortoise/do_tts.py b/tortoise/do_tts.py
index eb5af04..0807c69 100644
--- a/tortoise/do_tts.py
+++ b/tortoise/do_tts.py
@@ -13,9 +13,6 @@ if __name__ == '__main__':
     parser.add_argument('--voice', type=str, help='Selects the voice to use for generation. See options in voices/ directory (and add your own!) '
                                                  'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.', default='random')
     parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='fast')
-    parser.add_argument('--voice_diversity_intelligibility_slider', type=float,
-                        help='How to balance vocal diversity with the quality/intelligibility of the spoken text. 0 means highly diverse voice (not recommended), 1 means maximize intellibility',
-                        default=.5)
     parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/')
     parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this'
                                                       'should only be specified if you have custom checkpoints.', default='.models')
@@ -31,8 +28,7 @@ if __name__ == '__main__':
     for k, voice in enumerate(selected_voices):
         voice_samples, conditioning_latents = load_voice(voice)
         gen, dbg_state = tts.tts_with_preset(args.text, k=args.candidates, voice_samples=voice_samples, conditioning_latents=conditioning_latents,
-                                  preset=args.preset, clvp_cvvp_slider=args.voice_diversity_intelligibility_slider,
-                                  use_deterministic_seed=args.seed, return_deterministic_state=True)
+                                  preset=args.preset, use_deterministic_seed=args.seed, return_deterministic_state=True)
         if isinstance(gen, list):
             for j, g in enumerate(gen):
                 torchaudio.save(os.path.join(args.output_path, f'{voice}_{k}_{j}.wav'), g.squeeze(0).cpu(), 24000)
diff --git a/tortoise/models/cvvp.py b/tortoise/models/cvvp.py
deleted file mode 100644
index d094649..0000000
--- a/tortoise/models/cvvp.py
+++ /dev/null
@@ -1,133 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from torch import einsum
-from torch.utils.checkpoint import checkpoint
-
-from tortoise.models.arch_util import AttentionBlock
-from tortoise.models.xtransformers import ContinuousTransformerWrapper, Encoder
-
-
-def exists(val):
-    return val is not None
-
-
-def masked_mean(t, mask):
-    t = t.masked_fill(~mask, 0.)
-    return t.sum(dim = 1) / mask.sum(dim = 1)
-
-
-class CollapsingTransformer(nn.Module):
-    def __init__(self, model_dim, output_dims, heads, dropout, depth, mask_percentage=0, **encoder_kwargs):
-        super().__init__()
-        self.transformer = ContinuousTransformerWrapper(
-            max_seq_len=-1,
-            use_pos_emb=False,
-            attn_layers=Encoder(
-                dim=model_dim,
-                depth=depth,
-                heads=heads,
-                ff_dropout=dropout,
-                ff_mult=1,
-                attn_dropout=dropout,
-                use_rmsnorm=True,
-                ff_glu=True,
-                rotary_pos_emb=True,
-                **encoder_kwargs,
-            ))
-        self.pre_combiner = nn.Sequential(nn.Conv1d(model_dim, output_dims, 1),
-                                          AttentionBlock(output_dims, num_heads=heads, do_checkpoint=False),
-                                          nn.Conv1d(output_dims, output_dims, 1))
-        self.mask_percentage = mask_percentage
-
-    def forward(self, x, **transformer_kwargs):
-        h = self.transformer(x, **transformer_kwargs)
-        h = h.permute(0,2,1)
-        h = checkpoint(self.pre_combiner, h).permute(0,2,1)
-        if self.training:
-            mask = torch.rand_like(h.float()) > self.mask_percentage
-        else:
-            mask = torch.ones_like(h.float()).bool()
-        return masked_mean(h, mask)
-
-
-class ConvFormatEmbedding(nn.Module):
-    def __init__(self, *args, **kwargs):
-        super().__init__()
-        self.emb = nn.Embedding(*args, **kwargs)
-
-    def forward(self, x):
-        y = self.emb(x)
-        return y.permute(0,2,1)
-
-
-class CVVP(nn.Module):
-    def __init__(
-            self,
-            model_dim=512,
-            transformer_heads=8,
-            dropout=.1,
-            conditioning_enc_depth=8,
-            cond_mask_percentage=0,
-            mel_channels=80,
-            mel_codes=None,
-            speech_enc_depth=8,
-            speech_mask_percentage=0,
-            latent_multiplier=1,
-    ):
-        super().__init__()
-        latent_dim = latent_multiplier*model_dim
-        self.temperature = nn.Parameter(torch.tensor(1.))
-
-        self.cond_emb = nn.Sequential(nn.Conv1d(mel_channels, model_dim//2, kernel_size=5, stride=2, padding=2),
-                                      nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1))
-        self.conditioning_transformer = CollapsingTransformer(model_dim, model_dim, transformer_heads, dropout, conditioning_enc_depth, cond_mask_percentage)
-        self.to_conditioning_latent = nn.Linear(latent_dim, latent_dim, bias=False)
-
-        if mel_codes is None:
-            self.speech_emb = nn.Conv1d(mel_channels, model_dim, kernel_size=5, padding=2)
-        else:
-            self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim)
-        self.speech_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage)
-        self.to_speech_latent = nn.Linear(latent_dim, latent_dim, bias=False)
-
-    def get_grad_norm_parameter_groups(self):
-        return {
-            'conditioning': list(self.conditioning_transformer.parameters()),
-            'speech': list(self.speech_transformer.parameters()),
-        }
-
-    def forward(
-            self,
-            mel_cond,
-            mel_input,
-            return_loss=False
-    ):
-        cond_emb = self.cond_emb(mel_cond).permute(0,2,1)
-        enc_cond = self.conditioning_transformer(cond_emb)
-        cond_latents = self.to_conditioning_latent(enc_cond)
-
-        speech_emb = self.speech_emb(mel_input).permute(0,2,1)
-        enc_speech = self.speech_transformer(speech_emb)
-        speech_latents = self.to_speech_latent(enc_speech)
-
-
-        cond_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (cond_latents, speech_latents))
-        temp = self.temperature.exp()
-
-        if not return_loss:
-            sim = einsum('n d, n d -> n', cond_latents, speech_latents) * temp
-            return sim
-
-        sim = einsum('i d, j d -> i j', cond_latents, speech_latents) * temp
-        labels = torch.arange(cond_latents.shape[0], device=mel_input.device)
-        loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
-
-        return loss
-
-
-if __name__ == '__main__':
-    clvp = CVVP()
-    clvp(torch.randn(2,80,100),
-         torch.randn(2,80,95),
-         return_loss=True)
\ No newline at end of file
diff --git a/tortoise/read.py b/tortoise/read.py
index ae68202..33332b6 100644
--- a/tortoise/read.py
+++ b/tortoise/read.py
@@ -18,9 +18,6 @@ if __name__ == '__main__':
     parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/longform/')
     parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='standard')
     parser.add_argument('--regenerate', type=str, help='Comma-separated list of clip numbers to re-generate, or nothing.', default=None)
-    parser.add_argument('--voice_diversity_intelligibility_slider', type=float,
-                        help='How to balance vocal diversity with the quality/intelligibility of the spoken text. 0 means highly diverse voice (not recommended), 1 means maximize intellibility',
-                        default=.5)
     parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this'
                                                       'should only be specified if you have custom checkpoints.', default='.models')
     parser.add_argument('--seed', type=int, help='Random seed which can be used to reproduce results.', default=None)
@@ -62,8 +59,7 @@ if __name__ == '__main__':
                 all_parts.append(load_audio(os.path.join(voice_outpath, f'{j}.wav'), 24000))
                 continue
             gen = tts.tts_with_preset(text, voice_samples=voice_samples, conditioning_latents=conditioning_latents,
-                                      preset=args.preset, clvp_cvvp_slider=args.voice_diversity_intelligibility_slider,
-                                      use_deterministic_seed=seed)
+                                      preset=args.preset, use_deterministic_seed=seed)
             gen = gen.squeeze(0).cpu()
             torchaudio.save(os.path.join(voice_outpath, f'{j}.wav'), gen, 24000)
             all_parts.append(gen)