misc fixes

This commit is contained in:
James Betker 2022-05-02 18:00:57 -06:00
parent e00606a601
commit 5663e98904
6 changed files with 53 additions and 51 deletions

View File

@ -37,6 +37,8 @@ def download_models(specific_models=None):
'cvvp.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/hf/.models/cvvp.pth', 'cvvp.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/hf/.models/cvvp.pth',
'diffusion_decoder.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/hf/.models/diffusion_decoder.pth', 'diffusion_decoder.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/hf/.models/diffusion_decoder.pth',
'vocoder.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/hf/.models/vocoder.pth', 'vocoder.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/hf/.models/vocoder.pth',
'rlg_auto.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/hf/.models/rlg_auto.pth',
'rlg_diffuser.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/hf/.models/rlg_diffuser.pth',
} }
os.makedirs('.models', exist_ok=True) os.makedirs('.models', exist_ok=True)
def show_progress(block_num, block_size, total_size): def show_progress(block_num, block_size, total_size):
@ -110,9 +112,9 @@ def fix_autoregressive_output(codes, stop_token, complain=True):
stop_token_indices = (codes == stop_token).nonzero() stop_token_indices = (codes == stop_token).nonzero()
if len(stop_token_indices) == 0: if len(stop_token_indices) == 0:
if complain: if complain:
print("No stop tokens found. This typically means the spoken audio is too long. In some cases, the output " print("No stop tokens found in one of the generated voice clips. This typically means the spoken audio is "
"will still be good, though. Listen to it and if it is missing words, try breaking up your input " "too long. In some cases, the output will still be good, though. Listen to it and if it is missing words, "
"text.") "try breaking up your input text.")
return codes return codes
else: else:
codes[stop_token_indices] = 83 codes[stop_token_indices] = 83
@ -163,8 +165,7 @@ class TextToSpeech:
Main entry point into Tortoise. Main entry point into Tortoise.
""" """
def __init__(self, autoregressive_batch_size=16, models_dir='.models', enable_redaction=True, def __init__(self, autoregressive_batch_size=16, models_dir='.models', enable_redaction=True):
save_random_voices=False):
""" """
Constructor Constructor
:param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing :param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
@ -174,14 +175,11 @@ class TextToSpeech:
:param enable_redaction: When true, text enclosed in brackets are automatically redacted from the spoken output :param enable_redaction: When true, text enclosed in brackets are automatically redacted from the spoken output
(but are still rendered by the model). This can be used for prompt engineering. (but are still rendered by the model). This can be used for prompt engineering.
Default is true. Default is true.
:param save_random_voices: When true, voices that are randomly generated are saved to the `random_voices`
directory. Default is false.
""" """
self.autoregressive_batch_size = autoregressive_batch_size self.autoregressive_batch_size = autoregressive_batch_size
self.enable_redaction = enable_redaction self.enable_redaction = enable_redaction
if self.enable_redaction: if self.enable_redaction:
self.aligner = Wav2VecAlignment() self.aligner = Wav2VecAlignment()
self.save_random_voices = save_random_voices
self.tokenizer = VoiceBpeTokenizer() self.tokenizer = VoiceBpeTokenizer()
download_models() download_models()
@ -220,29 +218,6 @@ class TextToSpeech:
self.rlg_auto = None self.rlg_auto = None
self.rlg_diffusion = None self.rlg_diffusion = None
def tts_with_preset(self, text, preset='fast', **kwargs):
"""
Calls TTS with one of a set of preset generation parameters. Options:
'ultra_fast': Produces speech at a speed which belies the name of this repo. (Not really, but it's definitely fastest).
'fast': Decent quality speech at a decent inference rate. A good choice for mass inference.
'standard': Very good quality. This is generally about as good as you are going to get.
'high_quality': Use if you want the absolute best. This is not really worth the compute, though.
"""
# Use generally found best tuning knobs for generation.
kwargs.update({'temperature': .8, 'length_penalty': 1.0, 'repetition_penalty': 2.0,
#'typical_sampling': True,
'top_p': .8,
'cond_free_k': 2.0, 'diffusion_temperature': 1.0})
# Presets are defined here.
presets = {
'ultra_fast': {'num_autoregressive_samples': 16, 'diffusion_iterations': 32, 'cond_free': False},
'fast': {'num_autoregressive_samples': 96, 'diffusion_iterations': 32},
'standard': {'num_autoregressive_samples': 256, 'diffusion_iterations': 128},
'high_quality': {'num_autoregressive_samples': 512, 'diffusion_iterations': 1024},
}
kwargs.update(presets[preset])
return self.tts(text, **kwargs)
def get_conditioning_latents(self, voice_samples, return_mels=False): def get_conditioning_latents(self, voice_samples, return_mels=False):
""" """
Transforms one or more voice_samples into a tuple (autoregressive_conditioning_latent, diffusion_conditioning_latent). Transforms one or more voice_samples into a tuple (autoregressive_conditioning_latent, diffusion_conditioning_latent).
@ -288,11 +263,30 @@ class TextToSpeech:
self.rlg_diffusion = RandomLatentConverter(2048).eval() self.rlg_diffusion = RandomLatentConverter(2048).eval()
self.rlg_diffusion.load_state_dict(torch.load('.models/rlg_diffuser.pth', map_location=torch.device('cpu'))) self.rlg_diffusion.load_state_dict(torch.load('.models/rlg_diffuser.pth', map_location=torch.device('cpu')))
with torch.no_grad(): with torch.no_grad():
latents = self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0])) return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0]))
if self.save_random_voices:
os.makedirs('random_voices', exist_ok=True) def tts_with_preset(self, text, preset='fast', **kwargs):
torch.save(latents, f'random_voices/{str(uuid.uuid4())}.pth') """
return latents Calls TTS with one of a set of preset generation parameters. Options:
'ultra_fast': Produces speech at a speed which belies the name of this repo. (Not really, but it's definitely fastest).
'fast': Decent quality speech at a decent inference rate. A good choice for mass inference.
'standard': Very good quality. This is generally about as good as you are going to get.
'high_quality': Use if you want the absolute best. This is not really worth the compute, though.
"""
# Use generally found best tuning knobs for generation.
kwargs.update({'temperature': .8, 'length_penalty': 1.0, 'repetition_penalty': 2.0,
#'typical_sampling': True,
'top_p': .8,
'cond_free_k': 2.0, 'diffusion_temperature': 1.0})
# Presets are defined here.
presets = {
'ultra_fast': {'num_autoregressive_samples': 16, 'diffusion_iterations': 32, 'cond_free': False},
'fast': {'num_autoregressive_samples': 96, 'diffusion_iterations': 32},
'standard': {'num_autoregressive_samples': 256, 'diffusion_iterations': 128},
'high_quality': {'num_autoregressive_samples': 512, 'diffusion_iterations': 1024},
}
kwargs.update(presets[preset])
return self.tts(text, **kwargs)
def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose=True, def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose=True,
# autoregressive generation parameters follow # autoregressive generation parameters follow
@ -452,7 +446,7 @@ class TextToSpeech:
def potentially_redact(clip, text): def potentially_redact(clip, text):
if self.enable_redaction: if self.enable_redaction:
return self.aligner.redact(clip, text) return self.aligner.redact(clip.squeeze(1), text).unsqueeze(1)
return clip return clip
wav_candidates = [potentially_redact(wav_candidate, text) for wav_candidate in wav_candidates] wav_candidates = [potentially_redact(wav_candidate, text) for wav_candidate in wav_candidates]
if len(wav_candidates) > 1: if len(wav_candidates) > 1:

View File

@ -8,7 +8,7 @@ from tortoise.utils.audio import load_audio, get_voices, load_voice
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.") parser.add_argument('--text', type=str, help='Text to speak.', default="The expressiveness of autoregressive transformers is literally nuts! I absolutely adore them.")
parser.add_argument('--voice', type=str, help='Selects the voice to use for generation. See options in voices/ directory (and add your own!) ' parser.add_argument('--voice', type=str, help='Selects the voice to use for generation. See options in voices/ directory (and add your own!) '
'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.', default='random') 'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.', default='random')
parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='fast') parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='fast')
@ -21,7 +21,7 @@ if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
os.makedirs(args.output_path, exist_ok=True) os.makedirs(args.output_path, exist_ok=True)
tts = TextToSpeech(models_dir=args.model_dir, save_random_voices=True) tts = TextToSpeech(models_dir=args.model_dir)
selected_voices = args.voice.split(',') selected_voices = args.voice.split(',')
for k, voice in enumerate(selected_voices): for k, voice in enumerate(selected_voices):

View File

@ -5,7 +5,7 @@ from tortoise.utils.audio import load_audio
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--clip', type=str, help='Path to an audio clip to classify.', default="results/favorite_riding_hood.mp3") parser.add_argument('--clip', type=str, help='Path to an audio clip to classify.', default="../examples/favorite_riding_hood.mp3")
args = parser.parse_args() args = parser.parse_args()
clip = load_audio(args.clip, 24000) clip = load_audio(args.clip, 24000)

View File

@ -40,7 +40,7 @@ if __name__ == '__main__':
parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this' parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this'
'should only be specified if you have custom checkpoints.', default='.models') 'should only be specified if you have custom checkpoints.', default='.models')
args = parser.parse_args() args = parser.parse_args()
tts = TextToSpeech(models_dir=args.model_dir, save_random_voices=True) tts = TextToSpeech(models_dir=args.model_dir)
outpath = args.output_path outpath = args.output_path
selected_voices = args.voice.split(',') selected_voices = args.voice.split(',')

View File

@ -114,7 +114,7 @@ def load_voices(voices):
if voice == 'random': if voice == 'random':
print("Cannot combine a random voice with a non-random voice. Just using a random voice.") print("Cannot combine a random voice with a non-random voice. Just using a random voice.")
return None, None return None, None
latent, clip = load_voice(voice) clip, latent = load_voice(voice)
if latent is None: if latent is None:
assert len(latents) == 0, "Can only combine raw audio voices or latent voices, not both. Do it yourself if you want this." assert len(latents) == 0, "Can only combine raw audio voices or latent voices, not both. Do it yourself if you want this."
clips.extend(clip) clips.extend(clip)

View File

@ -1,3 +1,5 @@
import re
import torch import torch
import torchaudio import torchaudio
from transformers import Wav2Vec2ForCTC, Wav2Vec2FeatureExtractor, Wav2Vec2CTCTokenizer, Wav2Vec2Processor from transformers import Wav2Vec2ForCTC, Wav2Vec2FeatureExtractor, Wav2Vec2CTCTokenizer, Wav2Vec2Processor
@ -11,7 +13,7 @@ class Wav2VecAlignment:
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"facebook/wav2vec2-large-960h") self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"facebook/wav2vec2-large-960h")
self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('jbetker/tacotron_symbols') self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('jbetker/tacotron_symbols')
def align(self, audio, expected_text, audio_sample_rate=24000, topk=3): def align(self, audio, expected_text, audio_sample_rate=24000, topk=3, return_partial=False):
orig_len = audio.shape[-1] orig_len = audio.shape[-1]
with torch.no_grad(): with torch.no_grad():
@ -41,7 +43,9 @@ class Wav2VecAlignment:
if len(expected_tokens) > 0: if len(expected_tokens) > 0:
print(f"Alignment did not work. {len(expected_tokens)} were not found, with the following string un-aligned:" print(f"Alignment did not work. {len(expected_tokens)} were not found, with the following string un-aligned:"
f" {self.tokenizer.decode(expected_tokens)}") f" `{self.tokenizer.decode(expected_tokens)}`. Here's what wav2vec thought it heard:"
f"`{self.tokenizer.decode(logits.argmax(-1).tolist())}`")
if not return_partial:
return None return None
return alignments return alignments
@ -54,6 +58,8 @@ class Wav2VecAlignment:
for spl in splitted[1:]: for spl in splitted[1:]:
assert ']' in spl, 'Every "[" character must be paired with a "]" with no nesting.' assert ']' in spl, 'Every "[" character must be paired with a "]" with no nesting.'
fully_split.extend(spl.split(']')) fully_split.extend(spl.split(']'))
# Remove any non-alphabetic character in the input text. This makes matching more likely.
fully_split = [re.sub(r'[^a-zA-Z ]', '', s) for s in fully_split]
# At this point, fully_split is a list of strings, with every other string being something that should be redacted. # At this point, fully_split is a list of strings, with every other string being something that should be redacted.
non_redacted_intervals = [] non_redacted_intervals = []
last_point = 0 last_point = 0
@ -63,20 +69,22 @@ class Wav2VecAlignment:
last_point += len(fully_split[i]) last_point += len(fully_split[i])
bare_text = ''.join(fully_split) bare_text = ''.join(fully_split)
alignments = self.align(audio, bare_text, audio_sample_rate, topk) alignments = self.align(audio, bare_text, audio_sample_rate, topk, return_partial=True)
if alignments is None: # If alignment fails, we will attempt to recover by assuming the remaining alignments consume the rest of the string.
return audio # Cannot redact because alignment did not succeed. def get_alignment(i):
if i >= len(alignments):
return audio.shape[-1]
output_audio = [] output_audio = []
for nri in non_redacted_intervals: for nri in non_redacted_intervals:
start, stop = nri start, stop = nri
output_audio.append(audio[:, alignments[start]:alignments[stop]]) output_audio.append(audio[:, get_alignment(start):get_alignment(stop)])
return torch.cat(output_audio, dim=-1) return torch.cat(output_audio, dim=-1)
if __name__ == '__main__': if __name__ == '__main__':
some_audio = load_audio('../../results/favorites/morgan_freeman_metallic_hydrogen.mp3', 24000) some_audio = load_audio('../../results/train_dotrice_0.wav', 24000)
aligner = Wav2VecAlignment() aligner = Wav2VecAlignment()
text = "instead of molten iron, jupiter [and brown dwaves] have hydrogen, which [is under so much pressure that it] develops metallic properties" text = "[God fucking damn it I'm so angry] The expressiveness of autoregressive transformers is literally nuts! I absolutely adore them."
redact = aligner.redact(some_audio, text) redact = aligner.redact(some_audio, text)
torchaudio.save(f'test_output.wav', redact, 24000) torchaudio.save(f'test_output.wav', redact, 24000)