From 9007955d88798d28982321b7ed5e9dd003cb17a9 Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 2 May 2022 14:57:29 -0600 Subject: [PATCH] Add redaction support --- tortoise/api.py | 26 +++++++-- tortoise/utils/wav2vec_alignment.py | 82 +++++++++++++++++++++++++++++ 2 files changed, 104 insertions(+), 4 deletions(-) create mode 100644 tortoise/utils/wav2vec_alignment.py diff --git a/tortoise/api.py b/tortoise/api.py index a50057c..6cb4733 100644 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -19,6 +19,7 @@ from tortoise.models.vocoder import UnivNetGenerator from tortoise.utils.audio import wav_to_univnet_mel, denormalize_tacotron_mel from tortoise.utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule from tortoise.utils.tokenizer import VoiceBpeTokenizer +from tortoise.utils.wav2vec_alignment import Wav2VecAlignment pbar = None @@ -158,11 +159,23 @@ def classify_audio_clip(clip): class TextToSpeech: """ Main entry point into Tortoise. - :param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing - GPU OOM errors. Larger numbers generates slightly faster. """ - def __init__(self, autoregressive_batch_size=16, models_dir='.models'): + + def __init__(self, autoregressive_batch_size=16, models_dir='.models', enable_redaction=True): + """ + Constructor + :param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing + GPU OOM errors. Larger numbers generates slightly faster. + :param models_dir: Where model weights are stored. This should only be specified if you are providing your own + models, otherwise use the defaults. + :param enable_redaction: When true, text enclosed in brackets are automatically redacted from the spoken output + (but are still rendered by the model). This can be used for prompt engineering. + """ self.autoregressive_batch_size = autoregressive_batch_size + self.enable_redaction = enable_redaction + if self.enable_redaction: + self.aligner = Wav2VecAlignment() + self.tokenizer = VoiceBpeTokenizer() download_models() @@ -380,7 +393,6 @@ class TextToSpeech: wav_candidates = [] self.diffusion = self.diffusion.cuda() self.vocoder = self.vocoder.cuda() - diffusion_conds = for b in range(best_results.shape[0]): codes = best_results[b].unsqueeze(0) latents = best_latents[b].unsqueeze(0) @@ -403,6 +415,12 @@ class TextToSpeech: self.diffusion = self.diffusion.cpu() self.vocoder = self.vocoder.cpu() + def potentially_redact(self, clip, text): + if self.enable_redaction: + return self.aligner.redact(clip, text) + return clip + wav_candidates = [potentially_redact(wav_candidate, text) for wav_candidate in wav_candidates] if len(wav_candidates) > 1: return wav_candidates return wav_candidates[0] + diff --git a/tortoise/utils/wav2vec_alignment.py b/tortoise/utils/wav2vec_alignment.py new file mode 100644 index 0000000..748a2f5 --- /dev/null +++ b/tortoise/utils/wav2vec_alignment.py @@ -0,0 +1,82 @@ +import torch +import torchaudio +from transformers import Wav2Vec2ForCTC, Wav2Vec2FeatureExtractor, Wav2Vec2CTCTokenizer, Wav2Vec2Processor + +from tortoise.utils.audio import load_audio + + +class Wav2VecAlignment: + def __init__(self): + self.model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").cpu() + self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"facebook/wav2vec2-large-960h") + self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('jbetker/tacotron_symbols') + + def align(self, audio, expected_text, audio_sample_rate=24000, topk=3): + orig_len = audio.shape[-1] + + with torch.no_grad(): + self.model = self.model.cuda() + audio = audio.to('cuda') + audio = torchaudio.functional.resample(audio, audio_sample_rate, 16000) + clip_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7) + logits = self.model(clip_norm).logits + self.model = self.model.cpu() + + logits = logits[0] + w2v_compression = orig_len // logits.shape[0] + expected_tokens = self.tokenizer.encode(expected_text) + if len(expected_tokens) == 1: + return [0] # The alignment is simple; there is only one token. + expected_tokens.pop(0) # The first token is a given. + next_expected_token = expected_tokens.pop(0) + alignments = [0] + for i, logit in enumerate(logits): + top = logit.topk(topk).indices.tolist() + if next_expected_token in top: + alignments.append(i * w2v_compression) + if len(expected_tokens) > 0: + next_expected_token = expected_tokens.pop(0) + else: + break + + if len(expected_tokens) > 0: + print(f"Alignment did not work. {len(expected_tokens)} were not found, with the following string un-aligned:" + f" {self.tokenizer.decode(expected_tokens)}") + return None + + return alignments + + def redact(self, audio, expected_text, audio_sample_rate=24000, topk=3): + if '[' not in expected_text: + return audio + splitted = expected_text.split('[') + fully_split = [splitted[0]] + for spl in splitted[1:]: + assert ']' in spl, 'Every "[" character must be paired with a "]" with no nesting.' + fully_split.extend(spl.split(']')) + # At this point, fully_split is a list of strings, with every other string being something that should be redacted. + non_redacted_intervals = [] + last_point = 0 + for i in range(len(fully_split)): + if i % 2 == 0: + non_redacted_intervals.append((last_point, last_point + len(fully_split[i]) - 1)) + last_point += len(fully_split[i]) + + bare_text = ''.join(fully_split) + alignments = self.align(audio, bare_text, audio_sample_rate, topk) + if alignments is None: + return audio # Cannot redact because alignment did not succeed. + + output_audio = [] + for nri in non_redacted_intervals: + start, stop = nri + output_audio.append(audio[:, alignments[start]:alignments[stop]]) + return torch.cat(output_audio, dim=-1) + + +if __name__ == '__main__': + some_audio = load_audio('../../results/favorites/morgan_freeman_metallic_hydrogen.mp3', 24000) + aligner = Wav2VecAlignment() + text = "instead of molten iron, jupiter [and brown dwaves] have hydrogen, which [is under so much pressure that it] develops metallic properties" + redact = aligner.redact(some_audio, text) + torchaudio.save(f'test_output.wav', redact, 24000)