forked from mrq/tortoise-tts
Enable redaction by default
This commit is contained in:
parent
53cb3299d4
commit
b11f6ddd60
|
@ -165,7 +165,7 @@ class TextToSpeech:
|
||||||
Main entry point into Tortoise.
|
Main entry point into Tortoise.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, autoregressive_batch_size=16, models_dir='.models', enable_redaction=False):
|
def __init__(self, autoregressive_batch_size=16, models_dir='.models', enable_redaction=True):
|
||||||
"""
|
"""
|
||||||
Constructor
|
Constructor
|
||||||
:param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
|
:param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
|
||||||
|
@ -275,7 +275,6 @@ class TextToSpeech:
|
||||||
"""
|
"""
|
||||||
# Use generally found best tuning knobs for generation.
|
# Use generally found best tuning knobs for generation.
|
||||||
kwargs.update({'temperature': .8, 'length_penalty': 1.0, 'repetition_penalty': 2.0,
|
kwargs.update({'temperature': .8, 'length_penalty': 1.0, 'repetition_penalty': 2.0,
|
||||||
#'typical_sampling': True,
|
|
||||||
'top_p': .8,
|
'top_p': .8,
|
||||||
'cond_free_k': 2.0, 'diffusion_temperature': 1.0})
|
'cond_free_k': 2.0, 'diffusion_temperature': 1.0})
|
||||||
# Presets are defined here.
|
# Presets are defined here.
|
||||||
|
|
|
@ -7,13 +7,52 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2FeatureExtractor, Wav2Vec2CTCTo
|
||||||
from tortoise.utils.audio import load_audio
|
from tortoise.utils.audio import load_audio
|
||||||
|
|
||||||
|
|
||||||
|
def max_alignment(s1, s2, skip_character='~', record={}):
|
||||||
|
"""
|
||||||
|
A clever function that aligns s1 to s2 as best it can. Wherever a character from s1 is not found in s2, a '~' is
|
||||||
|
used to replace that character.
|
||||||
|
|
||||||
|
Finally got to use my DP skills!
|
||||||
|
"""
|
||||||
|
assert skip_character not in s1, f"Found the skip character {skip_character} in the provided string, {s1}"
|
||||||
|
if len(s1) == 0:
|
||||||
|
return ''
|
||||||
|
if len(s2) == 0:
|
||||||
|
return skip_character * len(s1)
|
||||||
|
if s1 == s2:
|
||||||
|
return s1
|
||||||
|
if s1[0] == s2[0]:
|
||||||
|
return s1[0] + max_alignment(s1[1:], s2[1:], skip_character, record)
|
||||||
|
|
||||||
|
take_s1_key = (len(s1), len(s2) - 1)
|
||||||
|
if take_s1_key in record:
|
||||||
|
take_s1, take_s1_score = record[take_s1_key]
|
||||||
|
else:
|
||||||
|
take_s1 = max_alignment(s1, s2[1:], skip_character, record)
|
||||||
|
take_s1_score = len(take_s1.replace(skip_character, ''))
|
||||||
|
record[take_s1_key] = (take_s1, take_s1_score)
|
||||||
|
|
||||||
|
take_s2_key = (len(s1) - 1, len(s2))
|
||||||
|
if take_s2_key in record:
|
||||||
|
take_s2, take_s2_score = record[take_s2_key]
|
||||||
|
else:
|
||||||
|
take_s2 = max_alignment(s1[1:], s2, skip_character, record)
|
||||||
|
take_s2_score = len(take_s2.replace(skip_character, ''))
|
||||||
|
record[take_s2_key] = (take_s2, take_s2_score)
|
||||||
|
|
||||||
|
return take_s1 if take_s1_score > take_s2_score else skip_character + take_s2
|
||||||
|
|
||||||
|
|
||||||
class Wav2VecAlignment:
|
class Wav2VecAlignment:
|
||||||
|
"""
|
||||||
|
Uses wav2vec2 to perform audio<->text alignment.
|
||||||
|
"""
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").cpu()
|
self.model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").cpu()
|
||||||
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"facebook/wav2vec2-large-960h")
|
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"facebook/wav2vec2-large-960h")
|
||||||
self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('jbetker/tacotron_symbols')
|
self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('jbetker/tacotron_symbols')
|
||||||
|
|
||||||
def align(self, audio, expected_text, audio_sample_rate=24000, topk=3, return_partial=False):
|
def align(self, audio, expected_text, audio_sample_rate=24000):
|
||||||
orig_len = audio.shape[-1]
|
orig_len = audio.shape[-1]
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
@ -25,32 +64,59 @@ class Wav2VecAlignment:
|
||||||
self.model = self.model.cpu()
|
self.model = self.model.cpu()
|
||||||
|
|
||||||
logits = logits[0]
|
logits = logits[0]
|
||||||
|
pred_string = self.tokenizer.decode(logits.argmax(-1).tolist())
|
||||||
|
|
||||||
|
fixed_expectation = max_alignment(expected_text, pred_string)
|
||||||
w2v_compression = orig_len // logits.shape[0]
|
w2v_compression = orig_len // logits.shape[0]
|
||||||
expected_tokens = self.tokenizer.encode(expected_text)
|
expected_tokens = self.tokenizer.encode(fixed_expectation)
|
||||||
|
expected_chars = list(fixed_expectation)
|
||||||
if len(expected_tokens) == 1:
|
if len(expected_tokens) == 1:
|
||||||
return [0] # The alignment is simple; there is only one token.
|
return [0] # The alignment is simple; there is only one token.
|
||||||
expected_tokens.pop(0) # The first token is a given.
|
expected_tokens.pop(0) # The first token is a given.
|
||||||
next_expected_token = expected_tokens.pop(0)
|
expected_chars.pop(0)
|
||||||
|
|
||||||
alignments = [0]
|
alignments = [0]
|
||||||
|
def pop_till_you_win():
|
||||||
|
if len(expected_tokens) == 0:
|
||||||
|
return None
|
||||||
|
popped = expected_tokens.pop(0)
|
||||||
|
popped_char = expected_chars.pop(0)
|
||||||
|
while popped_char == '~':
|
||||||
|
alignments.append(-1)
|
||||||
|
if len(expected_tokens) == 0:
|
||||||
|
return None
|
||||||
|
popped = expected_tokens.pop(0)
|
||||||
|
popped_char = expected_chars.pop(0)
|
||||||
|
return popped
|
||||||
|
|
||||||
|
next_expected_token = pop_till_you_win()
|
||||||
for i, logit in enumerate(logits):
|
for i, logit in enumerate(logits):
|
||||||
top = logit.topk(topk).indices.tolist()
|
top = logit.argmax()
|
||||||
if next_expected_token in top:
|
if next_expected_token == top:
|
||||||
alignments.append(i * w2v_compression)
|
alignments.append(i * w2v_compression)
|
||||||
if len(expected_tokens) > 0:
|
if len(expected_tokens) > 0:
|
||||||
next_expected_token = expected_tokens.pop(0)
|
next_expected_token = pop_till_you_win()
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
if len(expected_tokens) > 0:
|
pop_till_you_win()
|
||||||
print(f"Alignment did not work. {len(expected_tokens)} were not found, with the following string un-aligned:"
|
assert len(expected_tokens) == 0, "This shouldn't happen. My coding sucks."
|
||||||
f" `{self.tokenizer.decode(expected_tokens)}`. Here's what wav2vec thought it heard:"
|
|
||||||
f"`{self.tokenizer.decode(logits.argmax(-1).tolist())}`")
|
|
||||||
if not return_partial:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return alignments
|
# Now fix up alignments. Anything with -1 should be interpolated.
|
||||||
|
alignments.append(orig_len) # This'll get removed but makes the algorithm below more readable.
|
||||||
|
for i in range(len(alignments)):
|
||||||
|
if alignments[i] == -1:
|
||||||
|
for j in range(i+1, len(alignments)):
|
||||||
|
if alignments[j] != -1:
|
||||||
|
next_found_token = j
|
||||||
|
break
|
||||||
|
for j in range(i, next_found_token):
|
||||||
|
gap = alignments[next_found_token] - alignments[i-1]
|
||||||
|
alignments[j] = (j-i+1) * gap // (next_found_token-i+1) + alignments[i-1]
|
||||||
|
|
||||||
def redact(self, audio, expected_text, audio_sample_rate=24000, topk=3):
|
return alignments[:-1]
|
||||||
|
|
||||||
|
def redact(self, audio, expected_text, audio_sample_rate=24000):
|
||||||
if '[' not in expected_text:
|
if '[' not in expected_text:
|
||||||
return audio
|
return audio
|
||||||
splitted = expected_text.split('[')
|
splitted = expected_text.split('[')
|
||||||
|
@ -58,33 +124,22 @@ class Wav2VecAlignment:
|
||||||
for spl in splitted[1:]:
|
for spl in splitted[1:]:
|
||||||
assert ']' in spl, 'Every "[" character must be paired with a "]" with no nesting.'
|
assert ']' in spl, 'Every "[" character must be paired with a "]" with no nesting.'
|
||||||
fully_split.extend(spl.split(']'))
|
fully_split.extend(spl.split(']'))
|
||||||
# Remove any non-alphabetic character in the input text. This makes matching more likely.
|
|
||||||
fully_split = [re.sub(r'[^a-zA-Z ]', '', s) for s in fully_split]
|
|
||||||
# At this point, fully_split is a list of strings, with every other string being something that should be redacted.
|
# At this point, fully_split is a list of strings, with every other string being something that should be redacted.
|
||||||
non_redacted_intervals = []
|
non_redacted_intervals = []
|
||||||
last_point = 0
|
last_point = 0
|
||||||
for i in range(len(fully_split)):
|
for i in range(len(fully_split)):
|
||||||
if i % 2 == 0:
|
if i % 2 == 0:
|
||||||
non_redacted_intervals.append((last_point, last_point + len(fully_split[i]) - 1))
|
end_interval = max(0, last_point + len(fully_split[i]) - 1)
|
||||||
|
non_redacted_intervals.append((last_point, end_interval))
|
||||||
last_point += len(fully_split[i])
|
last_point += len(fully_split[i])
|
||||||
|
|
||||||
bare_text = ''.join(fully_split)
|
bare_text = ''.join(fully_split)
|
||||||
alignments = self.align(audio, bare_text, audio_sample_rate, topk, return_partial=True)
|
alignments = self.align(audio, bare_text, audio_sample_rate)
|
||||||
# If alignment fails, we will attempt to recover by assuming the remaining alignments consume the rest of the string.
|
|
||||||
def get_alignment(i):
|
|
||||||
if i >= len(alignments):
|
|
||||||
return audio.shape[-1]
|
|
||||||
|
|
||||||
output_audio = []
|
output_audio = []
|
||||||
for nri in non_redacted_intervals:
|
for nri in non_redacted_intervals:
|
||||||
start, stop = nri
|
start, stop = nri
|
||||||
output_audio.append(audio[:, get_alignment(start):get_alignment(stop)])
|
output_audio.append(audio[:, alignments[start]:alignments[stop]])
|
||||||
return torch.cat(output_audio, dim=-1)
|
return torch.cat(output_audio, dim=-1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
some_audio = load_audio('../../results/train_dotrice_0.wav', 24000)
|
|
||||||
aligner = Wav2VecAlignment()
|
|
||||||
text = "[God fucking damn it I'm so angry] The expressiveness of autoregressive transformers is literally nuts! I absolutely adore them."
|
|
||||||
redact = aligner.redact(some_audio, text)
|
|
||||||
torchaudio.save(f'test_output.wav', redact, 24000)
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user