Add redaction support

This commit is contained in:
James Betker 2022-05-02 14:57:29 -06:00
parent f823e31e49
commit f631123264
2 changed files with 104 additions and 4 deletions

View File

@ -19,6 +19,7 @@ from tortoise.models.vocoder import UnivNetGenerator
from tortoise.utils.audio import wav_to_univnet_mel, denormalize_tacotron_mel
from tortoise.utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule
from tortoise.utils.tokenizer import VoiceBpeTokenizer
from tortoise.utils.wav2vec_alignment import Wav2VecAlignment
pbar = None
@ -158,11 +159,23 @@ def classify_audio_clip(clip):
class TextToSpeech:
"""
Main entry point into Tortoise.
"""
def __init__(self, autoregressive_batch_size=16, models_dir='.models', enable_redaction=True):
"""
Constructor
:param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
GPU OOM errors. Larger numbers generates slightly faster.
:param models_dir: Where model weights are stored. This should only be specified if you are providing your own
models, otherwise use the defaults.
:param enable_redaction: When true, text enclosed in brackets are automatically redacted from the spoken output
(but are still rendered by the model). This can be used for prompt engineering.
"""
def __init__(self, autoregressive_batch_size=16, models_dir='.models'):
self.autoregressive_batch_size = autoregressive_batch_size
self.enable_redaction = enable_redaction
if self.enable_redaction:
self.aligner = Wav2VecAlignment()
self.tokenizer = VoiceBpeTokenizer()
download_models()
@ -380,7 +393,6 @@ class TextToSpeech:
wav_candidates = []
self.diffusion = self.diffusion.cuda()
self.vocoder = self.vocoder.cuda()
diffusion_conds =
for b in range(best_results.shape[0]):
codes = best_results[b].unsqueeze(0)
latents = best_latents[b].unsqueeze(0)
@ -403,6 +415,12 @@ class TextToSpeech:
self.diffusion = self.diffusion.cpu()
self.vocoder = self.vocoder.cpu()
def potentially_redact(self, clip, text):
if self.enable_redaction:
return self.aligner.redact(clip, text)
return clip
wav_candidates = [potentially_redact(wav_candidate, text) for wav_candidate in wav_candidates]
if len(wav_candidates) > 1:
return wav_candidates
return wav_candidates[0]

View File

@ -0,0 +1,82 @@
import torch
import torchaudio
from transformers import Wav2Vec2ForCTC, Wav2Vec2FeatureExtractor, Wav2Vec2CTCTokenizer, Wav2Vec2Processor
from tortoise.utils.audio import load_audio
class Wav2VecAlignment:
def __init__(self):
self.model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").cpu()
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"facebook/wav2vec2-large-960h")
self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('jbetker/tacotron_symbols')
def align(self, audio, expected_text, audio_sample_rate=24000, topk=3):
orig_len = audio.shape[-1]
with torch.no_grad():
self.model = self.model.cuda()
audio = audio.to('cuda')
audio = torchaudio.functional.resample(audio, audio_sample_rate, 16000)
clip_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7)
logits = self.model(clip_norm).logits
self.model = self.model.cpu()
logits = logits[0]
w2v_compression = orig_len // logits.shape[0]
expected_tokens = self.tokenizer.encode(expected_text)
if len(expected_tokens) == 1:
return [0] # The alignment is simple; there is only one token.
expected_tokens.pop(0) # The first token is a given.
next_expected_token = expected_tokens.pop(0)
alignments = [0]
for i, logit in enumerate(logits):
top = logit.topk(topk).indices.tolist()
if next_expected_token in top:
alignments.append(i * w2v_compression)
if len(expected_tokens) > 0:
next_expected_token = expected_tokens.pop(0)
else:
break
if len(expected_tokens) > 0:
print(f"Alignment did not work. {len(expected_tokens)} were not found, with the following string un-aligned:"
f" {self.tokenizer.decode(expected_tokens)}")
return None
return alignments
def redact(self, audio, expected_text, audio_sample_rate=24000, topk=3):
if '[' not in expected_text:
return audio
splitted = expected_text.split('[')
fully_split = [splitted[0]]
for spl in splitted[1:]:
assert ']' in spl, 'Every "[" character must be paired with a "]" with no nesting.'
fully_split.extend(spl.split(']'))
# At this point, fully_split is a list of strings, with every other string being something that should be redacted.
non_redacted_intervals = []
last_point = 0
for i in range(len(fully_split)):
if i % 2 == 0:
non_redacted_intervals.append((last_point, last_point + len(fully_split[i]) - 1))
last_point += len(fully_split[i])
bare_text = ''.join(fully_split)
alignments = self.align(audio, bare_text, audio_sample_rate, topk)
if alignments is None:
return audio # Cannot redact because alignment did not succeed.
output_audio = []
for nri in non_redacted_intervals:
start, stop = nri
output_audio.append(audio[:, alignments[start]:alignments[stop]])
return torch.cat(output_audio, dim=-1)
if __name__ == '__main__':
some_audio = load_audio('../../results/favorites/morgan_freeman_metallic_hydrogen.mp3', 24000)
aligner = Wav2VecAlignment()
text = "instead of molten iron, jupiter [and brown dwaves] have hydrogen, which [is under so much pressure that it] develops metallic properties"
redact = aligner.redact(some_audio, text)
torchaudio.save(f'test_output.wav', redact, 24000)