Add support for extracting and feeding conditioning latents directly into the model

- Adds a new script and API endpoints for doing this
- Reworks autoregressive and diffusion models so that the conditioning is computed separately (which will actually provide a mild performance boost)
- Updates README

This is untested. Need to do the following manual tests (and someday write unit tests for this behemoth before
it becomes a problem..)
1) Does get_conditioning_latents.py work?
2) Can I feed those latents back into the model by creating a new voice?
3) Can I still mix and match voices (both with conditioning latents and normal voices) with read.py?
This commit is contained in:
James Betker 2022-05-01 17:25:18 -06:00
parent a8264f5cef
commit 0ffc191408
8 changed files with 165 additions and 78 deletions

View File

@ -118,12 +118,24 @@ These settings are not available in the normal scripts packaged with Tortoise. T
### Playing with the voice latent
Tortoise ingests reference clips by feeding them through individually through a small submodel that produces a point latent, then taking the mean of all of the produced latents. The experimentation I have done has indicated that these point latents are quite expressive, affecting
everything from tone to speaking rate to speech abnormalities.
Tortoise ingests reference clips by feeding them through individually through a small submodel that produces a point latent,
then taking the mean of all of the produced latents. The experimentation I have done has indicated that these point latents
are quite expressive, affecting everything from tone to speaking rate to speech abnormalities.
This lends itself to some neat tricks. For example, you can combine feed two different voices to tortoise and it will output what it thinks the "average" of those two voices sounds like. You could also theoretically build a small extension to Tortoise that gradually shifts the
latent from one speaker to another, then apply it across a bit of spoken text (something I havent implemented yet, but might
get to soon!) I am sure there are other interesting things that can be done here. Please let me know what you find!
This lends itself to some neat tricks. For example, you can combine feed two different voices to tortoise and it will output
what it thinks the "average" of those two voices sounds like.
#### Generating conditioning latents from voices
Use the script `get_conditioning_latents.py` to extract conditioning latents for a voice you have installed. This script
will dump the latents to a .pth pickle file. The file will contain a single tuple, (autoregressive_latent, diffusion_latent).
Alternatively, use the api.TextToSpeech.get_conditioning_latents() to fetch the latents.
#### Using raw conditioning latents to generate speech
After you've played with them, you can use them to generate speech by creating a subdirectory in voices/ with a single
".pth" file containing the pickled conditioning latents as a tuple (autoregressive_latent, diffusion_latent).
### Send me feedback!

View File

@ -121,23 +121,14 @@ def fix_autoregressive_output(codes, stop_token, complain=True):
return codes
def do_spectrogram_diffusion(diffusion_model, diffuser, latents, conditioning_samples, temperature=1, verbose=True):
def do_spectrogram_diffusion(diffusion_model, diffuser, latents, conditioning_latents, temperature=1, verbose=True):
"""
Uses the specified diffusion model to convert discrete codes into a spectrogram.
"""
with torch.no_grad():
cond_mels = []
for sample in conditioning_samples:
# The diffuser operates at a sample rate of 24000 (except for the latent inputs)
sample = torchaudio.functional.resample(sample, 22050, 24000)
sample = pad_or_truncate(sample, 102400)
cond_mel = wav_to_univnet_mel(sample.to(latents.device), do_normalization=False)
cond_mels.append(cond_mel)
cond_mels = torch.stack(cond_mels, dim=1)
output_seq_len = latents.shape[1] * 4 * 24000 // 22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
output_shape = (latents.shape[0], 100, output_seq_len)
precomputed_embeddings = diffusion_model.timestep_independent(latents, cond_mels, output_seq_len, False)
precomputed_embeddings = diffusion_model.timestep_independent(latents, conditioning_latents, output_seq_len, False)
noise = torch.randn(output_shape, device=latents.device) * temperature
mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=noise,
@ -204,7 +195,7 @@ class TextToSpeech:
self.vocoder.load_state_dict(torch.load(f'{models_dir}/vocoder.pth')['model_g'])
self.vocoder.eval(inference=True)
def tts_with_preset(self, text, voice_samples, preset='fast', **kwargs):
def tts_with_preset(self, text, preset='fast', **kwargs):
"""
Calls TTS with one of a set of preset generation parameters. Options:
'ultra_fast': Produces speech at a speed which belies the name of this repo. (Not really, but it's definitely fastest).
@ -225,9 +216,43 @@ class TextToSpeech:
'high_quality': {'num_autoregressive_samples': 512, 'diffusion_iterations': 1024},
}
kwargs.update(presets[preset])
return self.tts(text, voice_samples, **kwargs)
return self.tts(text, **kwargs)
def tts(self, text, voice_samples, k=1, verbose=True,
def get_conditioning_latents(self, voice_samples):
"""
Transforms one or more voice_samples into a tuple (autoregressive_conditioning_latent, diffusion_conditioning_latent).
These are expressive learned latents that encode aspects of the provided clips like voice, intonation, and acoustic
properties.
:param voice_samples: List of 2 or more ~10 second reference clips, which should be torch tensors containing 22.05kHz waveform data.
"""
voice_samples = [v.to('cuda') for v in voice_samples]
auto_conds = []
if not isinstance(voice_samples, list):
voice_samples = [voice_samples]
for vs in voice_samples:
auto_conds.append(format_conditioning(vs))
auto_conds = torch.stack(auto_conds, dim=1)
self.autoregressive = self.autoregressive.cuda()
auto_latent = self.autoregressive.get_conditioning(auto_conds)
self.autoregressive = self.autoregressive.cpu()
diffusion_conds = []
for sample in voice_samples:
# The diffuser operates at a sample rate of 24000 (except for the latent inputs)
sample = torchaudio.functional.resample(sample, 22050, 24000)
sample = pad_or_truncate(sample, 102400)
cond_mel = wav_to_univnet_mel(sample.to(voice_samples.device), do_normalization=False)
diffusion_conds.append(cond_mel)
diffusion_conds = torch.stack(diffusion_conds, dim=1)
self.diffusion = self.diffusion.cuda()
diffusion_latent = self.diffusion.get_conditioning(diffusion_conds)
self.diffusion = self.diffusion.cpu()
return auto_latent, diffusion_latent
def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose=True,
# autoregressive generation parameters follow
num_autoregressive_samples=512, temperature=.8, length_penalty=1, repetition_penalty=2.0, top_p=.8, max_mel_tokens=500,
typical_sampling=False, typical_mass=.9,
@ -240,6 +265,9 @@ class TextToSpeech:
Produces an audio clip of the given text being spoken with the given reference voice.
:param text: Text to be spoken.
:param voice_samples: List of 2 or more ~10 second reference clips which should be torch tensors containing 22.05kHz waveform data.
:param conditioning_latents: A tuple of (autoregressive_conditioning_latent, diffusion_conditioning_latent), which
can be provided in lieu of voice_samples. This is ignored unless voice_samples=None.
Conditioning latents can be retrieved via get_conditioning_latents().
:param k: The number of returned clips. The most likely (as determined by Tortoises' CLVP and CVVP models) clips are returned.
:param verbose: Whether or not to print log messages indicating the progress of creating a clip. Default=true.
~~AUTOREGRESSIVE KNOBS~~
@ -283,12 +311,10 @@ class TextToSpeech:
text = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).cuda()
text = F.pad(text, (0, 1)) # This may not be necessary.
conds = []
if not isinstance(voice_samples, list):
voice_samples = [voice_samples]
for vs in voice_samples:
conds.append(format_conditioning(vs))
conds = torch.stack(conds, dim=1)
if voice_samples is not None:
auto_conditioning, diffusion_conditioning = self.get_conditioning_latents(voice_samples)
else:
auto_conditioning, diffusion_conditioning = conditioning_latents
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k)
@ -301,7 +327,7 @@ class TextToSpeech:
if verbose:
print("Generating autoregressive samples..")
for b in tqdm(range(num_batches), disable=not verbose):
codes = self.autoregressive.inference_speech(conds, text,
codes = self.autoregressive.inference_speech(auto_conditioning, text,
do_sample=True,
top_p=top_p,
temperature=temperature,
@ -340,16 +366,18 @@ class TextToSpeech:
# inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
# results, but will increase memory usage.
self.autoregressive = self.autoregressive.cuda()
best_latents = self.autoregressive(conds, text, torch.tensor([text.shape[-1]], device=conds.device), best_results,
best_latents = self.autoregressive(auto_conditioning, text, torch.tensor([text.shape[-1]], device=conds.device), best_results,
torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=conds.device),
return_latent=True, clip_inputs=False)
self.autoregressive = self.autoregressive.cpu()
del auto_conditioning
if verbose:
print("Transforming autoregressive outputs into audio..")
wav_candidates = []
self.diffusion = self.diffusion.cuda()
self.vocoder = self.vocoder.cuda()
diffusion_conds =
for b in range(best_results.shape[0]):
codes = best_results[b].unsqueeze(0)
latents = best_latents[b].unsqueeze(0)
@ -365,7 +393,8 @@ class TextToSpeech:
latents = latents[:, :k]
break
mel = do_spectrogram_diffusion(self.diffusion, diffuser, latents, voice_samples, temperature=diffusion_temperature, verbose=verbose)
mel = do_spectrogram_diffusion(self.diffusion, diffuser, latents, diffusion_conditioning,
temperature=diffusion_temperature, verbose=verbose)
wav = self.vocoder.inference(mel)
wav_candidates.append(wav.cpu())
self.diffusion = self.diffusion.cpu()

View File

@ -4,7 +4,7 @@ import os
import torchaudio
from api import TextToSpeech
from tortoise.utils.audio import load_audio, get_voices
from tortoise.utils.audio import load_audio, get_voices, load_voice
if __name__ == '__main__':
parser = argparse.ArgumentParser()
@ -21,14 +21,10 @@ if __name__ == '__main__':
tts = TextToSpeech()
voices = get_voices()
selected_voices = args.voice.split(',')
for voice in selected_voices:
cond_paths = voices[voice]
conds = []
for cond_path in cond_paths:
c = load_audio(cond_path, 22050)
conds.append(c)
gen = tts.tts_with_preset(args.text, conds, preset=args.preset, clvp_cvvp_slider=args.voice_diversity_intelligibility_slider)
voice_samples, conditioning_latents = load_voice(voice)
gen = tts.tts_with_preset(args.text, voice_samples=voice_samples, conditioning_latents=conditioning_latents,
preset=args.preset, clvp_cvvp_slider=args.voice_diversity_intelligibility_slider)
torchaudio.save(os.path.join(args.output_path, f'{voice}.wav'), gen.squeeze(0).cpu(), 24000)

View File

@ -0,0 +1,30 @@
import argparse
import os
import torch
from api import TextToSpeech
from tortoise.utils.audio import load_audio, get_voices
"""
Dumps the conditioning latents for the specified voice to disk. These are expressive latents which can be used for
other ML models, or can be augmented manually and fed back into Tortoise to affect vocal qualities.
"""
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--voice', type=str, help='Selects the voice to convert to conditioning latents', default='pat')
parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/conditioning_latents')
args = parser.parse_args()
os.makedirs(args.output_path, exist_ok=True)
tts = TextToSpeech()
voices = get_voices()
selected_voices = args.voice.split(',')
for voice in selected_voices:
cond_paths = voices[voice]
conds = []
for cond_path in cond_paths:
c = load_audio(cond_path, 22050)
conds.append(c)
conditioning_latents = tts.get_conditioning_latents(conds)
torch.save(conditioning_latents, os.path.join(args.output_path, f'{voice}.pth'))

View File

@ -390,6 +390,17 @@ class UnifiedVoice(nn.Module):
else:
return first_logits
def get_conditioning(self, speech_conditioning_input):
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(
speech_conditioning_input.shape) == 3 else speech_conditioning_input
conds = []
for j in range(speech_conditioning_input.shape[1]):
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
conds = torch.stack(conds, dim=1)
if self.average_conditioning_embeddings:
conds = conds.mean(dim=1).unsqueeze(1)
return conds
def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, types=None, text_first=True, raw_mels=None, return_attentions=False,
return_latent=False, clip_inputs=True):
"""
@ -424,14 +435,7 @@ class UnifiedVoice(nn.Module):
text_inputs = F.pad(text_inputs, (0,1), value=self.stop_text_token)
mel_codes = F.pad(mel_codes, (0,1), value=self.stop_mel_token)
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
conds = []
for j in range(speech_conditioning_input.shape[1]):
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
conds = torch.stack(conds, dim=1)
if self.average_conditioning_embeddings:
conds = conds.mean(dim=1).unsqueeze(1)
conds = self.get_conditioning(speech_conditioning_input)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
@ -516,7 +520,7 @@ class UnifiedVoice(nn.Module):
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
return loss_mel.mean()
def inference_speech(self, speech_conditioning_input, text_inputs, input_tokens=None, num_return_sequences=1,
def inference_speech(self, speech_conditioning_latent, text_inputs, input_tokens=None, num_return_sequences=1,
max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
if not hasattr(self, 'inference_model'):
@ -536,14 +540,7 @@ class UnifiedVoice(nn.Module):
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
conds = []
for j in range(speech_conditioning_input.shape[1]):
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
conds = torch.stack(conds, dim=1)
if self.average_conditioning_embeddings:
conds = conds.mean(dim=1).unsqueeze(1)
conds = speech_conditioning_latent
emb = torch.cat([conds, text_emb], dim=1)
self.inference_model.store_mel_emb(emb)

View File

@ -219,18 +219,21 @@ class DiffusionTts(nn.Module):
}
return groups
def timestep_independent(self, aligned_conditioning, conditioning_input, expected_seq_len, return_code_pred):
# Shuffle aligned_latent to BxCxS format
if is_latent(aligned_conditioning):
aligned_conditioning = aligned_conditioning.permute(0, 2, 1)
# Note: this block does not need to repeated on inference, since it is not timestep-dependent or x-dependent.
def get_conditioning(self, conditioning_input):
speech_conditioning_input = conditioning_input.unsqueeze(1) if len(
conditioning_input.shape) == 3 else conditioning_input
conds = []
for j in range(speech_conditioning_input.shape[1]):
conds.append(self.contextual_embedder(speech_conditioning_input[:, j]))
conds = torch.cat(conds, dim=-1)
return conds
def timestep_independent(self, aligned_conditioning, conditioning_latent, expected_seq_len, return_code_pred):
# Shuffle aligned_latent to BxCxS format
if is_latent(aligned_conditioning):
aligned_conditioning = aligned_conditioning.permute(0, 2, 1)
conds = conditioning_latent
cond_emb = conds.mean(dim=-1)
cond_scale, cond_shift = torch.chunk(cond_emb, 2, dim=1)
if is_latent(aligned_conditioning):
@ -257,19 +260,19 @@ class DiffusionTts(nn.Module):
mel_pred = mel_pred * unconditioned_batches.logical_not()
return expanded_code_emb, mel_pred
def forward(self, x, timesteps, aligned_conditioning=None, conditioning_input=None, precomputed_aligned_embeddings=None, conditioning_free=False, return_code_pred=False):
def forward(self, x, timesteps, aligned_conditioning=None, conditioning_latent=None, precomputed_aligned_embeddings=None, conditioning_free=False, return_code_pred=False):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param aligned_conditioning: an aligned latent or sequence of tokens providing useful data about the sample to be produced.
:param conditioning_input: a full-resolution audio clip that is used as a reference to the style you want decoded.
:param conditioning_latent: a pre-computed conditioning latent; see get_conditioning().
:param precomputed_aligned_embeddings: Embeddings returned from self.timestep_independent()
:param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered.
:return: an [N x C x ...] Tensor of outputs.
"""
assert precomputed_aligned_embeddings is not None or (aligned_conditioning is not None and conditioning_input is not None)
assert precomputed_aligned_embeddings is not None or (aligned_conditioning is not None and conditioning_latent is not None)
assert not (return_code_pred and precomputed_aligned_embeddings is not None) # These two are mutually exclusive.
unused_params = []
@ -281,7 +284,7 @@ class DiffusionTts(nn.Module):
if precomputed_aligned_embeddings is not None:
code_emb = precomputed_aligned_embeddings
else:
code_emb, mel_pred = self.timestep_independent(aligned_conditioning, conditioning_input, x.shape[-1], True)
code_emb, mel_pred = self.timestep_independent(aligned_conditioning, conditioning_latent, x.shape[-1], True)
if is_latent(aligned_conditioning):
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
else:

View File

@ -5,7 +5,7 @@ import torch
import torchaudio
from api import TextToSpeech
from tortoise.utils.audio import load_audio, get_voices
from tortoise.utils.audio import load_audio, get_voices, load_voices
def split_and_recombine_text(texts, desired_length=200, max_len=300):
@ -40,7 +40,6 @@ if __name__ == '__main__':
args = parser.parse_args()
outpath = args.output_path
voices = get_voices()
selected_voices = args.voice.split(',')
regenerate = args.regenerate
if regenerate is not None:
@ -58,25 +57,15 @@ if __name__ == '__main__':
voice_sel = selected_voice.split('&')
else:
voice_sel = [selected_voice]
cond_paths = []
for vsel in voice_sel:
if vsel not in voices.keys():
print(f'Error: voice {vsel} not available. Skipping.')
continue
cond_paths.extend(voices[vsel])
if not cond_paths:
print('Error: no valid voices specified. Try again.')
conds = []
for cond_path in cond_paths:
c = load_audio(cond_path, 22050)
conds.append(c)
voice_samples, conditioning_latents = load_voices(voice_sel)
all_parts = []
for j, text in enumerate(texts):
if regenerate is not None and j not in regenerate:
all_parts.append(load_audio(os.path.join(voice_outpath, f'{j}.wav'), 24000))
continue
gen = tts.tts_with_preset(text, conds, preset=args.preset, clvp_cvvp_slider=args.voice_diversity_intelligibility_slider)
gen = tts.tts_with_preset(text, voice_samples=voice_samples, conditioning_latents=conditioning_latents,
preset=args.preset, clvp_cvvp_slider=args.voice_diversity_intelligibility_slider)
gen = gen.squeeze(0).cpu()
torchaudio.save(os.path.join(voice_outpath, f'{j}.wav'), gen, 24000)
all_parts.append(gen)

View File

@ -91,6 +91,37 @@ def get_voices():
return voices
def load_voice(voice):
voices = get_voices()
paths = voices[voice]
if len(paths) == 1 and paths[0].endswith('.pth'):
return None, torch.load(paths[0])
else:
conds = []
for cond_path in paths:
c = load_audio(cond_path, 22050)
conds.append(c)
return conds, None
def load_voices(voices):
latents = []
clips = []
for voice in voices:
latent, clip = load_voice(voice)
if latent is None:
assert len(latents) == 0, "Can only combine raw audio voices or latent voices, not both. Do it yourself if you want this."
clips.extend(clip)
elif voice is None:
assert len(voices) == 0, "Can only combine raw audio voices or latent voices, not both. Do it yourself if you want this."
latents.append(latent)
if len(latents) == 0:
return clips
else:
latents = torch.stack(latents, dim=0)
return latents.mean(dim=0)
class TacotronSTFT(torch.nn.Module):
def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,