Revive CVVP model

This commit is contained in:
Johan Nordberg 2022-05-25 10:22:50 +00:00
parent e0be49f02f
commit 0ca4d8f291
2 changed files with 68 additions and 24 deletions

View File

@ -16,6 +16,7 @@ from tqdm import tqdm
from tortoise.models.arch_util import TorchMelSpectrogram from tortoise.models.arch_util import TorchMelSpectrogram
from tortoise.models.clvp import CLVP from tortoise.models.clvp import CLVP
from tortoise.models.cvvp import CVVP
from tortoise.models.random_latent_generator import RandomLatentConverter from tortoise.models.random_latent_generator import RandomLatentConverter
from tortoise.models.vocoder import UnivNetGenerator from tortoise.models.vocoder import UnivNetGenerator
from tortoise.utils.audio import wav_to_univnet_mel, denormalize_tacotron_mel from tortoise.utils.audio import wav_to_univnet_mel, denormalize_tacotron_mel
@ -26,21 +27,23 @@ from tortoise.utils.wav2vec_alignment import Wav2VecAlignment
pbar = None pbar = None
MODELS_DIR = os.environ.get('TORTOISE_MODELS_DIR', '.models') MODELS_DIR = os.environ.get('TORTOISE_MODELS_DIR', '.models')
MODELS = {
'autoregressive.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/autoregressive.pth',
'classifier.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/classifier.pth',
'clvp2.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/clvp2.pth',
'cvvp.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/cvvp.pth',
'diffusion_decoder.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/diffusion_decoder.pth',
'vocoder.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/vocoder.pth',
'rlg_auto.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_auto.pth',
'rlg_diffuser.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth',
}
def download_models(specific_models=None): def download_models(specific_models=None):
""" """
Call to download all the models that Tortoise uses. Call to download all the models that Tortoise uses.
""" """
MODELS = {
'autoregressive.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/autoregressive.pth',
'classifier.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/classifier.pth',
'clvp2.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/clvp2.pth',
'diffusion_decoder.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/diffusion_decoder.pth',
'vocoder.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/vocoder.pth',
'rlg_auto.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_auto.pth',
'rlg_diffuser.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth',
}
os.makedirs(MODELS_DIR, exist_ok=True) os.makedirs(MODELS_DIR, exist_ok=True)
def show_progress(block_num, block_size, total_size): def show_progress(block_num, block_size, total_size):
global pbar global pbar
if pbar is None: if pbar is None:
@ -64,6 +67,18 @@ def download_models(specific_models=None):
print('Done.') print('Done.')
def get_model_path(model_name, models_dir=MODELS_DIR):
"""
Get path to given model, download it if it doesn't exist.
"""
if model_name not in MODELS:
raise ValueError(f'Model {model_name} not found in available models.')
model_path = os.path.join(models_dir, model_name)
if not os.path.exists(model_path) and models_dir == MODELS_DIR:
download_models([model_name])
return model_path
def pad_or_truncate(t, length): def pad_or_truncate(t, length):
""" """
Utility function for forcing <t> to have the specified sequence length, whether by clipping it or padding it with 0s. Utility function for forcing <t> to have the specified sequence length, whether by clipping it or padding it with 0s.
@ -151,11 +166,10 @@ def classify_audio_clip(clip):
:param clip: torch tensor containing audio waveform data (get it from load_audio) :param clip: torch tensor containing audio waveform data (get it from load_audio)
:return: True if the clip was classified as coming from Tortoise and false if it was classified as real. :return: True if the clip was classified as coming from Tortoise and false if it was classified as real.
""" """
download_models(['classifier.pth'])
classifier = AudioMiniEncoderWithClassifierHead(2, spec_dim=1, embedding_dim=512, depth=5, downsample_factor=4, classifier = AudioMiniEncoderWithClassifierHead(2, spec_dim=1, embedding_dim=512, depth=5, downsample_factor=4,
resnet_blocks=2, attn_blocks=4, num_attn_heads=4, base_channels=32, resnet_blocks=2, attn_blocks=4, num_attn_heads=4, base_channels=32,
dropout=0, kernel_size=5, distribute_zero_label=False) dropout=0, kernel_size=5, distribute_zero_label=False)
classifier.load_state_dict(torch.load(os.path.join(MODELS_DIR, 'classifier.pth'), map_location=torch.device('cpu'))) classifier.load_state_dict(torch.load(get_model_path('classifier.pth'), map_location=torch.device('cpu')))
clip = clip.cpu().unsqueeze(0) clip = clip.cpu().unsqueeze(0)
results = F.softmax(classifier(clip), dim=-1) results = F.softmax(classifier(clip), dim=-1)
return results[0][0] return results[0][0]
@ -193,13 +207,13 @@ class TextToSpeech:
(but are still rendered by the model). This can be used for prompt engineering. (but are still rendered by the model). This can be used for prompt engineering.
Default is true. Default is true.
""" """
self.models_dir = models_dir
self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size
self.enable_redaction = enable_redaction self.enable_redaction = enable_redaction
if self.enable_redaction: if self.enable_redaction:
self.aligner = Wav2VecAlignment() self.aligner = Wav2VecAlignment()
self.tokenizer = VoiceBpeTokenizer() self.tokenizer = VoiceBpeTokenizer()
download_models()
if os.path.exists(f'{models_dir}/autoregressive.ptt'): if os.path.exists(f'{models_dir}/autoregressive.ptt'):
# Assume this is a traced directory. # Assume this is a traced directory.
@ -210,27 +224,34 @@ class TextToSpeech:
model_dim=1024, model_dim=1024,
heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False, heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False,
train_solo_embeddings=False).cpu().eval() train_solo_embeddings=False).cpu().eval()
self.autoregressive.load_state_dict(torch.load(f'{models_dir}/autoregressive.pth')) self.autoregressive.load_state_dict(torch.load(get_model_path('autoregressive.pth', models_dir)))
self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200, self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16, in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16,
layer_drop=0, unconditioned_percentage=0).cpu().eval() layer_drop=0, unconditioned_percentage=0).cpu().eval()
self.diffusion.load_state_dict(torch.load(f'{models_dir}/diffusion_decoder.pth')) self.diffusion.load_state_dict(torch.load(get_model_path('diffusion_decoder.pth', models_dir)))
self.clvp = CLVP(dim_text=768, dim_speech=768, dim_latent=768, num_text_tokens=256, text_enc_depth=20, self.clvp = CLVP(dim_text=768, dim_speech=768, dim_latent=768, num_text_tokens=256, text_enc_depth=20,
text_seq_len=350, text_heads=12, text_seq_len=350, text_heads=12,
num_speech_tokens=8192, speech_enc_depth=20, speech_heads=12, speech_seq_len=430, num_speech_tokens=8192, speech_enc_depth=20, speech_heads=12, speech_seq_len=430,
use_xformers=True).cpu().eval() use_xformers=True).cpu().eval()
self.clvp.load_state_dict(torch.load(f'{models_dir}/clvp2.pth')) self.clvp.load_state_dict(torch.load(get_model_path('clvp2.pth', models_dir)))
self.cvvp = None # CVVP model is only loaded if used.
self.vocoder = UnivNetGenerator().cpu() self.vocoder = UnivNetGenerator().cpu()
self.vocoder.load_state_dict(torch.load(f'{models_dir}/vocoder.pth')['model_g']) self.vocoder.load_state_dict(torch.load(get_model_path('vocoder.pth', models_dir))['model_g'])
self.vocoder.eval(inference=True) self.vocoder.eval(inference=True)
# Random latent generators (RLGs) are loaded lazily. # Random latent generators (RLGs) are loaded lazily.
self.rlg_auto = None self.rlg_auto = None
self.rlg_diffusion = None self.rlg_diffusion = None
def load_cvvp(self):
"""Load CVVP model."""
self.cvvp = CVVP(model_dim=512, transformer_heads=8, dropout=0, mel_codes=8192, conditioning_enc_depth=8, cond_mask_percentage=0,
speech_enc_depth=8, speech_mask_percentage=0, latent_multiplier=1).cpu().eval()
self.cvvp.load_state_dict(torch.load(get_model_path('cvvp.pth', self.models_dir)))
def get_conditioning_latents(self, voice_samples, return_mels=False): def get_conditioning_latents(self, voice_samples, return_mels=False):
""" """
Transforms one or more voice_samples into a tuple (autoregressive_conditioning_latent, diffusion_conditioning_latent). Transforms one or more voice_samples into a tuple (autoregressive_conditioning_latent, diffusion_conditioning_latent).
@ -273,9 +294,9 @@ class TextToSpeech:
# Lazy-load the RLG models. # Lazy-load the RLG models.
if self.rlg_auto is None: if self.rlg_auto is None:
self.rlg_auto = RandomLatentConverter(1024).eval() self.rlg_auto = RandomLatentConverter(1024).eval()
self.rlg_auto.load_state_dict(torch.load(os.path.join(MODELS_DIR, 'rlg_auto.pth'), map_location=torch.device('cpu'))) self.rlg_auto.load_state_dict(torch.load(get_model_path('rlg_auto.pth', self.models_dir), map_location=torch.device('cpu')))
self.rlg_diffusion = RandomLatentConverter(2048).eval() self.rlg_diffusion = RandomLatentConverter(2048).eval()
self.rlg_diffusion.load_state_dict(torch.load(os.path.join(MODELS_DIR, 'rlg_diffuser.pth'), map_location=torch.device('cpu'))) self.rlg_diffusion.load_state_dict(torch.load(get_model_path('rlg_diffuser.pth', self.models_dir), map_location=torch.device('cpu')))
with torch.no_grad(): with torch.no_grad():
return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0])) return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0]))
@ -305,6 +326,8 @@ class TextToSpeech:
return_deterministic_state=False, return_deterministic_state=False,
# autoregressive generation parameters follow # autoregressive generation parameters follow
num_autoregressive_samples=512, temperature=.8, length_penalty=1, repetition_penalty=2.0, top_p=.8, max_mel_tokens=500, num_autoregressive_samples=512, temperature=.8, length_penalty=1, repetition_penalty=2.0, top_p=.8, max_mel_tokens=500,
# CVVP parameters follow
cvvp_amount=.0,
# diffusion generation parameters follow # diffusion generation parameters follow
diffusion_iterations=100, cond_free=True, cond_free_k=2, diffusion_temperature=1.0, diffusion_iterations=100, cond_free=True, cond_free_k=2, diffusion_temperature=1.0,
**hf_generate_kwargs): **hf_generate_kwargs):
@ -330,6 +353,9 @@ class TextToSpeech:
I was interested in the premise, but the results were not as good as I was hoping. This is off by default, but I was interested in the premise, but the results were not as good as I was hoping. This is off by default, but
could use some tuning. could use some tuning.
:param typical_mass: The typical_mass parameter from the typical_sampling algorithm. :param typical_mass: The typical_mass parameter from the typical_sampling algorithm.
~~CLVP-CVVP KNOBS~~
:param cvvp_amount: Controls the influence of the CVVP model in selecting the best output from the autoregressive model.
[0,1]. Values closer to 1 mean the CVVP model is more important, 0 disables the CVVP model.
~~DIFFUSION KNOBS~~ ~~DIFFUSION KNOBS~~
:param diffusion_iterations: Number of diffusion steps to perform. [0,4000]. More steps means the network has more chances to iteratively refine :param diffusion_iterations: Number of diffusion steps to perform. [0,4000]. More steps means the network has more chances to iteratively refine
the output, which should theoretically mean a higher quality output. Generally a value above 250 is not noticeably better, the output, which should theoretically mean a higher quality output. Generally a value above 250 is not noticeably better,
@ -391,19 +417,35 @@ class TextToSpeech:
samples.append(codes) samples.append(codes)
self.autoregressive = self.autoregressive.cpu() self.autoregressive = self.autoregressive.cpu()
clvp_results = [] clip_results = []
self.clvp = self.clvp.cuda() self.clvp = self.clvp.cuda()
if cvvp_amount > 0:
if self.cvvp is None:
self.load_cvvp()
self.cvvp = self.cvvp.cuda()
if verbose: if verbose:
print("Computing best candidates using CLVP") if self.cvvp is None:
print("Computing best candidates using CLVP")
else:
print(f"Computing best candidates using CLVP {int((1-cvvp_amount) * 100):02d}% and CVVP {int(cvvp_amount * 100):02d}%")
for batch in tqdm(samples, disable=not verbose): for batch in tqdm(samples, disable=not verbose):
for i in range(batch.shape[0]): for i in range(batch.shape[0]):
batch[i] = fix_autoregressive_output(batch[i], stop_mel_token) batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
clvp = self.clvp(text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False) clvp = self.clvp(text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False)
clvp_results.append(clvp) if auto_conds is not None and cvvp_amount > 0:
clvp_results = torch.cat(clvp_results, dim=0) cvvp_accumulator = 0
for cl in range(auto_conds.shape[1]):
cvvp_accumulator = cvvp_accumulator + self.cvvp(auto_conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False)
cvvp = cvvp_accumulator / auto_conds.shape[1]
clip_results.append(cvvp * cvvp_amount + clvp * (1-cvvp_amount))
else:
clip_results.append(clvp)
clip_results = torch.cat(clip_results, dim=0)
samples = torch.cat(samples, dim=0) samples = torch.cat(samples, dim=0)
best_results = samples[torch.topk(clvp_results, k=k).indices] best_results = samples[torch.topk(clip_results, k=k).indices]
self.clvp = self.clvp.cpu() self.clvp = self.clvp.cpu()
if self.cvvp is not None:
self.cvvp = self.cvvp.cpu()
del samples del samples
# The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning # The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning

View File

@ -19,6 +19,8 @@ if __name__ == '__main__':
parser.add_argument('--candidates', type=int, help='How many output candidates to produce per-voice.', default=3) parser.add_argument('--candidates', type=int, help='How many output candidates to produce per-voice.', default=3)
parser.add_argument('--seed', type=int, help='Random seed which can be used to reproduce results.', default=None) parser.add_argument('--seed', type=int, help='Random seed which can be used to reproduce results.', default=None)
parser.add_argument('--produce_debug_state', type=bool, help='Whether or not to produce debug_state.pth, which can aid in reproducing problems. Defaults to true.', default=True) parser.add_argument('--produce_debug_state', type=bool, help='Whether or not to produce debug_state.pth, which can aid in reproducing problems. Defaults to true.', default=True)
parser.add_argument('--cvvp_amount', type=float, help='How much the CVVP model should influence the output.'
'Increasing this can in some cases reduce the likelyhood of multiple speakers. Defaults to 0 (disabled)', default=.0)
args = parser.parse_args() args = parser.parse_args()
os.makedirs(args.output_path, exist_ok=True) os.makedirs(args.output_path, exist_ok=True)
@ -33,7 +35,7 @@ if __name__ == '__main__':
voice_samples, conditioning_latents = load_voices(voice_sel) voice_samples, conditioning_latents = load_voices(voice_sel)
gen, dbg_state = tts.tts_with_preset(args.text, k=args.candidates, voice_samples=voice_samples, conditioning_latents=conditioning_latents, gen, dbg_state = tts.tts_with_preset(args.text, k=args.candidates, voice_samples=voice_samples, conditioning_latents=conditioning_latents,
preset=args.preset, use_deterministic_seed=args.seed, return_deterministic_state=True) preset=args.preset, use_deterministic_seed=args.seed, return_deterministic_state=True, cvvp_amount=args.cvvp_amount)
if isinstance(gen, list): if isinstance(gen, list):
for j, g in enumerate(gen): for j, g in enumerate(gen):
torchaudio.save(os.path.join(args.output_path, f'{selected_voice}_{k}_{j}.wav'), g.squeeze(0).cpu(), 24000) torchaudio.save(os.path.join(args.output_path, f'{selected_voice}_{k}_{j}.wav'), g.squeeze(0).cpu(), 24000)