forked from mrq/tortoise-tts
Merge remote-tracking branch 'origin/main'
This commit is contained in:
commit
9eac62598a
|
@ -5,25 +5,8 @@ import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
|
||||||
from api import TextToSpeech
|
from api import TextToSpeech
|
||||||
from tortoise.utils.audio import load_audio, get_voices, load_voices
|
from utils.audio import load_audio, get_voices, load_voices
|
||||||
|
from utils.text import split_and_recombine_text
|
||||||
|
|
||||||
def split_and_recombine_text(texts, desired_length=200, max_len=300):
|
|
||||||
# TODO: also split across '!' and '?'. Attempt to keep quotations together.
|
|
||||||
texts = [s.strip() + "." for s in texts.split('.')]
|
|
||||||
|
|
||||||
i = 0
|
|
||||||
while i < len(texts):
|
|
||||||
ltxt = texts[i]
|
|
||||||
if len(ltxt) >= desired_length or i == len(texts)-1:
|
|
||||||
i += 1
|
|
||||||
continue
|
|
||||||
if len(ltxt) + len(texts[i+1]) > max_len:
|
|
||||||
i += 1
|
|
||||||
continue
|
|
||||||
texts[i] = f'{ltxt} {texts[i+1]}'
|
|
||||||
texts.pop(i+1)
|
|
||||||
return texts
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -119,14 +119,16 @@ def load_voices(voices):
|
||||||
if latent is None:
|
if latent is None:
|
||||||
assert len(latents) == 0, "Can only combine raw audio voices or latent voices, not both. Do it yourself if you want this."
|
assert len(latents) == 0, "Can only combine raw audio voices or latent voices, not both. Do it yourself if you want this."
|
||||||
clips.extend(clip)
|
clips.extend(clip)
|
||||||
elif voice is None:
|
elif clip is None:
|
||||||
assert len(voices) == 0, "Can only combine raw audio voices or latent voices, not both. Do it yourself if you want this."
|
assert len(clips) == 0, "Can only combine raw audio voices or latent voices, not both. Do it yourself if you want this."
|
||||||
latents.append(latent)
|
latents.append(latent)
|
||||||
if len(latents) == 0:
|
if len(latents) == 0:
|
||||||
return clips, None
|
return clips, None
|
||||||
else:
|
else:
|
||||||
latents = torch.stack(latents, dim=0)
|
latents_0 = torch.stack([l[0] for l in latents], dim=0).mean(dim=0)
|
||||||
return None, latents.mean(dim=0)
|
latents_1 = torch.stack([l[1] for l in latents], dim=0).mean(dim=0)
|
||||||
|
latents = (latents_0,latents_1)
|
||||||
|
return None, latents
|
||||||
|
|
||||||
|
|
||||||
class TacotronSTFT(torch.nn.Module):
|
class TacotronSTFT(torch.nn.Module):
|
||||||
|
@ -178,4 +180,4 @@ def wav_to_univnet_mel(wav, do_normalization=False):
|
||||||
mel = stft.mel_spectrogram(wav)
|
mel = stft.mel_spectrogram(wav)
|
||||||
if do_normalization:
|
if do_normalization:
|
||||||
mel = normalize_tacotron_mel(mel)
|
mel = normalize_tacotron_mel(mel)
|
||||||
return mel
|
return mel
|
||||||
|
|
84
tortoise/utils/text.py
Normal file
84
tortoise/utils/text.py
Normal file
|
@ -0,0 +1,84 @@
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
def split_and_recombine_text(text, desired_length=200, max_length=300):
|
||||||
|
"""Split text it into chunks of a desired length trying to keep sentences intact."""
|
||||||
|
# normalize text, remove redundant whitespace and convert non-ascii quotes to ascii
|
||||||
|
text = re.sub(r'\n\n+', '\n', text)
|
||||||
|
text = re.sub(r'\s+', ' ', text)
|
||||||
|
text = re.sub(r'[“”]', '"', text)
|
||||||
|
|
||||||
|
rv = []
|
||||||
|
in_quote = False
|
||||||
|
current = ""
|
||||||
|
split_pos = []
|
||||||
|
pos = -1
|
||||||
|
|
||||||
|
def seek(delta):
|
||||||
|
nonlocal pos, in_quote, text
|
||||||
|
is_neg = delta < 0
|
||||||
|
for _ in range(abs(delta)):
|
||||||
|
if is_neg:
|
||||||
|
pos -= 1
|
||||||
|
else:
|
||||||
|
pos += 1
|
||||||
|
if text[pos] == '"':
|
||||||
|
in_quote = not in_quote
|
||||||
|
return text[pos], text[pos+1] if pos < len(text)-1 else ""
|
||||||
|
|
||||||
|
def commit():
|
||||||
|
nonlocal rv, current, split_pos
|
||||||
|
rv.append(current)
|
||||||
|
current = ""
|
||||||
|
split_pos = []
|
||||||
|
|
||||||
|
while pos < len(text) - 1:
|
||||||
|
c, next_c = seek(1)
|
||||||
|
current += c
|
||||||
|
# do we need to force a split?
|
||||||
|
if len(current) >= max_length:
|
||||||
|
if len(split_pos) > 0 and len(current) > (desired_length / 2):
|
||||||
|
# we have at least one sentence and we are over half the desired length, seek back to the last split
|
||||||
|
d = pos - split_pos[-1]
|
||||||
|
seek(-d)
|
||||||
|
current = current[:-d]
|
||||||
|
else:
|
||||||
|
# no full sentences, seek back until we are not in the middle of a word and split there
|
||||||
|
while c not in '!?.\n ' and pos > 0 and len(current) > desired_length:
|
||||||
|
c, _ = seek(-1)
|
||||||
|
current = current[:-1]
|
||||||
|
commit()
|
||||||
|
# check for sentence boundaries
|
||||||
|
elif not in_quote and (c in '!?\n' or (c == '.' and next_c in '\n ')):
|
||||||
|
split_pos.append(pos)
|
||||||
|
if len(current) >= desired_length:
|
||||||
|
commit()
|
||||||
|
rv.append(current)
|
||||||
|
|
||||||
|
# clean up
|
||||||
|
rv = [s.strip() for s in rv]
|
||||||
|
rv = [s for s in rv if len(s) > 0]
|
||||||
|
|
||||||
|
return rv
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
class Test(unittest.TestCase):
|
||||||
|
def test_split_and_recombine_text(self):
|
||||||
|
text = """
|
||||||
|
This is a sample sentence.
|
||||||
|
This is another sample sentence.
|
||||||
|
This is a longer sample sentence that should force a split inthemiddlebutinotinthislongword.
|
||||||
|
"Don't split my quote... please"
|
||||||
|
"""
|
||||||
|
self.assertEqual(split_and_recombine_text(text, desired_length=20, max_length=40),
|
||||||
|
['This is a sample sentence.',
|
||||||
|
'This is another sample sentence.',
|
||||||
|
'This is a longer sample sentence that',
|
||||||
|
'should force a split',
|
||||||
|
'inthemiddlebutinotinthislongword.',
|
||||||
|
'"Don\'t split my quote... please"'])
|
||||||
|
|
||||||
|
unittest.main()
|
Loading…
Reference in New Issue
Block a user