forked from mrq/tortoise-tts
Revive CVVP model
This commit is contained in:
parent
e0be49f02f
commit
0ca4d8f291
|
@ -16,6 +16,7 @@ from tqdm import tqdm
|
||||||
|
|
||||||
from tortoise.models.arch_util import TorchMelSpectrogram
|
from tortoise.models.arch_util import TorchMelSpectrogram
|
||||||
from tortoise.models.clvp import CLVP
|
from tortoise.models.clvp import CLVP
|
||||||
|
from tortoise.models.cvvp import CVVP
|
||||||
from tortoise.models.random_latent_generator import RandomLatentConverter
|
from tortoise.models.random_latent_generator import RandomLatentConverter
|
||||||
from tortoise.models.vocoder import UnivNetGenerator
|
from tortoise.models.vocoder import UnivNetGenerator
|
||||||
from tortoise.utils.audio import wav_to_univnet_mel, denormalize_tacotron_mel
|
from tortoise.utils.audio import wav_to_univnet_mel, denormalize_tacotron_mel
|
||||||
|
@ -26,21 +27,23 @@ from tortoise.utils.wav2vec_alignment import Wav2VecAlignment
|
||||||
pbar = None
|
pbar = None
|
||||||
|
|
||||||
MODELS_DIR = os.environ.get('TORTOISE_MODELS_DIR', '.models')
|
MODELS_DIR = os.environ.get('TORTOISE_MODELS_DIR', '.models')
|
||||||
|
|
||||||
def download_models(specific_models=None):
|
|
||||||
"""
|
|
||||||
Call to download all the models that Tortoise uses.
|
|
||||||
"""
|
|
||||||
MODELS = {
|
MODELS = {
|
||||||
'autoregressive.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/autoregressive.pth',
|
'autoregressive.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/autoregressive.pth',
|
||||||
'classifier.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/classifier.pth',
|
'classifier.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/classifier.pth',
|
||||||
'clvp2.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/clvp2.pth',
|
'clvp2.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/clvp2.pth',
|
||||||
|
'cvvp.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/cvvp.pth',
|
||||||
'diffusion_decoder.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/diffusion_decoder.pth',
|
'diffusion_decoder.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/diffusion_decoder.pth',
|
||||||
'vocoder.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/vocoder.pth',
|
'vocoder.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/vocoder.pth',
|
||||||
'rlg_auto.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_auto.pth',
|
'rlg_auto.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_auto.pth',
|
||||||
'rlg_diffuser.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth',
|
'rlg_diffuser.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def download_models(specific_models=None):
|
||||||
|
"""
|
||||||
|
Call to download all the models that Tortoise uses.
|
||||||
|
"""
|
||||||
os.makedirs(MODELS_DIR, exist_ok=True)
|
os.makedirs(MODELS_DIR, exist_ok=True)
|
||||||
|
|
||||||
def show_progress(block_num, block_size, total_size):
|
def show_progress(block_num, block_size, total_size):
|
||||||
global pbar
|
global pbar
|
||||||
if pbar is None:
|
if pbar is None:
|
||||||
|
@ -64,6 +67,18 @@ def download_models(specific_models=None):
|
||||||
print('Done.')
|
print('Done.')
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_path(model_name, models_dir=MODELS_DIR):
|
||||||
|
"""
|
||||||
|
Get path to given model, download it if it doesn't exist.
|
||||||
|
"""
|
||||||
|
if model_name not in MODELS:
|
||||||
|
raise ValueError(f'Model {model_name} not found in available models.')
|
||||||
|
model_path = os.path.join(models_dir, model_name)
|
||||||
|
if not os.path.exists(model_path) and models_dir == MODELS_DIR:
|
||||||
|
download_models([model_name])
|
||||||
|
return model_path
|
||||||
|
|
||||||
|
|
||||||
def pad_or_truncate(t, length):
|
def pad_or_truncate(t, length):
|
||||||
"""
|
"""
|
||||||
Utility function for forcing <t> to have the specified sequence length, whether by clipping it or padding it with 0s.
|
Utility function for forcing <t> to have the specified sequence length, whether by clipping it or padding it with 0s.
|
||||||
|
@ -151,11 +166,10 @@ def classify_audio_clip(clip):
|
||||||
:param clip: torch tensor containing audio waveform data (get it from load_audio)
|
:param clip: torch tensor containing audio waveform data (get it from load_audio)
|
||||||
:return: True if the clip was classified as coming from Tortoise and false if it was classified as real.
|
:return: True if the clip was classified as coming from Tortoise and false if it was classified as real.
|
||||||
"""
|
"""
|
||||||
download_models(['classifier.pth'])
|
|
||||||
classifier = AudioMiniEncoderWithClassifierHead(2, spec_dim=1, embedding_dim=512, depth=5, downsample_factor=4,
|
classifier = AudioMiniEncoderWithClassifierHead(2, spec_dim=1, embedding_dim=512, depth=5, downsample_factor=4,
|
||||||
resnet_blocks=2, attn_blocks=4, num_attn_heads=4, base_channels=32,
|
resnet_blocks=2, attn_blocks=4, num_attn_heads=4, base_channels=32,
|
||||||
dropout=0, kernel_size=5, distribute_zero_label=False)
|
dropout=0, kernel_size=5, distribute_zero_label=False)
|
||||||
classifier.load_state_dict(torch.load(os.path.join(MODELS_DIR, 'classifier.pth'), map_location=torch.device('cpu')))
|
classifier.load_state_dict(torch.load(get_model_path('classifier.pth'), map_location=torch.device('cpu')))
|
||||||
clip = clip.cpu().unsqueeze(0)
|
clip = clip.cpu().unsqueeze(0)
|
||||||
results = F.softmax(classifier(clip), dim=-1)
|
results = F.softmax(classifier(clip), dim=-1)
|
||||||
return results[0][0]
|
return results[0][0]
|
||||||
|
@ -193,13 +207,13 @@ class TextToSpeech:
|
||||||
(but are still rendered by the model). This can be used for prompt engineering.
|
(but are still rendered by the model). This can be used for prompt engineering.
|
||||||
Default is true.
|
Default is true.
|
||||||
"""
|
"""
|
||||||
|
self.models_dir = models_dir
|
||||||
self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size
|
self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size
|
||||||
self.enable_redaction = enable_redaction
|
self.enable_redaction = enable_redaction
|
||||||
if self.enable_redaction:
|
if self.enable_redaction:
|
||||||
self.aligner = Wav2VecAlignment()
|
self.aligner = Wav2VecAlignment()
|
||||||
|
|
||||||
self.tokenizer = VoiceBpeTokenizer()
|
self.tokenizer = VoiceBpeTokenizer()
|
||||||
download_models()
|
|
||||||
|
|
||||||
if os.path.exists(f'{models_dir}/autoregressive.ptt'):
|
if os.path.exists(f'{models_dir}/autoregressive.ptt'):
|
||||||
# Assume this is a traced directory.
|
# Assume this is a traced directory.
|
||||||
|
@ -210,27 +224,34 @@ class TextToSpeech:
|
||||||
model_dim=1024,
|
model_dim=1024,
|
||||||
heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False,
|
heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False,
|
||||||
train_solo_embeddings=False).cpu().eval()
|
train_solo_embeddings=False).cpu().eval()
|
||||||
self.autoregressive.load_state_dict(torch.load(f'{models_dir}/autoregressive.pth'))
|
self.autoregressive.load_state_dict(torch.load(get_model_path('autoregressive.pth', models_dir)))
|
||||||
|
|
||||||
self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
|
self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
|
||||||
in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16,
|
in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16,
|
||||||
layer_drop=0, unconditioned_percentage=0).cpu().eval()
|
layer_drop=0, unconditioned_percentage=0).cpu().eval()
|
||||||
self.diffusion.load_state_dict(torch.load(f'{models_dir}/diffusion_decoder.pth'))
|
self.diffusion.load_state_dict(torch.load(get_model_path('diffusion_decoder.pth', models_dir)))
|
||||||
|
|
||||||
self.clvp = CLVP(dim_text=768, dim_speech=768, dim_latent=768, num_text_tokens=256, text_enc_depth=20,
|
self.clvp = CLVP(dim_text=768, dim_speech=768, dim_latent=768, num_text_tokens=256, text_enc_depth=20,
|
||||||
text_seq_len=350, text_heads=12,
|
text_seq_len=350, text_heads=12,
|
||||||
num_speech_tokens=8192, speech_enc_depth=20, speech_heads=12, speech_seq_len=430,
|
num_speech_tokens=8192, speech_enc_depth=20, speech_heads=12, speech_seq_len=430,
|
||||||
use_xformers=True).cpu().eval()
|
use_xformers=True).cpu().eval()
|
||||||
self.clvp.load_state_dict(torch.load(f'{models_dir}/clvp2.pth'))
|
self.clvp.load_state_dict(torch.load(get_model_path('clvp2.pth', models_dir)))
|
||||||
|
self.cvvp = None # CVVP model is only loaded if used.
|
||||||
|
|
||||||
self.vocoder = UnivNetGenerator().cpu()
|
self.vocoder = UnivNetGenerator().cpu()
|
||||||
self.vocoder.load_state_dict(torch.load(f'{models_dir}/vocoder.pth')['model_g'])
|
self.vocoder.load_state_dict(torch.load(get_model_path('vocoder.pth', models_dir))['model_g'])
|
||||||
self.vocoder.eval(inference=True)
|
self.vocoder.eval(inference=True)
|
||||||
|
|
||||||
# Random latent generators (RLGs) are loaded lazily.
|
# Random latent generators (RLGs) are loaded lazily.
|
||||||
self.rlg_auto = None
|
self.rlg_auto = None
|
||||||
self.rlg_diffusion = None
|
self.rlg_diffusion = None
|
||||||
|
|
||||||
|
def load_cvvp(self):
|
||||||
|
"""Load CVVP model."""
|
||||||
|
self.cvvp = CVVP(model_dim=512, transformer_heads=8, dropout=0, mel_codes=8192, conditioning_enc_depth=8, cond_mask_percentage=0,
|
||||||
|
speech_enc_depth=8, speech_mask_percentage=0, latent_multiplier=1).cpu().eval()
|
||||||
|
self.cvvp.load_state_dict(torch.load(get_model_path('cvvp.pth', self.models_dir)))
|
||||||
|
|
||||||
def get_conditioning_latents(self, voice_samples, return_mels=False):
|
def get_conditioning_latents(self, voice_samples, return_mels=False):
|
||||||
"""
|
"""
|
||||||
Transforms one or more voice_samples into a tuple (autoregressive_conditioning_latent, diffusion_conditioning_latent).
|
Transforms one or more voice_samples into a tuple (autoregressive_conditioning_latent, diffusion_conditioning_latent).
|
||||||
|
@ -273,9 +294,9 @@ class TextToSpeech:
|
||||||
# Lazy-load the RLG models.
|
# Lazy-load the RLG models.
|
||||||
if self.rlg_auto is None:
|
if self.rlg_auto is None:
|
||||||
self.rlg_auto = RandomLatentConverter(1024).eval()
|
self.rlg_auto = RandomLatentConverter(1024).eval()
|
||||||
self.rlg_auto.load_state_dict(torch.load(os.path.join(MODELS_DIR, 'rlg_auto.pth'), map_location=torch.device('cpu')))
|
self.rlg_auto.load_state_dict(torch.load(get_model_path('rlg_auto.pth', self.models_dir), map_location=torch.device('cpu')))
|
||||||
self.rlg_diffusion = RandomLatentConverter(2048).eval()
|
self.rlg_diffusion = RandomLatentConverter(2048).eval()
|
||||||
self.rlg_diffusion.load_state_dict(torch.load(os.path.join(MODELS_DIR, 'rlg_diffuser.pth'), map_location=torch.device('cpu')))
|
self.rlg_diffusion.load_state_dict(torch.load(get_model_path('rlg_diffuser.pth', self.models_dir), map_location=torch.device('cpu')))
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0]))
|
return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0]))
|
||||||
|
|
||||||
|
@ -305,6 +326,8 @@ class TextToSpeech:
|
||||||
return_deterministic_state=False,
|
return_deterministic_state=False,
|
||||||
# autoregressive generation parameters follow
|
# autoregressive generation parameters follow
|
||||||
num_autoregressive_samples=512, temperature=.8, length_penalty=1, repetition_penalty=2.0, top_p=.8, max_mel_tokens=500,
|
num_autoregressive_samples=512, temperature=.8, length_penalty=1, repetition_penalty=2.0, top_p=.8, max_mel_tokens=500,
|
||||||
|
# CVVP parameters follow
|
||||||
|
cvvp_amount=.0,
|
||||||
# diffusion generation parameters follow
|
# diffusion generation parameters follow
|
||||||
diffusion_iterations=100, cond_free=True, cond_free_k=2, diffusion_temperature=1.0,
|
diffusion_iterations=100, cond_free=True, cond_free_k=2, diffusion_temperature=1.0,
|
||||||
**hf_generate_kwargs):
|
**hf_generate_kwargs):
|
||||||
|
@ -330,6 +353,9 @@ class TextToSpeech:
|
||||||
I was interested in the premise, but the results were not as good as I was hoping. This is off by default, but
|
I was interested in the premise, but the results were not as good as I was hoping. This is off by default, but
|
||||||
could use some tuning.
|
could use some tuning.
|
||||||
:param typical_mass: The typical_mass parameter from the typical_sampling algorithm.
|
:param typical_mass: The typical_mass parameter from the typical_sampling algorithm.
|
||||||
|
~~CLVP-CVVP KNOBS~~
|
||||||
|
:param cvvp_amount: Controls the influence of the CVVP model in selecting the best output from the autoregressive model.
|
||||||
|
[0,1]. Values closer to 1 mean the CVVP model is more important, 0 disables the CVVP model.
|
||||||
~~DIFFUSION KNOBS~~
|
~~DIFFUSION KNOBS~~
|
||||||
:param diffusion_iterations: Number of diffusion steps to perform. [0,4000]. More steps means the network has more chances to iteratively refine
|
:param diffusion_iterations: Number of diffusion steps to perform. [0,4000]. More steps means the network has more chances to iteratively refine
|
||||||
the output, which should theoretically mean a higher quality output. Generally a value above 250 is not noticeably better,
|
the output, which should theoretically mean a higher quality output. Generally a value above 250 is not noticeably better,
|
||||||
|
@ -391,19 +417,35 @@ class TextToSpeech:
|
||||||
samples.append(codes)
|
samples.append(codes)
|
||||||
self.autoregressive = self.autoregressive.cpu()
|
self.autoregressive = self.autoregressive.cpu()
|
||||||
|
|
||||||
clvp_results = []
|
clip_results = []
|
||||||
self.clvp = self.clvp.cuda()
|
self.clvp = self.clvp.cuda()
|
||||||
|
if cvvp_amount > 0:
|
||||||
|
if self.cvvp is None:
|
||||||
|
self.load_cvvp()
|
||||||
|
self.cvvp = self.cvvp.cuda()
|
||||||
if verbose:
|
if verbose:
|
||||||
|
if self.cvvp is None:
|
||||||
print("Computing best candidates using CLVP")
|
print("Computing best candidates using CLVP")
|
||||||
|
else:
|
||||||
|
print(f"Computing best candidates using CLVP {int((1-cvvp_amount) * 100):02d}% and CVVP {int(cvvp_amount * 100):02d}%")
|
||||||
for batch in tqdm(samples, disable=not verbose):
|
for batch in tqdm(samples, disable=not verbose):
|
||||||
for i in range(batch.shape[0]):
|
for i in range(batch.shape[0]):
|
||||||
batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
|
batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
|
||||||
clvp = self.clvp(text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False)
|
clvp = self.clvp(text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False)
|
||||||
clvp_results.append(clvp)
|
if auto_conds is not None and cvvp_amount > 0:
|
||||||
clvp_results = torch.cat(clvp_results, dim=0)
|
cvvp_accumulator = 0
|
||||||
|
for cl in range(auto_conds.shape[1]):
|
||||||
|
cvvp_accumulator = cvvp_accumulator + self.cvvp(auto_conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False)
|
||||||
|
cvvp = cvvp_accumulator / auto_conds.shape[1]
|
||||||
|
clip_results.append(cvvp * cvvp_amount + clvp * (1-cvvp_amount))
|
||||||
|
else:
|
||||||
|
clip_results.append(clvp)
|
||||||
|
clip_results = torch.cat(clip_results, dim=0)
|
||||||
samples = torch.cat(samples, dim=0)
|
samples = torch.cat(samples, dim=0)
|
||||||
best_results = samples[torch.topk(clvp_results, k=k).indices]
|
best_results = samples[torch.topk(clip_results, k=k).indices]
|
||||||
self.clvp = self.clvp.cpu()
|
self.clvp = self.clvp.cpu()
|
||||||
|
if self.cvvp is not None:
|
||||||
|
self.cvvp = self.cvvp.cpu()
|
||||||
del samples
|
del samples
|
||||||
|
|
||||||
# The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning
|
# The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning
|
||||||
|
|
|
@ -19,6 +19,8 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--candidates', type=int, help='How many output candidates to produce per-voice.', default=3)
|
parser.add_argument('--candidates', type=int, help='How many output candidates to produce per-voice.', default=3)
|
||||||
parser.add_argument('--seed', type=int, help='Random seed which can be used to reproduce results.', default=None)
|
parser.add_argument('--seed', type=int, help='Random seed which can be used to reproduce results.', default=None)
|
||||||
parser.add_argument('--produce_debug_state', type=bool, help='Whether or not to produce debug_state.pth, which can aid in reproducing problems. Defaults to true.', default=True)
|
parser.add_argument('--produce_debug_state', type=bool, help='Whether or not to produce debug_state.pth, which can aid in reproducing problems. Defaults to true.', default=True)
|
||||||
|
parser.add_argument('--cvvp_amount', type=float, help='How much the CVVP model should influence the output.'
|
||||||
|
'Increasing this can in some cases reduce the likelyhood of multiple speakers. Defaults to 0 (disabled)', default=.0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
os.makedirs(args.output_path, exist_ok=True)
|
os.makedirs(args.output_path, exist_ok=True)
|
||||||
|
|
||||||
|
@ -33,7 +35,7 @@ if __name__ == '__main__':
|
||||||
voice_samples, conditioning_latents = load_voices(voice_sel)
|
voice_samples, conditioning_latents = load_voices(voice_sel)
|
||||||
|
|
||||||
gen, dbg_state = tts.tts_with_preset(args.text, k=args.candidates, voice_samples=voice_samples, conditioning_latents=conditioning_latents,
|
gen, dbg_state = tts.tts_with_preset(args.text, k=args.candidates, voice_samples=voice_samples, conditioning_latents=conditioning_latents,
|
||||||
preset=args.preset, use_deterministic_seed=args.seed, return_deterministic_state=True)
|
preset=args.preset, use_deterministic_seed=args.seed, return_deterministic_state=True, cvvp_amount=args.cvvp_amount)
|
||||||
if isinstance(gen, list):
|
if isinstance(gen, list):
|
||||||
for j, g in enumerate(gen):
|
for j, g in enumerate(gen):
|
||||||
torchaudio.save(os.path.join(args.output_path, f'{selected_voice}_{k}_{j}.wav'), g.squeeze(0).cpu(), 24000)
|
torchaudio.save(os.path.join(args.output_path, f'{selected_voice}_{k}_{j}.wav'), g.squeeze(0).cpu(), 24000)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user