diff --git a/tortoise/api.py b/tortoise/api.py index 11a13ff..65c7d6e 100644 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -165,7 +165,7 @@ class TextToSpeech: Main entry point into Tortoise. """ - def __init__(self, autoregressive_batch_size=16, models_dir='.models', enable_redaction=False): + def __init__(self, autoregressive_batch_size=16, models_dir='.models', enable_redaction=True): """ Constructor :param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing @@ -275,7 +275,6 @@ class TextToSpeech: """ # Use generally found best tuning knobs for generation. kwargs.update({'temperature': .8, 'length_penalty': 1.0, 'repetition_penalty': 2.0, - #'typical_sampling': True, 'top_p': .8, 'cond_free_k': 2.0, 'diffusion_temperature': 1.0}) # Presets are defined here. diff --git a/tortoise/utils/wav2vec_alignment.py b/tortoise/utils/wav2vec_alignment.py index d7d159c..fe4a3fb 100644 --- a/tortoise/utils/wav2vec_alignment.py +++ b/tortoise/utils/wav2vec_alignment.py @@ -7,13 +7,52 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2FeatureExtractor, Wav2Vec2CTCTo from tortoise.utils.audio import load_audio +def max_alignment(s1, s2, skip_character='~', record={}): + """ + A clever function that aligns s1 to s2 as best it can. Wherever a character from s1 is not found in s2, a '~' is + used to replace that character. + + Finally got to use my DP skills! + """ + assert skip_character not in s1, f"Found the skip character {skip_character} in the provided string, {s1}" + if len(s1) == 0: + return '' + if len(s2) == 0: + return skip_character * len(s1) + if s1 == s2: + return s1 + if s1[0] == s2[0]: + return s1[0] + max_alignment(s1[1:], s2[1:], skip_character, record) + + take_s1_key = (len(s1), len(s2) - 1) + if take_s1_key in record: + take_s1, take_s1_score = record[take_s1_key] + else: + take_s1 = max_alignment(s1, s2[1:], skip_character, record) + take_s1_score = len(take_s1.replace(skip_character, '')) + record[take_s1_key] = (take_s1, take_s1_score) + + take_s2_key = (len(s1) - 1, len(s2)) + if take_s2_key in record: + take_s2, take_s2_score = record[take_s2_key] + else: + take_s2 = max_alignment(s1[1:], s2, skip_character, record) + take_s2_score = len(take_s2.replace(skip_character, '')) + record[take_s2_key] = (take_s2, take_s2_score) + + return take_s1 if take_s1_score > take_s2_score else skip_character + take_s2 + + class Wav2VecAlignment: + """ + Uses wav2vec2 to perform audio<->text alignment. + """ def __init__(self): self.model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").cpu() self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"facebook/wav2vec2-large-960h") self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('jbetker/tacotron_symbols') - def align(self, audio, expected_text, audio_sample_rate=24000, topk=3, return_partial=False): + def align(self, audio, expected_text, audio_sample_rate=24000): orig_len = audio.shape[-1] with torch.no_grad(): @@ -25,32 +64,59 @@ class Wav2VecAlignment: self.model = self.model.cpu() logits = logits[0] + pred_string = self.tokenizer.decode(logits.argmax(-1).tolist()) + + fixed_expectation = max_alignment(expected_text, pred_string) w2v_compression = orig_len // logits.shape[0] - expected_tokens = self.tokenizer.encode(expected_text) + expected_tokens = self.tokenizer.encode(fixed_expectation) + expected_chars = list(fixed_expectation) if len(expected_tokens) == 1: return [0] # The alignment is simple; there is only one token. expected_tokens.pop(0) # The first token is a given. - next_expected_token = expected_tokens.pop(0) + expected_chars.pop(0) + alignments = [0] + def pop_till_you_win(): + if len(expected_tokens) == 0: + return None + popped = expected_tokens.pop(0) + popped_char = expected_chars.pop(0) + while popped_char == '~': + alignments.append(-1) + if len(expected_tokens) == 0: + return None + popped = expected_tokens.pop(0) + popped_char = expected_chars.pop(0) + return popped + + next_expected_token = pop_till_you_win() for i, logit in enumerate(logits): - top = logit.topk(topk).indices.tolist() - if next_expected_token in top: + top = logit.argmax() + if next_expected_token == top: alignments.append(i * w2v_compression) if len(expected_tokens) > 0: - next_expected_token = expected_tokens.pop(0) + next_expected_token = pop_till_you_win() else: break - if len(expected_tokens) > 0: - print(f"Alignment did not work. {len(expected_tokens)} were not found, with the following string un-aligned:" - f" `{self.tokenizer.decode(expected_tokens)}`. Here's what wav2vec thought it heard:" - f"`{self.tokenizer.decode(logits.argmax(-1).tolist())}`") - if not return_partial: - return None + pop_till_you_win() + assert len(expected_tokens) == 0, "This shouldn't happen. My coding sucks." - return alignments + # Now fix up alignments. Anything with -1 should be interpolated. + alignments.append(orig_len) # This'll get removed but makes the algorithm below more readable. + for i in range(len(alignments)): + if alignments[i] == -1: + for j in range(i+1, len(alignments)): + if alignments[j] != -1: + next_found_token = j + break + for j in range(i, next_found_token): + gap = alignments[next_found_token] - alignments[i-1] + alignments[j] = (j-i+1) * gap // (next_found_token-i+1) + alignments[i-1] - def redact(self, audio, expected_text, audio_sample_rate=24000, topk=3): + return alignments[:-1] + + def redact(self, audio, expected_text, audio_sample_rate=24000): if '[' not in expected_text: return audio splitted = expected_text.split('[') @@ -58,33 +124,22 @@ class Wav2VecAlignment: for spl in splitted[1:]: assert ']' in spl, 'Every "[" character must be paired with a "]" with no nesting.' fully_split.extend(spl.split(']')) - # Remove any non-alphabetic character in the input text. This makes matching more likely. - fully_split = [re.sub(r'[^a-zA-Z ]', '', s) for s in fully_split] + # At this point, fully_split is a list of strings, with every other string being something that should be redacted. non_redacted_intervals = [] last_point = 0 for i in range(len(fully_split)): if i % 2 == 0: - non_redacted_intervals.append((last_point, last_point + len(fully_split[i]) - 1)) + end_interval = max(0, last_point + len(fully_split[i]) - 1) + non_redacted_intervals.append((last_point, end_interval)) last_point += len(fully_split[i]) bare_text = ''.join(fully_split) - alignments = self.align(audio, bare_text, audio_sample_rate, topk, return_partial=True) - # If alignment fails, we will attempt to recover by assuming the remaining alignments consume the rest of the string. - def get_alignment(i): - if i >= len(alignments): - return audio.shape[-1] + alignments = self.align(audio, bare_text, audio_sample_rate) output_audio = [] for nri in non_redacted_intervals: start, stop = nri - output_audio.append(audio[:, get_alignment(start):get_alignment(stop)]) + output_audio.append(audio[:, alignments[start]:alignments[stop]]) return torch.cat(output_audio, dim=-1) - -if __name__ == '__main__': - some_audio = load_audio('../../results/train_dotrice_0.wav', 24000) - aligner = Wav2VecAlignment() - text = "[God fucking damn it I'm so angry] The expressiveness of autoregressive transformers is literally nuts! I absolutely adore them." - redact = aligner.redact(some_audio, text) - torchaudio.save(f'test_output.wav', redact, 24000)