Merge pull request #64 from jnordberg/revive-cvvp

Revive CVVP model
This commit is contained in:
James Betker 2022-05-25 15:59:09 -06:00 committed by GitHub
commit 7becd30c2a
2 changed files with 73 additions and 25 deletions

View File

@ -16,6 +16,7 @@ from tqdm import tqdm
from tortoise.models.arch_util import TorchMelSpectrogram
from tortoise.models.clvp import CLVP
from tortoise.models.cvvp import CVVP
from tortoise.models.random_latent_generator import RandomLatentConverter
from tortoise.models.vocoder import UnivNetGenerator
from tortoise.utils.audio import wav_to_univnet_mel, denormalize_tacotron_mel
@ -26,21 +27,23 @@ from tortoise.utils.wav2vec_alignment import Wav2VecAlignment
pbar = None
MODELS_DIR = os.environ.get('TORTOISE_MODELS_DIR', '.models')
MODELS = {
'autoregressive.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/autoregressive.pth',
'classifier.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/classifier.pth',
'clvp2.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/clvp2.pth',
'cvvp.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/cvvp.pth',
'diffusion_decoder.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/diffusion_decoder.pth',
'vocoder.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/vocoder.pth',
'rlg_auto.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_auto.pth',
'rlg_diffuser.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth',
}
def download_models(specific_models=None):
"""
Call to download all the models that Tortoise uses.
"""
MODELS = {
'autoregressive.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/autoregressive.pth',
'classifier.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/classifier.pth',
'clvp2.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/clvp2.pth',
'diffusion_decoder.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/diffusion_decoder.pth',
'vocoder.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/vocoder.pth',
'rlg_auto.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_auto.pth',
'rlg_diffuser.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth',
}
os.makedirs(MODELS_DIR, exist_ok=True)
def show_progress(block_num, block_size, total_size):
global pbar
if pbar is None:
@ -64,6 +67,18 @@ def download_models(specific_models=None):
print('Done.')
def get_model_path(model_name, models_dir=MODELS_DIR):
"""
Get path to given model, download it if it doesn't exist.
"""
if model_name not in MODELS:
raise ValueError(f'Model {model_name} not found in available models.')
model_path = os.path.join(models_dir, model_name)
if not os.path.exists(model_path) and models_dir == MODELS_DIR:
download_models([model_name])
return model_path
def pad_or_truncate(t, length):
"""
Utility function for forcing <t> to have the specified sequence length, whether by clipping it or padding it with 0s.
@ -151,11 +166,10 @@ def classify_audio_clip(clip):
:param clip: torch tensor containing audio waveform data (get it from load_audio)
:return: True if the clip was classified as coming from Tortoise and false if it was classified as real.
"""
download_models(['classifier.pth'])
classifier = AudioMiniEncoderWithClassifierHead(2, spec_dim=1, embedding_dim=512, depth=5, downsample_factor=4,
resnet_blocks=2, attn_blocks=4, num_attn_heads=4, base_channels=32,
dropout=0, kernel_size=5, distribute_zero_label=False)
classifier.load_state_dict(torch.load(os.path.join(MODELS_DIR, 'classifier.pth'), map_location=torch.device('cpu')))
classifier.load_state_dict(torch.load(get_model_path('classifier.pth'), map_location=torch.device('cpu')))
clip = clip.cpu().unsqueeze(0)
results = F.softmax(classifier(clip), dim=-1)
return results[0][0]
@ -193,13 +207,13 @@ class TextToSpeech:
(but are still rendered by the model). This can be used for prompt engineering.
Default is true.
"""
self.models_dir = models_dir
self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size
self.enable_redaction = enable_redaction
if self.enable_redaction:
self.aligner = Wav2VecAlignment()
self.tokenizer = VoiceBpeTokenizer()
download_models()
if os.path.exists(f'{models_dir}/autoregressive.ptt'):
# Assume this is a traced directory.
@ -210,27 +224,34 @@ class TextToSpeech:
model_dim=1024,
heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False,
train_solo_embeddings=False).cpu().eval()
self.autoregressive.load_state_dict(torch.load(f'{models_dir}/autoregressive.pth'))
self.autoregressive.load_state_dict(torch.load(get_model_path('autoregressive.pth', models_dir)))
self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16,
layer_drop=0, unconditioned_percentage=0).cpu().eval()
self.diffusion.load_state_dict(torch.load(f'{models_dir}/diffusion_decoder.pth'))
self.diffusion.load_state_dict(torch.load(get_model_path('diffusion_decoder.pth', models_dir)))
self.clvp = CLVP(dim_text=768, dim_speech=768, dim_latent=768, num_text_tokens=256, text_enc_depth=20,
text_seq_len=350, text_heads=12,
num_speech_tokens=8192, speech_enc_depth=20, speech_heads=12, speech_seq_len=430,
use_xformers=True).cpu().eval()
self.clvp.load_state_dict(torch.load(f'{models_dir}/clvp2.pth'))
self.clvp.load_state_dict(torch.load(get_model_path('clvp2.pth', models_dir)))
self.cvvp = None # CVVP model is only loaded if used.
self.vocoder = UnivNetGenerator().cpu()
self.vocoder.load_state_dict(torch.load(f'{models_dir}/vocoder.pth')['model_g'])
self.vocoder.load_state_dict(torch.load(get_model_path('vocoder.pth', models_dir))['model_g'])
self.vocoder.eval(inference=True)
# Random latent generators (RLGs) are loaded lazily.
self.rlg_auto = None
self.rlg_diffusion = None
def load_cvvp(self):
"""Load CVVP model."""
self.cvvp = CVVP(model_dim=512, transformer_heads=8, dropout=0, mel_codes=8192, conditioning_enc_depth=8, cond_mask_percentage=0,
speech_enc_depth=8, speech_mask_percentage=0, latent_multiplier=1).cpu().eval()
self.cvvp.load_state_dict(torch.load(get_model_path('cvvp.pth', self.models_dir)))
def get_conditioning_latents(self, voice_samples, return_mels=False):
"""
Transforms one or more voice_samples into a tuple (autoregressive_conditioning_latent, diffusion_conditioning_latent).
@ -273,9 +294,9 @@ class TextToSpeech:
# Lazy-load the RLG models.
if self.rlg_auto is None:
self.rlg_auto = RandomLatentConverter(1024).eval()
self.rlg_auto.load_state_dict(torch.load(os.path.join(MODELS_DIR, 'rlg_auto.pth'), map_location=torch.device('cpu')))
self.rlg_auto.load_state_dict(torch.load(get_model_path('rlg_auto.pth', self.models_dir), map_location=torch.device('cpu')))
self.rlg_diffusion = RandomLatentConverter(2048).eval()
self.rlg_diffusion.load_state_dict(torch.load(os.path.join(MODELS_DIR, 'rlg_diffuser.pth'), map_location=torch.device('cpu')))
self.rlg_diffusion.load_state_dict(torch.load(get_model_path('rlg_diffuser.pth', self.models_dir), map_location=torch.device('cpu')))
with torch.no_grad():
return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0]))
@ -305,6 +326,8 @@ class TextToSpeech:
return_deterministic_state=False,
# autoregressive generation parameters follow
num_autoregressive_samples=512, temperature=.8, length_penalty=1, repetition_penalty=2.0, top_p=.8, max_mel_tokens=500,
# CVVP parameters follow
cvvp_amount=.0,
# diffusion generation parameters follow
diffusion_iterations=100, cond_free=True, cond_free_k=2, diffusion_temperature=1.0,
**hf_generate_kwargs):
@ -330,6 +353,9 @@ class TextToSpeech:
I was interested in the premise, but the results were not as good as I was hoping. This is off by default, but
could use some tuning.
:param typical_mass: The typical_mass parameter from the typical_sampling algorithm.
~~CLVP-CVVP KNOBS~~
:param cvvp_amount: Controls the influence of the CVVP model in selecting the best output from the autoregressive model.
[0,1]. Values closer to 1 mean the CVVP model is more important, 0 disables the CVVP model.
~~DIFFUSION KNOBS~~
:param diffusion_iterations: Number of diffusion steps to perform. [0,4000]. More steps means the network has more chances to iteratively refine
the output, which should theoretically mean a higher quality output. Generally a value above 250 is not noticeably better,
@ -391,19 +417,39 @@ class TextToSpeech:
samples.append(codes)
self.autoregressive = self.autoregressive.cpu()
clvp_results = []
clip_results = []
self.clvp = self.clvp.cuda()
if cvvp_amount > 0:
if self.cvvp is None:
self.load_cvvp()
self.cvvp = self.cvvp.cuda()
if verbose:
print("Computing best candidates using CLVP")
if self.cvvp is None:
print("Computing best candidates using CLVP")
else:
print(f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%")
for batch in tqdm(samples, disable=not verbose):
for i in range(batch.shape[0]):
batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
clvp = self.clvp(text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False)
clvp_results.append(clvp)
clvp_results = torch.cat(clvp_results, dim=0)
if cvvp_amount != 1:
clvp = self.clvp(text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False)
if auto_conds is not None and cvvp_amount > 0:
cvvp_accumulator = 0
for cl in range(auto_conds.shape[1]):
cvvp_accumulator = cvvp_accumulator + self.cvvp(auto_conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False)
cvvp = cvvp_accumulator / auto_conds.shape[1]
if cvvp_amount == 1:
clip_results.append(cvvp)
else:
clip_results.append(cvvp * cvvp_amount + clvp * (1-cvvp_amount))
else:
clip_results.append(clvp)
clip_results = torch.cat(clip_results, dim=0)
samples = torch.cat(samples, dim=0)
best_results = samples[torch.topk(clvp_results, k=k).indices]
best_results = samples[torch.topk(clip_results, k=k).indices]
self.clvp = self.clvp.cpu()
if self.cvvp is not None:
self.cvvp = self.cvvp.cpu()
del samples
# The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning

View File

@ -19,6 +19,8 @@ if __name__ == '__main__':
parser.add_argument('--candidates', type=int, help='How many output candidates to produce per-voice.', default=3)
parser.add_argument('--seed', type=int, help='Random seed which can be used to reproduce results.', default=None)
parser.add_argument('--produce_debug_state', type=bool, help='Whether or not to produce debug_state.pth, which can aid in reproducing problems. Defaults to true.', default=True)
parser.add_argument('--cvvp_amount', type=float, help='How much the CVVP model should influence the output.'
'Increasing this can in some cases reduce the likelyhood of multiple speakers. Defaults to 0 (disabled)', default=.0)
args = parser.parse_args()
os.makedirs(args.output_path, exist_ok=True)
@ -33,7 +35,7 @@ if __name__ == '__main__':
voice_samples, conditioning_latents = load_voices(voice_sel)
gen, dbg_state = tts.tts_with_preset(args.text, k=args.candidates, voice_samples=voice_samples, conditioning_latents=conditioning_latents,
preset=args.preset, use_deterministic_seed=args.seed, return_deterministic_state=True)
preset=args.preset, use_deterministic_seed=args.seed, return_deterministic_state=True, cvvp_amount=args.cvvp_amount)
if isinstance(gen, list):
for j, g in enumerate(gen):
torchaudio.save(os.path.join(args.output_path, f'{selected_voice}_{k}_{j}.wav'), g.squeeze(0).cpu(), 24000)