91 lines
4.1 KiB
Python
91 lines
4.1 KiB
Python
import re
|
|
|
|
import torch
|
|
import torchaudio
|
|
from transformers import Wav2Vec2ForCTC, Wav2Vec2FeatureExtractor, Wav2Vec2CTCTokenizer, Wav2Vec2Processor
|
|
|
|
from tortoise.utils.audio import load_audio
|
|
|
|
|
|
class Wav2VecAlignment:
|
|
def __init__(self):
|
|
self.model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").cpu()
|
|
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"facebook/wav2vec2-large-960h")
|
|
self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('jbetker/tacotron_symbols')
|
|
|
|
def align(self, audio, expected_text, audio_sample_rate=24000, topk=3, return_partial=False):
|
|
orig_len = audio.shape[-1]
|
|
|
|
with torch.no_grad():
|
|
self.model = self.model.cuda()
|
|
audio = audio.to('cuda')
|
|
audio = torchaudio.functional.resample(audio, audio_sample_rate, 16000)
|
|
clip_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7)
|
|
logits = self.model(clip_norm).logits
|
|
self.model = self.model.cpu()
|
|
|
|
logits = logits[0]
|
|
w2v_compression = orig_len // logits.shape[0]
|
|
expected_tokens = self.tokenizer.encode(expected_text)
|
|
if len(expected_tokens) == 1:
|
|
return [0] # The alignment is simple; there is only one token.
|
|
expected_tokens.pop(0) # The first token is a given.
|
|
next_expected_token = expected_tokens.pop(0)
|
|
alignments = [0]
|
|
for i, logit in enumerate(logits):
|
|
top = logit.topk(topk).indices.tolist()
|
|
if next_expected_token in top:
|
|
alignments.append(i * w2v_compression)
|
|
if len(expected_tokens) > 0:
|
|
next_expected_token = expected_tokens.pop(0)
|
|
else:
|
|
break
|
|
|
|
if len(expected_tokens) > 0:
|
|
print(f"Alignment did not work. {len(expected_tokens)} were not found, with the following string un-aligned:"
|
|
f" `{self.tokenizer.decode(expected_tokens)}`. Here's what wav2vec thought it heard:"
|
|
f"`{self.tokenizer.decode(logits.argmax(-1).tolist())}`")
|
|
if not return_partial:
|
|
return None
|
|
|
|
return alignments
|
|
|
|
def redact(self, audio, expected_text, audio_sample_rate=24000, topk=3):
|
|
if '[' not in expected_text:
|
|
return audio
|
|
splitted = expected_text.split('[')
|
|
fully_split = [splitted[0]]
|
|
for spl in splitted[1:]:
|
|
assert ']' in spl, 'Every "[" character must be paired with a "]" with no nesting.'
|
|
fully_split.extend(spl.split(']'))
|
|
# Remove any non-alphabetic character in the input text. This makes matching more likely.
|
|
fully_split = [re.sub(r'[^a-zA-Z ]', '', s) for s in fully_split]
|
|
# At this point, fully_split is a list of strings, with every other string being something that should be redacted.
|
|
non_redacted_intervals = []
|
|
last_point = 0
|
|
for i in range(len(fully_split)):
|
|
if i % 2 == 0:
|
|
non_redacted_intervals.append((last_point, last_point + len(fully_split[i]) - 1))
|
|
last_point += len(fully_split[i])
|
|
|
|
bare_text = ''.join(fully_split)
|
|
alignments = self.align(audio, bare_text, audio_sample_rate, topk, return_partial=True)
|
|
# If alignment fails, we will attempt to recover by assuming the remaining alignments consume the rest of the string.
|
|
def get_alignment(i):
|
|
if i >= len(alignments):
|
|
return audio.shape[-1]
|
|
|
|
output_audio = []
|
|
for nri in non_redacted_intervals:
|
|
start, stop = nri
|
|
output_audio.append(audio[:, get_alignment(start):get_alignment(stop)])
|
|
return torch.cat(output_audio, dim=-1)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
some_audio = load_audio('../../results/train_dotrice_0.wav', 24000)
|
|
aligner = Wav2VecAlignment()
|
|
text = "[God fucking damn it I'm so angry] The expressiveness of autoregressive transformers is literally nuts! I absolutely adore them."
|
|
redact = aligner.redact(some_audio, text)
|
|
torchaudio.save(f'test_output.wav', redact, 24000)
|