support tts typing
This commit is contained in:
parent
48cb6a5abd
commit
8fe0dff33c
|
@ -48,6 +48,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
|||
self.paths = [self.paths]
|
||||
self.paths_size_bytes = [os.path.getsize(p) for p in self.paths]
|
||||
self.total_size_bytes = sum(self.paths_size_bytes)
|
||||
self.types = opt_get(hparams, ['types'], [0 for _ in self.paths])
|
||||
|
||||
self.load_conditioning = opt_get(hparams, ['load_conditioning'], False)
|
||||
self.conditioning_candidates = opt_get(hparams, ['num_conditioning_candidates'], 1)
|
||||
|
@ -101,22 +102,23 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
|||
else:
|
||||
rand_offset -= self.paths_size_bytes[i]
|
||||
path = self.paths[i]
|
||||
type = self.types[i]
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
f.seek(rand_offset)
|
||||
# Read the rest of the line we seeked to, then the line after that.
|
||||
try: # This can fail when seeking to a UTF-8 escape byte.
|
||||
f.readline()
|
||||
except:
|
||||
return self.load_random_line(depth=depth + 1) # On failure, just recurse and try again.
|
||||
return self.load_random_line(depth=depth + 1), type # On failure, just recurse and try again.
|
||||
l2 = f.readline()
|
||||
|
||||
if l2:
|
||||
try:
|
||||
base_path = os.path.dirname(path)
|
||||
return parse_tsv_aligned_codes(l2, base_path)
|
||||
return parse_tsv_aligned_codes(l2, base_path), type
|
||||
except:
|
||||
print(f"error parsing random offset: {sys.exc_info()}")
|
||||
return self.load_random_line(depth=depth+1) # On failure, just recurse and try again.
|
||||
return self.load_random_line(depth=depth+1), type # On failure, just recurse and try again.
|
||||
|
||||
def get_ctc_metadata(self, codes):
|
||||
grouped = groupby(codes.tolist())
|
||||
|
@ -155,7 +157,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
|||
def __getitem__(self, index):
|
||||
start = time.time()
|
||||
self.skipped_items += 1
|
||||
apt = self.load_random_line()
|
||||
apt, type = self.load_random_line()
|
||||
try:
|
||||
tseq, wav, text, path = self.get_wav_text_pair(apt)
|
||||
if text is None or len(text.strip()) == 0:
|
||||
|
@ -204,7 +206,8 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
|||
'wav_lengths': torch.tensor(orig_output, dtype=torch.long),
|
||||
'filenames': path,
|
||||
'skipped_items': actually_skipped_items,
|
||||
'load_time': self.load_times.mean()
|
||||
'load_time': self.load_times.mean(),
|
||||
'type': type,
|
||||
}
|
||||
if self.load_conditioning:
|
||||
res['conditioning'] = cond
|
||||
|
@ -266,10 +269,11 @@ if __name__ == '__main__':
|
|||
'path': ['y:/libritts/train-other-500/transcribed-oco.tsv',
|
||||
'y:/libritts/train-clean-100/transcribed-oco.tsv',
|
||||
'y:/libritts/train-clean-360/transcribed-oco.tsv',
|
||||
'y:/clips/books1/transcribed-w2v.tsv',
|
||||
'y:/clips/books2/transcribed-w2v.tsv',
|
||||
'y:/bigasr_dataset/hifi_tts/transcribed-w2v.tsv',
|
||||
'y:/clips/books1/transcribed-oco.tsv',
|
||||
'y:/clips/books2/transcribed-oco.tsv',
|
||||
'y:/bigasr_dataset/hifi_tts/transcribed-oco.tsv',
|
||||
'y:/clips/podcasts-1/transcribed-oco.tsv',],
|
||||
'types': [0,1,1,1,2,2,0],
|
||||
'phase': 'train',
|
||||
'n_workers': 0,
|
||||
'batch_size': batch_sz,
|
||||
|
|
|
@ -9,11 +9,28 @@ import torchaudio
|
|||
from tqdm import tqdm
|
||||
|
||||
from data.audio.unsupervised_audio_dataset import load_audio, load_similar_clips
|
||||
from models.audio.tts.tacotron2 import load_filepaths_and_text
|
||||
from models.audio.tts.tacotron2 import load_filepaths_and_text, load_filepaths_and_text_type
|
||||
from models.audio.tts.tacotron2 import text_to_sequence, sequence_to_text
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
def load_tsv_type(filename, type):
|
||||
with open(filename, encoding='utf-8') as f:
|
||||
filepaths_and_text = []
|
||||
base = os.path.dirname(filename)
|
||||
bad_lines = 0
|
||||
for line in f:
|
||||
components = line.strip().split('\t')
|
||||
if len(components) < 2:
|
||||
bad_lines += 1
|
||||
if bad_lines > 1000:
|
||||
print(f'{filename} contains 1000+ bad entries. Failing. Sample last entry: {line}')
|
||||
raise ValueError
|
||||
continue
|
||||
filepaths_and_text.append([os.path.join(base, f'{components[1]}'), components[0]] + [type])
|
||||
return filepaths_and_text
|
||||
|
||||
|
||||
def load_tsv(filename):
|
||||
with open(filename, encoding='utf-8') as f:
|
||||
filepaths_and_text = []
|
||||
|
@ -40,6 +57,23 @@ def convert_string_list_to_tensor(strlist):
|
|||
return torch.tensor(as_ints)
|
||||
|
||||
|
||||
def load_tsv_aligned_codes_type(filename, type):
|
||||
with open(filename, encoding='utf-8') as f:
|
||||
filepaths_and_text = []
|
||||
base = os.path.dirname(filename)
|
||||
bad_lines = 0
|
||||
for line in f:
|
||||
components = line.strip().split('\t')
|
||||
if len(components) < 3:
|
||||
bad_lines += 1
|
||||
if bad_lines > 1000:
|
||||
print(f'{filename} contains 1000+ bad entries. Failing. Sample last entry: {line}')
|
||||
raise ValueError
|
||||
continue
|
||||
filepaths_and_text.append([os.path.join(base, f'{components[1]}'), components[0], convert_string_list_to_tensor(components[2])] + [type])
|
||||
return filepaths_and_text
|
||||
|
||||
|
||||
def load_tsv_aligned_codes(filename):
|
||||
with open(filename, encoding='utf-8') as f:
|
||||
filepaths_and_text = []
|
||||
|
@ -57,15 +91,15 @@ def load_tsv_aligned_codes(filename):
|
|||
return filepaths_and_text
|
||||
|
||||
|
||||
def load_mozilla_cv(filename):
|
||||
def load_mozilla_cv(filename, type):
|
||||
with open(filename, encoding='utf-8') as f:
|
||||
components = [line.strip().split('\t') for line in f][1:] # First line is the header
|
||||
base = os.path.dirname(filename)
|
||||
filepaths_and_text = [[os.path.join(base, f'clips/{component[1]}'), component[2]] for component in components]
|
||||
filepaths_and_text = [[os.path.join(base, f'clips/{component[1]}'), component[2], type] for component in components]
|
||||
return filepaths_and_text
|
||||
|
||||
|
||||
def load_voxpopuli(filename):
|
||||
def load_voxpopuli(filename, type):
|
||||
with open(filename, encoding='utf-8') as f:
|
||||
lines = [line.strip().split('\t') for line in f][1:] # First line is the header
|
||||
base = os.path.dirname(filename)
|
||||
|
@ -75,7 +109,7 @@ def load_voxpopuli(filename):
|
|||
continue
|
||||
file, raw_text, norm_text, speaker_id, split, gender = line
|
||||
year = file[:4]
|
||||
filepaths_and_text.append([os.path.join(base, year, f'{file}.ogg.wav'), raw_text])
|
||||
filepaths_and_text.append([os.path.join(base, year, f'{file}.ogg.wav'), raw_text, type])
|
||||
return filepaths_and_text
|
||||
|
||||
|
||||
|
@ -92,6 +126,7 @@ class TextWavLoader(torch.utils.data.Dataset):
|
|||
self.path = hparams['path']
|
||||
if not isinstance(self.path, list):
|
||||
self.path = [self.path]
|
||||
self.types = opt_get(hparams, ['types'], [0 for _ in self.path])
|
||||
|
||||
fetcher_mode = opt_get(hparams, ['fetcher_mode'], 'lj')
|
||||
if not isinstance(fetcher_mode, list):
|
||||
|
@ -105,11 +140,11 @@ class TextWavLoader(torch.utils.data.Dataset):
|
|||
self.load_aligned_codes = opt_get(hparams, ['load_aligned_codes'], False)
|
||||
self.aligned_codes_to_audio_ratio = opt_get(hparams, ['aligned_codes_ratio'], 443)
|
||||
self.audiopaths_and_text = []
|
||||
for p, fm in zip(self.path, fetcher_mode):
|
||||
for p, fm, type in zip(self.path, fetcher_mode, self.types):
|
||||
if fm == 'lj' or fm == 'libritts':
|
||||
fetcher_fn = load_filepaths_and_text
|
||||
fetcher_fn = load_filepaths_and_text_type
|
||||
elif fm == 'tsv':
|
||||
fetcher_fn = load_tsv_aligned_codes if self.load_aligned_codes else load_tsv
|
||||
fetcher_fn = load_tsv_aligned_codes_type if self.load_aligned_codes else load_tsv_type
|
||||
elif fm == 'mozilla_cv':
|
||||
assert not self.load_conditioning # Conditioning inputs are incompatible with mozilla_cv
|
||||
fetcher_fn = load_mozilla_cv
|
||||
|
@ -118,7 +153,7 @@ class TextWavLoader(torch.utils.data.Dataset):
|
|||
fetcher_fn = load_voxpopuli
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
self.audiopaths_and_text.extend(fetcher_fn(p))
|
||||
self.audiopaths_and_text.extend(fetcher_fn(p, type))
|
||||
self.text_cleaners = hparams.text_cleaners
|
||||
self.sample_rate = hparams.sample_rate
|
||||
random.seed(hparams.seed)
|
||||
|
@ -138,10 +173,10 @@ class TextWavLoader(torch.utils.data.Dataset):
|
|||
|
||||
def get_wav_text_pair(self, audiopath_and_text):
|
||||
# separate filename and text
|
||||
audiopath, text = audiopath_and_text[0], audiopath_and_text[1]
|
||||
audiopath, text, type = audiopath_and_text[0], audiopath_and_text[1], audiopath_and_text[2]
|
||||
text_seq = self.get_text(text)
|
||||
wav = load_audio(audiopath, self.sample_rate)
|
||||
return (text_seq, wav, text, audiopath_and_text[0])
|
||||
return (text_seq, wav, text, audiopath_and_text[0], type)
|
||||
|
||||
def get_text(self, text):
|
||||
tokens = self.tokenizer.encode(text)
|
||||
|
@ -156,7 +191,7 @@ class TextWavLoader(torch.utils.data.Dataset):
|
|||
def __getitem__(self, index):
|
||||
self.skipped_items += 1
|
||||
try:
|
||||
tseq, wav, text, path = self.get_wav_text_pair(self.audiopaths_and_text[index])
|
||||
tseq, wav, text, path, type = self.get_wav_text_pair(self.audiopaths_and_text[index])
|
||||
if text is None or len(text.strip()) == 0:
|
||||
raise ValueError
|
||||
if wav is None or wav.shape[-1] < (.6 * self.sample_rate):
|
||||
|
@ -202,6 +237,7 @@ class TextWavLoader(torch.utils.data.Dataset):
|
|||
'wav_lengths': torch.tensor(orig_output, dtype=torch.long),
|
||||
'filenames': path,
|
||||
'skipped_items': actually_skipped_items,
|
||||
'type': type,
|
||||
}
|
||||
if self.load_conditioning:
|
||||
res['conditioning'] = cond
|
||||
|
@ -249,7 +285,7 @@ if __name__ == '__main__':
|
|||
batch_sz = 8
|
||||
params = {
|
||||
'mode': 'paired_voice_audio',
|
||||
'path': ['Y:\\clips\\books1\\transcribed-w2v.tsv'],
|
||||
'path': ['Y:\\clips\\books1\\transcribed-oco.tsv'],
|
||||
'fetcher_mode': ['tsv'],
|
||||
'phase': 'train',
|
||||
'n_workers': 0,
|
||||
|
|
|
@ -26,6 +26,14 @@ def load_wav_to_torch(full_path):
|
|||
return (torch.FloatTensor(data.astype(np.float32)) / norm_fix, sampling_rate)
|
||||
|
||||
|
||||
def load_filepaths_and_text_type(filename, type, split="|"):
|
||||
with open(filename, encoding='utf-8') as f:
|
||||
filepaths_and_text = [list(line.strip().split(split)) + [type] for line in f]
|
||||
base = os.path.dirname(filename)
|
||||
for j in range(len(filepaths_and_text)):
|
||||
filepaths_and_text[j][0] = os.path.join(base, filepaths_and_text[j][0])
|
||||
return filepaths_and_text
|
||||
|
||||
def load_filepaths_and_text(filename, split="|"):
|
||||
with open(filename, encoding='utf-8') as f:
|
||||
filepaths_and_text = [line.strip().split(split) for line in f]
|
||||
|
|
|
@ -238,10 +238,10 @@ class MelEncoder(nn.Module):
|
|||
|
||||
class UnifiedVoice(nn.Module):
|
||||
def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1,
|
||||
mel_length_compression=1024, number_text_tokens=256,
|
||||
start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
|
||||
mel_length_compression=1024, number_text_tokens=256, number_mel_codes=8194, start_mel_token=8192,
|
||||
stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
|
||||
checkpointing=True, average_conditioning_embeddings=False, freeze_everything_but_position_embeddings=False):
|
||||
checkpointing=True, average_conditioning_embeddings=False, freeze_everything_but_position_embeddings=False,
|
||||
types=1):
|
||||
"""
|
||||
Args:
|
||||
layers: Number of layers in transformer stack.
|
||||
|
@ -252,7 +252,6 @@ class UnifiedVoice(nn.Module):
|
|||
max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s).
|
||||
mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length.
|
||||
number_text_tokens:
|
||||
start_text_token:
|
||||
stop_text_token:
|
||||
number_mel_codes:
|
||||
start_mel_token:
|
||||
|
@ -265,8 +264,8 @@ class UnifiedVoice(nn.Module):
|
|||
super().__init__()
|
||||
|
||||
self.number_text_tokens = number_text_tokens
|
||||
self.start_text_token = start_text_token
|
||||
self.stop_text_token = stop_text_token
|
||||
self.start_text_token = number_text_tokens * types
|
||||
self.stop_text_token = 0
|
||||
self.number_mel_codes = number_mel_codes
|
||||
self.start_mel_token = start_mel_token
|
||||
self.stop_mel_token = stop_mel_token
|
||||
|
@ -279,7 +278,7 @@ class UnifiedVoice(nn.Module):
|
|||
self.mel_length_compression = mel_length_compression
|
||||
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
||||
self.average_conditioning_embeddings = average_conditioning_embeddings
|
||||
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
|
||||
self.text_embedding = nn.Embedding(self.number_text_tokens*types+1, model_dim)
|
||||
if use_mel_codes_as_input:
|
||||
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
|
||||
else:
|
||||
|
@ -294,7 +293,7 @@ class UnifiedVoice(nn.Module):
|
|||
self.text_solo_embedding = 0
|
||||
|
||||
self.final_norm = nn.LayerNorm(model_dim)
|
||||
self.text_head = nn.Linear(model_dim, self.number_text_tokens)
|
||||
self.text_head = nn.Linear(model_dim, self.number_text_tokens*types+1)
|
||||
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
|
||||
|
||||
# Initialize the embeddings per the GPT-2 scheme
|
||||
|
@ -366,7 +365,7 @@ class UnifiedVoice(nn.Module):
|
|||
else:
|
||||
return first_logits
|
||||
|
||||
def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, text_first=True, raw_mels=None, return_attentions=False,
|
||||
def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, types=None, text_first=True, raw_mels=None, return_attentions=False,
|
||||
return_latent=False, clip_inputs=True):
|
||||
"""
|
||||
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
|
||||
|
@ -383,6 +382,10 @@ class UnifiedVoice(nn.Module):
|
|||
If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
|
||||
If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality.
|
||||
"""
|
||||
# Types are expressed by expanding the text embedding space.
|
||||
if types is not None:
|
||||
text_inputs = text_inputs * (1+types).unsqueeze(-1)
|
||||
|
||||
if clip_inputs:
|
||||
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
|
||||
# chopping the inputs by the maximum actual length.
|
||||
|
@ -579,10 +582,11 @@ def register_unified_voice2(opt_net, opt):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4, freeze_everything_but_position_embeddings=True)
|
||||
gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4, freeze_everything_but_position_embeddings=True, types=2)
|
||||
l = gpt(torch.randn(2, 3, 80, 800),
|
||||
torch.randint(high=256, size=(2,120)),
|
||||
torch.tensor([32, 120]),
|
||||
torch.randint(high=8192, size=(2,250)),
|
||||
torch.tensor([250*256,195*256]))
|
||||
torch.tensor([250*256,195*256]),
|
||||
types=torch.tensor([0, 1]))
|
||||
#gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80]))
|
||||
|
|
Loading…
Reference in New Issue
Block a user