1493 lines
49 KiB
Python
Executable File
1493 lines
49 KiB
Python
Executable File
# todo: clean this mess up
|
|
|
|
import copy
|
|
import h5py
|
|
import json
|
|
import logging
|
|
import numpy as np
|
|
import os
|
|
import random
|
|
import torch
|
|
import itertools
|
|
|
|
from .config import cfg
|
|
from .emb.qnt import trim, trim_random, repeat_extend_audio, merge_audio, decode_to_file
|
|
from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler
|
|
from .utils.distributed import global_rank, local_rank, world_size
|
|
|
|
from collections import defaultdict
|
|
from functools import cache, cached_property
|
|
from itertools import groupby, zip_longest
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from torch import Tensor
|
|
from torch.utils.data import DataLoader, Dataset as _Dataset
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
from torch.nn.utils.rnn import pad_sequence
|
|
|
|
from tqdm.auto import tqdm
|
|
# torch.multiprocessing.set_sharing_strategy("file_system")
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
# fold into a typical LLM sequence (one embedding rather than split embeddings)
|
|
def fold_inputs(
|
|
text_list = [],
|
|
prom_list = [],
|
|
resp_list = [],
|
|
targ_list = [],
|
|
|
|
ignore_index = None,
|
|
|
|
sep = 3,
|
|
stop = 3,
|
|
|
|
text_tokens = 256,
|
|
audio_tokens = 1024,
|
|
audio_rvq_levels = cfg.model.max_levels,
|
|
quant_levels = None,
|
|
):
|
|
def _create_mask(l, device):
|
|
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
|
|
stop = torch.tensor(l, device=device).unsqueeze(1) # (b 1)
|
|
return (seq < stop).float() # (b t)
|
|
|
|
def list_to_tensor(x_list: list[Tensor]):
|
|
l = list(map(len, x_list))
|
|
x = pad_sequence(x_list).t()
|
|
|
|
m = _create_mask(l, x_list[0].device)
|
|
m = m.to(x)
|
|
return x, m
|
|
|
|
device = text_list[0].device
|
|
batch_size = len(text_list)
|
|
input_ids = [ [] for _ in range(batch_size) ]
|
|
|
|
offset = 0
|
|
|
|
sep = torch.Tensor([ sep ])
|
|
stop = torch.Tensor([ stop ])
|
|
|
|
for i, text in enumerate(text_list):
|
|
seq = text.to("cpu", dtype=torch.int64)
|
|
input_ids[i].append( seq )
|
|
input_ids[i].append( sep )
|
|
|
|
offset = text_tokens
|
|
# inject target quant_level
|
|
if quant_levels is not None:
|
|
for i, rvq in enumerate( quant_levels ):
|
|
seq = torch.Tensor([offset + rvq]).to("cpu", dtype=torch.int64)
|
|
input_ids[i].append( seq )
|
|
input_ids[i].append( sep )
|
|
|
|
offset = text_tokens + audio_rvq_levels
|
|
for i, prom in enumerate(prom_list):
|
|
# deinterleaved
|
|
if quant_levels is not None:
|
|
quant_level = quant_levels[i]
|
|
if ignore_index is not None:
|
|
seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] ) ] ).to("cpu", dtype=torch.int64)
|
|
else:
|
|
seq = prom[:, quant_level].to("cpu", dtype=torch.int64)
|
|
for idx, token in enumerate( seq ):
|
|
token += offset + ( audio_tokens * quant_level )
|
|
# interleaved
|
|
else:
|
|
if ignore_index is not None:
|
|
seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] * prom.shape[1] ) ] ).to("cpu", dtype=torch.int64)
|
|
else:
|
|
seq = prom.flatten().to("cpu", dtype=torch.int64)
|
|
for idx, token in enumerate( seq ):
|
|
token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) )
|
|
|
|
input_ids[i].append( seq )
|
|
input_ids[i].append( sep )
|
|
|
|
offset = text_tokens + audio_rvq_levels + (audio_tokens * audio_rvq_levels)
|
|
|
|
for i, resp in enumerate(resp_list):
|
|
# deinterleaved
|
|
if quant_levels is not None:
|
|
# grab the previous rvq level
|
|
quant_level = quant_levels[i] - 1
|
|
# way to signal we want to inference for rvq level 0
|
|
# without it, it's a random chance for any level to be selected again
|
|
|
|
if quant_level < 0:
|
|
continue
|
|
|
|
seq = sep
|
|
else:
|
|
# my shitcode keeps things as lists of tensors for each level, so this handles it because lists can't index by tuples
|
|
if isinstance(resp, list):
|
|
seq = resp[quant_level].to("cpu", dtype=torch.int64)
|
|
else:
|
|
seq = resp[:, quant_level].to("cpu", dtype=torch.int64)
|
|
|
|
for idx, token in enumerate( seq ):
|
|
token += offset + ( audio_tokens * quant_level )
|
|
|
|
|
|
input_ids[i].append( seq )
|
|
input_ids[i].append( stop )
|
|
# interleaved
|
|
else:
|
|
seq = resp.flatten().to("cpu", dtype=torch.int64)
|
|
for idx, token in enumerate( seq ):
|
|
token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) )
|
|
|
|
input_ids[i].append( seq )
|
|
input_ids[i].append( stop )
|
|
|
|
for i, resp in enumerate(targ_list):
|
|
# deinterleaved
|
|
if quant_levels is not None:
|
|
quant_level = quant_levels[i]
|
|
seq = resp[:, quant_level].to("cpu", dtype=torch.int64)
|
|
for idx, token in enumerate( seq ):
|
|
token += offset + ( audio_tokens * quant_level )
|
|
|
|
input_ids[i].append( seq )
|
|
input_ids[i].append( stop )
|
|
# interleaved
|
|
else:
|
|
seq = resp.flatten().to("cpu", dtype=torch.int64)
|
|
for idx, token in enumerate( seq ):
|
|
token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) )
|
|
|
|
input_ids[i].append( seq )
|
|
input_ids[i].append( stop )
|
|
|
|
for i, batch in enumerate(input_ids):
|
|
input_ids[i] = torch.concat(input_ids[i], dim=-1).to(device=device, dtype=torch.int64)
|
|
|
|
return list_to_tensor(input_ids)
|
|
|
|
# unfold from one unified token ID space to separate token spaces
|
|
# to-do: unfold at a specific RVQ level instead if requested
|
|
def unfold_outputs(
|
|
output_ids,
|
|
|
|
sep = 3,
|
|
stop = 3,
|
|
|
|
text_tokens = 256,
|
|
audio_tokens = 1024,
|
|
audio_rvq_levels = cfg.model.max_levels,
|
|
quant_levels = None,
|
|
):
|
|
device = output_ids.device
|
|
batch_size = output_ids.shape[0]
|
|
|
|
text_list = [ [] for _ in range(batch_size) ]
|
|
prom_list = [ [] for _ in range(batch_size) ]
|
|
resp_list = [ [] for _ in range(batch_size) ]
|
|
|
|
for i, batch in enumerate( output_ids ):
|
|
# crigne logic to handle prefix resp for rvq levels > 0
|
|
# a better way is to observe if the rvq level increased
|
|
should_flush = False
|
|
flushed = False
|
|
for idx, token in enumerate( batch ):
|
|
id = token.item()
|
|
if id == sep or id == stop:
|
|
if should_flush and quant_levels is not None and quant_levels[i] > 0:
|
|
resp_list[i] = []
|
|
should_flush = False
|
|
flushed = True
|
|
|
|
continue
|
|
|
|
if 0 <= id and id < text_tokens:
|
|
text_list[i].append( id )
|
|
elif text_tokens + audio_rvq_levels <= id and id < text_tokens + audio_rvq_levels + (audio_tokens * audio_rvq_levels):
|
|
prom_list[i].append( (id - text_tokens - audio_rvq_levels) % audio_tokens )
|
|
elif text_tokens + audio_rvq_levels + (audio_tokens * audio_rvq_levels) <= id:
|
|
resp_list[i].append( (id - text_tokens - audio_rvq_levels) % audio_tokens )
|
|
if not flushed:
|
|
should_flush = True
|
|
|
|
if quant_levels is not None:
|
|
prom_list[i] = torch.Tensor(prom_list[i]).t().to(device=device, dtype=torch.int64)
|
|
resp_list[i] = torch.Tensor(resp_list[i]).t().to(device=device, dtype=torch.int64)
|
|
else:
|
|
prom_len = len(prom_list[i])
|
|
if prom_len % audio_rvq_levels == 0 and False:
|
|
prom_list[i] = torch.Tensor(prom_list[i]).reshape( audio_rvq_levels, prom_len // audio_rvq_levels ).t()
|
|
else:
|
|
bins = [ [] for _ in range(audio_rvq_levels) ]
|
|
for pos in range( prom_len ):
|
|
rvq = pos % audio_rvq_levels
|
|
bins[rvq].append( prom_list[i][pos] )
|
|
nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels
|
|
bins = bins[:nearest]
|
|
prom_list[i] = torch.Tensor(bins).t().to(device=device, dtype=torch.int64)
|
|
|
|
resp_len = len(resp_list[i])
|
|
if len(resp_list[i]) % audio_rvq_levels == 0 and False:
|
|
resp_list[i] = torch.Tensor(resp_list[i]).reshape( audio_rvq_levels, resp_len // audio_rvq_levels ).t()
|
|
else:
|
|
bins = [ [] for _ in range(audio_rvq_levels) ]
|
|
for pos in range( resp_len ):
|
|
rvq = pos % audio_rvq_levels
|
|
bins[rvq].append( resp_list[i][pos] )
|
|
nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels
|
|
bins = bins[:nearest]
|
|
resp_list[i] = torch.Tensor(bins).t().to(device=device, dtype=torch.int64)
|
|
|
|
text_list[i] = torch.Tensor( text_list[i] ).to(device=device, dtype=torch.int64)
|
|
|
|
return dict(
|
|
text_list=text_list,
|
|
prom_list=prom_list,
|
|
resp_list=resp_list
|
|
)
|
|
|
|
# to-do: clean up this symmap mess
|
|
def get_phone_symmap():
|
|
return cfg.tokenizer.get_vocab()
|
|
|
|
def tokenize( phones ):
|
|
return cfg.tokenizer.encode( "".join(phones) )
|
|
|
|
def get_lang_symmap():
|
|
return {
|
|
"en": 0,
|
|
"ja": 1,
|
|
}
|
|
|
|
def get_tone_symmap():
|
|
return {
|
|
"neutral": 0,
|
|
}
|
|
return symmap
|
|
|
|
def get_task_symmap():
|
|
return {
|
|
"<tts>": 0,
|
|
"<tts-c>": 1,
|
|
"<ns>": 2,
|
|
"<sr>": 3,
|
|
"<tse>": 4,
|
|
"<soe>": 5,
|
|
"<mask>": 6,
|
|
"<eoe>": 7,
|
|
}
|
|
|
|
def _replace_file_extension(path, suffix):
|
|
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
|
|
|
|
def _get_quant_extension():
|
|
return ".dac" if cfg.audio_backend == "dac" else ".enc"
|
|
|
|
def _get_phone_extension():
|
|
return ".json" # if cfg.audio_backend == "dac" else ".phn.txt"
|
|
|
|
def _get_quant_path(path):
|
|
return _replace_file_extension(path, _get_quant_extension())
|
|
|
|
def _get_phone_path(path):
|
|
return _replace_file_extension(path, _get_phone_extension())
|
|
|
|
_durations_map = {}
|
|
# makeshift caching the above to disk
|
|
@cfg.diskcache()
|
|
def _get_duration_map( type="training" ):
|
|
return _durations_map[type] if type in _durations_map else {}
|
|
|
|
@cfg.diskcache()
|
|
def _load_paths(dataset, type="training"):
|
|
return { cfg.get_spkr( cfg.data_dir / data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") }
|
|
|
|
def _load_paths_from_metadata(group_name, type="training", validate=False):
|
|
data_dir = group_name if cfg.dataset.use_hdf5 else cfg.data_dir / group_name
|
|
|
|
_fn = _get_hdf5_paths if cfg.dataset.use_hdf5 else _get_paths_of_extensions
|
|
|
|
def key( id, entry=None ):
|
|
return f"/{type}/{_get_hdf5_path(data_dir)}/{id}" if cfg.dataset.use_hdf5 else data_dir / id
|
|
|
|
metadata_path = cfg.metadata_dir / f'{group_name}.json'
|
|
metadata = {}
|
|
|
|
if cfg.dataset.use_metadata and metadata_path.exists():
|
|
metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read())
|
|
|
|
if len(metadata) == 0:
|
|
return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_quant_extension(), validate )
|
|
|
|
def _validate( id, entry ):
|
|
phones = entry['phones'] if "phones" in entry else 0
|
|
duration = entry['duration'] if "duration" in entry else 0
|
|
|
|
# add to duration bucket
|
|
k = key(id, entry)
|
|
if type not in _durations_map:
|
|
_durations_map[type] = {}
|
|
_durations_map[type][k] = duration
|
|
|
|
if not validate:
|
|
return True
|
|
|
|
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration
|
|
|
|
return [ key(id, entry) for id, entry in metadata.items() if _validate(id, entry) ]
|
|
|
|
|
|
def _get_hdf5_path(path):
|
|
# to-do: better validation
|
|
#print(path)
|
|
return str(path)
|
|
|
|
def _get_hdf5_paths( data_dir, type="training", validate=False ):
|
|
data_dir = str(data_dir)
|
|
|
|
key = f"/{type}/{_get_hdf5_path(data_dir)}"
|
|
|
|
def _validate( id, entry ):
|
|
phones = entry.attrs['phonemes']
|
|
duration = entry.attrs['duration']
|
|
|
|
if type not in _durations_map:
|
|
_durations_map[type] = {}
|
|
_durations_map[type][f"{key}/{id}"] = duration
|
|
|
|
if not validate:
|
|
return True
|
|
|
|
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration
|
|
|
|
return [ Path(f"{key}/{id}") for id, entry in cfg.hdf5[key].items() if _validate(id, entry) ] if key in cfg.hdf5 else []
|
|
|
|
def _get_paths_of_extensions( path, extensions=_get_quant_extension(), validate=False ):
|
|
if isinstance(path, str):
|
|
path = Path(path)
|
|
|
|
def _validate(path):
|
|
if "".join(path.suffixes) not in extensions:
|
|
return False
|
|
if not _get_phone_path(path).exists() or not _get_quant_path(path).exists():
|
|
return False
|
|
if not validate:
|
|
return True
|
|
# to-do: find an easy way to determine size from pickled quants without loading
|
|
# to-do: find a consistent way to derive phoneme count from filesize (probably can't due to utf-8)
|
|
phones = len(_get_phones(_get_phone_path(path))) # _get_phone_path(path).stat().st_size // 2 + 1
|
|
return cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
|
|
|
|
|
|
return [ p for p in list(path.iterdir()) if _validate(p) ] if path.exists() and path.is_dir() else []
|
|
|
|
def _load_quants(path, return_metadata=False) -> Tensor:
|
|
qnt = np.load(_get_quant_path(path), allow_pickle=True)[()]
|
|
if return_metadata:
|
|
return torch.from_numpy(qnt["codes"].astype(int))[0][:, :].t().to(torch.int16), qnt["metadata"]
|
|
return torch.from_numpy(qnt["codes"].astype(int))[0][:, :].t().to(torch.int16)
|
|
|
|
# prune consecutive spaces
|
|
def _cleanup_phones( phones, targets=[" "]):
|
|
return [ p for i, p in enumerate(phones) if p not in targets or ( p in targets and p != phones[i-1] ) ]
|
|
|
|
@cache
|
|
def _get_phones(path):
|
|
phone_path = _get_phone_path(path)
|
|
quant_path = _get_quant_path(path)
|
|
if phone_path.exists():
|
|
metadata = json.loads(open(phone_path, "r", encoding="utf-8").read())
|
|
elif quant_path.exists():
|
|
_, metadata = _load_quants( path, return_metadata=True )
|
|
else:
|
|
raise Exception(f"Could not load phonemes: {path}")
|
|
|
|
content = metadata["phonemes"]
|
|
return "".join(content)
|
|
|
|
def _interleaved_reorder(l, fn):
|
|
groups = defaultdict(list)
|
|
for e in l:
|
|
groups[fn(e)].append(e)
|
|
groups = {k: groups[k] for k in sorted(groups)}
|
|
for interleaved in zip_longest(*groups.values()):
|
|
for value in interleaved:
|
|
if value is not None:
|
|
yield value
|
|
|
|
class Dataset(_Dataset):
|
|
def __init__(
|
|
self,
|
|
phone_symmap=None,
|
|
training=False,
|
|
extra_paths_by_spkr_name: dict[str, list] = {},
|
|
):
|
|
super().__init__()
|
|
self._head = None
|
|
self.shuffle = False
|
|
self.sampler = None
|
|
|
|
self.paths = []
|
|
|
|
self.training = training
|
|
self.dataset_type = "training" if self.training else "validation"
|
|
self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation
|
|
self.sampler_type = cfg.dataset.sample_type # if self.dataset_type == "training" else "group"
|
|
self.sampler_order = cfg.dataset.sample_order
|
|
|
|
# to-do: do not do validation if there's nothing in the validation
|
|
# this just makes it be happy
|
|
if len(self.dataset) == 0:
|
|
self.dataset = cfg.dataset.training
|
|
|
|
# dict of paths keyed by speaker names
|
|
self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type)
|
|
|
|
# cull speakers if they do not have enough utterances
|
|
if cfg.dataset.min_utterances > 0:
|
|
keys = list(self.paths_by_spkr_name.keys())
|
|
for key in keys:
|
|
if len(self.paths_by_spkr_name[key]) < cfg.dataset.min_utterances:
|
|
del self.paths_by_spkr_name[key]
|
|
|
|
# flatten paths
|
|
self.paths = list(itertools.chain.from_iterable(self.paths_by_spkr_name.values()))
|
|
|
|
# split dataset accordingly per GPU
|
|
if cfg.distributed and self.training:
|
|
"""
|
|
batches = len(self.paths) // world_size()
|
|
start = batches * global_rank()
|
|
end = batches * (global_rank() + 1)
|
|
|
|
self.paths = self.paths[start:end]
|
|
"""
|
|
|
|
self.paths = [ path for i, path in enumerate(self.paths) if i % world_size() == 0 ]
|
|
|
|
# recreate paths_by_spkr_name
|
|
self.paths_by_spkr_name = {}
|
|
for path in self.paths:
|
|
name = cfg.get_spkr( Path(path) )
|
|
if name not in self.paths_by_spkr_name:
|
|
self.paths_by_spkr_name[name] = []
|
|
self.paths_by_spkr_name[name].append( path )
|
|
|
|
# do it here due to the above
|
|
self.duration = 0
|
|
self.duration_map = _get_duration_map( self.dataset_type )
|
|
self.duration_buckets = {}
|
|
|
|
# store in corresponding bucket
|
|
for path in self.paths:
|
|
duration = self.duration_map[path]
|
|
self.duration += duration
|
|
|
|
# only calc duration if we're tot going to order by duration
|
|
if self.sampler_order != "duration":
|
|
continue
|
|
|
|
bucket = int(round(duration))
|
|
if bucket not in self.duration_buckets:
|
|
self.duration_buckets[bucket] = []
|
|
self.duration_buckets[bucket].append( ( Path(path), duration ) )
|
|
|
|
# ensure they're ordered
|
|
self.duration_buckets = dict(sorted(self.duration_buckets.items()))
|
|
|
|
# sort by duration
|
|
if self.sampler_order == "duration":
|
|
flattened = {}
|
|
# sort and interleave
|
|
for bucket in self.duration_buckets:
|
|
# sort by duration
|
|
self.duration_buckets[bucket].sort( key=lambda x: x[1] )
|
|
# split to retain tuples
|
|
flattened[bucket] = self.duration_buckets[bucket]
|
|
# replace with path
|
|
flattened[bucket] = [ x[0] for x in flattened[bucket] ]
|
|
# flatten by paths
|
|
flattened[bucket] = [*_interleaved_reorder(flattened[bucket], self.get_speaker)]
|
|
# flatten paths
|
|
self.paths = list(itertools.chain.from_iterable(flattened.values()))
|
|
elif self.sampler_order == "shuffle":
|
|
# just interleave
|
|
self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
|
|
|
|
|
|
# dict of speakers keyed by speaker group
|
|
self.spkrs_by_spkr_group = {}
|
|
for data_dir in self.dataset:
|
|
spkr = cfg.get_spkr( data_dir / "dummy" )
|
|
spkr_group = cfg.get_spkr_group( data_dir / "dummy" )
|
|
|
|
if spkr not in self.paths_by_spkr_name or len(self.paths_by_spkr_name[spkr]) < cfg.dataset.min_utterances:
|
|
continue
|
|
|
|
if spkr_group not in self.spkrs_by_spkr_group:
|
|
self.spkrs_by_spkr_group[spkr_group] = []
|
|
|
|
self.spkrs_by_spkr_group[spkr_group].append( spkr )
|
|
|
|
self.spkr_groups = list(self.spkrs_by_spkr_group.keys())
|
|
|
|
self.noise_paths = _load_paths(cfg.dataset.noise, "noise")
|
|
self.noise_paths = list(itertools.chain.from_iterable(self.noise_paths.values()))
|
|
|
|
self.phone_symmap = phone_symmap or self._get_phone_symmap()
|
|
self.spkr_symmap = self._get_spkr_symmap()
|
|
self.spkr_group_symmap = self._get_spkr_group_symmap()
|
|
self.lang_symmap = self._get_lang_symmap()
|
|
self.tone_symmap = self._get_tone_symmap()
|
|
self.task_symmap = self._get_task_symmap()
|
|
|
|
# assert len(self.phone_symmap) < 256, "Unique token count should be [0,255] to fit within uint8"
|
|
self.text_dtype = torch.uint8 if len(self.phone_symmap) < 256 else torch.int16
|
|
|
|
if len(self.paths) == 0:
|
|
raise ValueError(f"No valid path is found for {self.dataset_type}")
|
|
|
|
sampler_path = cfg.rel_path / self.sampler_state_dict_path
|
|
|
|
if self.sampler_type == "path":
|
|
if self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0:
|
|
self.sampler = BatchedOrderedSampler( self.duration_buckets, cfg.dataset.sample_max_duration_batch, cfg.hyperparameters.batch_size if self.training else cfg.evaluation.batch_size )
|
|
else:
|
|
self.sampler = OrderedSampler( len(self) )
|
|
self.samplers = {}
|
|
self.spkr_samplers = {}
|
|
else:
|
|
self.sampler = RandomSampler( len(self) )
|
|
self.samplers = { name: PoolSampler( paths, keep_all=True ) for name, paths in self.paths_by_spkr_name.items() }
|
|
self.spkr_samplers = { name: PoolSampler( [*set(speakers)], keep_all=True ) for name, speakers in self.spkrs_by_spkr_group.items() }
|
|
|
|
self.load_state_dict()
|
|
|
|
@cached_property
|
|
def sampler_state_dict_path(self):
|
|
return f"sampler.{self.sampler_type}.rank{global_rank()}.pt"
|
|
|
|
def get_speaker(self, path):
|
|
if isinstance(path, str):
|
|
path = Path(path)
|
|
res = cfg.get_spkr(path)
|
|
return res
|
|
|
|
def get_speaker_group(self, path):
|
|
if isinstance(path, str):
|
|
path = Path(path)
|
|
res = cfg.get_spkr_group(path)
|
|
return res
|
|
|
|
def get_language(self, speaker_group):
|
|
lang = "en"
|
|
for k, v in cfg.dataset.speaker_languages.items():
|
|
if speaker_group in v:
|
|
lang = k
|
|
break
|
|
|
|
return lang
|
|
|
|
@cached_property
|
|
def spkrs(self):
|
|
return sorted({self.get_speaker(path) for path in self.paths})
|
|
|
|
@cached_property
|
|
def tasks(self):
|
|
return cfg.dataset.tasks_list # ["tts", "tts", "ns", "sr", "tse", "tts", "tts"] # , "cse", "nse"
|
|
|
|
def save_state_dict(self, path = None):
|
|
if path is None:
|
|
path = cfg.rel_path / self.sampler_state_dict_path
|
|
|
|
if self.sampler_type == "path":
|
|
state_dict = self.sampler.get_state()
|
|
else:
|
|
state_dict = {
|
|
"samplers": { name: sampler.get_state() for name, sampler in self.samplers.items() },
|
|
"spkr_samplers": { name: sampler.get_state() for name, sampler in self.spkr_samplers.items() },
|
|
}
|
|
torch.save(state_dict, path)
|
|
|
|
def load_state_dict(self, path = None):
|
|
if path is None:
|
|
path = cfg.rel_path / self.sampler_state_dict_path
|
|
|
|
if not path.exists():
|
|
return
|
|
|
|
state_dict = torch.load(path)
|
|
if self.sampler_type == "path":
|
|
state_dict = self.sampler.set_state(state_dict)
|
|
else:
|
|
for name, sampler in state_dict["samplers"].items():
|
|
if name not in self.samplers:
|
|
continue
|
|
self.samplers[name].set_state( sampler )
|
|
|
|
for name, sampler in state_dict["spkr_samplers"].items():
|
|
if name not in self.spkr_samplers:
|
|
continue
|
|
self.spkr_samplers[name].set_state( sampler )
|
|
|
|
def _get_phone_symmap(self):
|
|
return get_phone_symmap()
|
|
|
|
def _get_spkr_symmap(self):
|
|
return {s: i for i, s in enumerate(self.spkrs)}
|
|
|
|
def _get_spkr_group_symmap(self):
|
|
return {s: i for i, s in enumerate(self.spkr_groups)}
|
|
|
|
def _get_lang_symmap(self):
|
|
return get_lang_symmap()
|
|
|
|
def _get_tone_symmap(self):
|
|
return get_tone_symmap()
|
|
|
|
def _get_task_symmap(self):
|
|
return get_task_symmap()
|
|
|
|
"""
|
|
def get_task_token( self, token, levels=cfg.model.max_levels ):
|
|
if not hasattr(self, "task_symmap"):
|
|
self.task_symmap = self._get_task_symmap()
|
|
return torch.Tensor([[ self.task_symmap[f'<{token}>'] for _ in range(levels) ]]).to(dtype=torch.int16)
|
|
"""
|
|
|
|
def sample_noise(self):
|
|
path = random.choice(self.noise_paths)
|
|
|
|
if cfg.dataset.use_hdf5:
|
|
key = _get_hdf5_path(path)
|
|
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
|
|
else:
|
|
qnt = _load_quants(path, return_metadata=False)
|
|
return qnt
|
|
|
|
def sample_speakers(self, ignore=[]):
|
|
choices = set(self.spkrs) - set(ignore)
|
|
return random.choice([*choices])
|
|
|
|
def sample_prompts(self, spkr_name, ignore):
|
|
prom_list = []
|
|
|
|
choices = set(self.paths_by_spkr_name[spkr_name]) - {ignore}
|
|
choices = [*choices]
|
|
|
|
# no other utterances, it'd make more sense to prune speakers with only one utterance in the validation step
|
|
if len(choices) == 0:
|
|
choices = [*set(self.paths_by_spkr_name[spkr_name])]
|
|
"""
|
|
raise ValueError(
|
|
f"Failed to find another different utterance for {spkr_name}."
|
|
)
|
|
"""
|
|
|
|
prom_length = 0
|
|
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second)
|
|
|
|
for _ in range(cfg.dataset.max_prompts):
|
|
path = random.choice(choices)
|
|
if cfg.dataset.use_hdf5:
|
|
key = _get_hdf5_path(path)
|
|
|
|
if "audio" not in cfg.hdf5[key]:
|
|
_logger.warning(f'MISSING AUDIO: {key}')
|
|
continue
|
|
|
|
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
|
|
else:
|
|
qnt = _load_quants(path, return_metadata=False)
|
|
|
|
if 0 < trim_length and trim_length < qnt.shape[0]:
|
|
qnt = trim( qnt, trim_length )
|
|
|
|
prom_list.append(qnt)
|
|
prom_length += qnt.shape[0]
|
|
|
|
if prom_length >= trim_length or random.random() > cfg.dataset.random_utterance:
|
|
break
|
|
|
|
# might be better to decode => concat waveforms with silence in between => reencode
|
|
# as you technically can't just append encodec sequences together like this without issues
|
|
prom = torch.cat(prom_list)
|
|
|
|
if 0 < trim_length and trim_length < prom.shape[0]:
|
|
prom = trim( prom, trim_length )
|
|
|
|
return prom
|
|
|
|
def __getitem__(self, index):
|
|
if self.sampler_type == "group":
|
|
spkr_group = self.spkr_groups[index]
|
|
#spkr_group_id = self.spkr_group_symmap[spkr_group]
|
|
spkr_name = self.spkr_samplers[spkr_group].sample()
|
|
spkr_id = self.spkr_symmap[spkr_name]
|
|
path = self.samplers[spkr_name].sample()
|
|
elif self.sampler_type == "speaker":
|
|
spkr_name = self.spkrs[index]
|
|
spkr_id = self.spkr_symmap[spkr_name]
|
|
path = self.samplers[spkr_name].sample()
|
|
spkr_group = self.get_speaker_group(path)
|
|
#spkr_group_id = self.spkr_group_symmap[spkr_group]
|
|
else:
|
|
path = self.paths[index]
|
|
spkr_name = self.get_speaker(path)
|
|
spkr_id = self.spkr_symmap[spkr_name]
|
|
spkr_group = self.get_speaker_group(path)
|
|
#spkr_group_id = self.spkr_group_symmap[spkr_group]
|
|
|
|
if cfg.dataset.use_hdf5:
|
|
key = _get_hdf5_path(path)
|
|
|
|
if key not in cfg.hdf5:
|
|
raise RuntimeError(f'Key of Path ({path}) not in HDF5: {key}')
|
|
|
|
text = cfg.hdf5[key]["text"][:]
|
|
resps = cfg.hdf5[key]["audio"][:, :]
|
|
|
|
text = torch.from_numpy(text).to(self.text_dtype)
|
|
resps = torch.from_numpy(resps).to(torch.int16)
|
|
else:
|
|
resps, metadata = _load_quants(path, return_metadata=True)
|
|
text = torch.tensor(tokenize( metadata["phonemes"] )).to(self.text_dtype)
|
|
#text = torch.tensor(tokenize( _get_phones( path ) )).to(self.text_dtype)
|
|
|
|
lang = torch.tensor([ self.lang_symmap[ self.get_language(spkr_group) ]]).to(torch.uint8)
|
|
|
|
# append additional prompts in an attempt to artifically increase lengths / offer new data
|
|
if cfg.experimental and cfg.dataset.max_resps > 1 and random.random() < cfg.dataset.p_resp_append:
|
|
choices = [*(set(self.paths_by_spkr_name[spkr_name]) - {path})]
|
|
|
|
if len(choices) > 0:
|
|
for _ in range( cfg.dataset.max_resps - 1 ):
|
|
sampled_path = random.choice(choices)
|
|
choices = [*(set(choices) - {sampled_path})]
|
|
if cfg.dataset.use_hdf5:
|
|
key = _get_hdf5_path(sampled_path)
|
|
txt = cfg.hdf5[key]["text"][:]
|
|
qnt = cfg.hdf5[key]["audio"][:, :]
|
|
|
|
txt = np.array( txt )
|
|
|
|
txt = torch.from_numpy(txt).to(self.text_dtype)
|
|
qnt = torch.from_numpy(qnt).to(torch.int16)
|
|
else:
|
|
#txt = torch.tensor([*map(self.phone_symmap.get, _get_phones(sampled_path))]).to(self.text_dtype)
|
|
#txt = torch.tensor(tokenize(_get_phones(sampled_path))).to(self.text_dtype)
|
|
qnt, metadata = _load_quants(sampled_path, return_metadata=True)
|
|
txt = torch.tensor(tokenize( metadata["phonemes"] )).to(self.text_dtype)
|
|
|
|
# <s>[original text] [new text]</s>
|
|
# removes the original text's </s>, includes a space, and remove the new text's <s>
|
|
text = torch.concat([ text[:-1], torch.tensor([self.phone_symmap[" "]]).to(torch.int16), txt[1:] ])
|
|
|
|
# might be better to decode => concat waveforms with silence in between => reencode
|
|
# as you technically can't just append encodec sequences together like this without issues
|
|
resps = torch.concat([ resps, qnt ])
|
|
|
|
task = "tts"
|
|
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second)
|
|
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
|
|
|
|
|
# Disabled until I swap over to a better method
|
|
"""
|
|
task = random.choice(self.tasks)
|
|
|
|
# ensure a speaker has at least four utterances
|
|
# default to tts if not
|
|
if len(set(self.paths_by_spkr_name[spkr_name]) - {path}) < 4:
|
|
task = "tts"
|
|
noise_scale = 0.25
|
|
if task == "tts" or task == "tts-c":
|
|
trim_length = int(cfg.dataset.prompt_duration * cfg.dataset.frames_per_second)
|
|
# demote if the target is too short
|
|
if task == "tts-c" and trim_length * 2 >= resps.shape[0]:
|
|
task = "tts"
|
|
|
|
# VALL-E continuous
|
|
# ignore if target utterance is shorter than prompt duration
|
|
# to-do: actually do this for the AR only as I don't think the paper trained the NAR for this
|
|
if task == "tts-c":
|
|
proms = resps[:trim_length, :]
|
|
resps = resps[trim_length:, :]
|
|
|
|
proms = torch.cat( [self.get_task_token(task), proms] )
|
|
else:
|
|
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
|
# noise suppression || speech removal
|
|
elif task == "ns" or task == "sr":
|
|
# sample random noise
|
|
noise = self.sample_noise()
|
|
# extend the noise to fill the target audio
|
|
noise = repeat_extend_audio(noise, resps.shape[0])
|
|
# create the input prompt by merging the target audio with the noise
|
|
proms = merge_audio( resps, noise, scale=[1, noise_scale], device="cpu" )
|
|
# set the target to just be the noise if <sr>
|
|
if task == "sr":
|
|
resps = noise
|
|
# prepend the task token
|
|
proms = torch.cat( [self.get_task_token(task), proms] )
|
|
|
|
# set the text prompt to empty to train without a guided text prompt
|
|
if random.random() < 0.5:
|
|
text = torch.tensor([1, 2]).to(self.text_dtype)
|
|
# target speech extraction
|
|
elif task == "tse":
|
|
# sample a random, clean, utterance for the target speaker
|
|
clean_proms = self.sample_prompts(spkr_name, ignore=path)
|
|
# sample a random, clean utterance from a different speaker
|
|
other_proms = self.sample_prompts(self.sample_speakers(ignore=[spkr_name]), ignore="")
|
|
# overlay the random speaker over the target audio
|
|
|
|
smallest_size = min(resps.shape[0], other_proms.shape[0])
|
|
if other_proms.shape[0] == smallest_size:
|
|
noisy_proms = merge_audio( resps[:smallest_size, :], other_proms, scale=[1, random.uniform(0.5, 0.75)], device="cpu" )
|
|
noisy_proms = torch.cat( [ noisy_proms, resps[smallest_size:, :] ] )
|
|
else:
|
|
noisy_proms = merge_audio( resps, other_proms[:smallest_size, :], scale=[1, random.uniform(0.5, 0.75)], device="cpu" )
|
|
noisy_proms = torch.cat( [ noisy_proms, other_proms[smallest_size:, :] ] )
|
|
|
|
# stitch together the promps
|
|
proms = torch.cat( [clean_proms, self.get_task_token(task), noisy_proms] )
|
|
|
|
# set the text prompt to empty to train without a guided text prompt
|
|
if random.random() < 0.5:
|
|
text = torch.tensor([1, 2]).to(self.text_dtype) # <s></s>
|
|
|
|
# speech editing would require higher quality transcription data (phoneme level/word level) unfortunately
|
|
# as I need to get a good clean point to trim into
|
|
# clean speech editing
|
|
elif task == "cse" or task == "nse":
|
|
choices = set(self.paths_by_spkr_name[spkr_name]) - {path}
|
|
sampled = random.sample([*choices], 4)
|
|
|
|
if cfg.dataset.use_hdf5:
|
|
texts = [ torch.from_numpy(cfg.hdf5[_get_hdf5_path(path)]["text"][:]).to(self.text_dtype) for path in sampled ]
|
|
qnts = [ torch.from_numpy(cfg.hdf5[_get_hdf5_path(path)]["audio"][:, :]).to(torch.int16) for path in sampled ]
|
|
else:
|
|
texts = [ torch.tensor([*map(self.phone_symmap.get, _get_phones(path))]).to(self.text_dtype) for path in sampled ]
|
|
qnts = [ _load_quants(path) for path in sampled ]
|
|
|
|
# remove <s></s>
|
|
for i in range(len(texts)):
|
|
texts[i] = texts[i][1:-1]
|
|
|
|
pre_text, mid_text, post_text, edit_text = texts
|
|
pre_prom, mid_prom, post_prom, edit_prom = qnts
|
|
|
|
# randomly drop out pre
|
|
if random.random() < 0.125:
|
|
pre_text = None
|
|
pre_prom = None
|
|
# randomly drop out post
|
|
if random.random() < 0.125:
|
|
post_text = None
|
|
post_prom = None
|
|
|
|
# create new text
|
|
text = torch.cat(
|
|
[ torch.Tensor( [ 1 ] ).to(dtype=self.text_dtype) ] + # <s>
|
|
([ pre_text, torch.Tensor( [ 3 ] ).to(dtype=self.text_dtype) ] if pre_text is not None else []) + # pre_text + space'
|
|
[ edit_text ] + # 'edit text'
|
|
([ torch.Tensor( [ 3 ] ).to(dtype=self.text_dtype), post_text ] if post_text is not None else []) + # 'space' + edit_text
|
|
[ torch.Tensor( [ 2 ] ).to(dtype=self.text_dtype) ] # </s>
|
|
)
|
|
|
|
if task == "nse":
|
|
# sample random noise
|
|
noise = self.sample_noise()
|
|
|
|
# it might be better to extend the noise to the sum of the pre+mid+post or pre+edit+post to keep the noise truly coherent
|
|
# but it's noise, it's supposed to be random
|
|
def noise_proms( p ):
|
|
# ignore if we turned it off
|
|
if p is None:
|
|
return None
|
|
|
|
# extend the noise to fill the target audio
|
|
n = repeat_extend_audio(noise, p.shape[0])
|
|
# merge the noise over the utterance
|
|
return merge_audio(p, n, scale=[1, noise_scale], device="cpu")
|
|
|
|
# apply noise to all pieces
|
|
pre_prom = noise_proms( pre_prom )
|
|
mid_prom = noise_proms( mid_prom )
|
|
post_prom = noise_proms( post_prom )
|
|
edit_prom = noise_proms( edit_prom )
|
|
else:
|
|
mid_prom = self.get_task_token("mask")
|
|
|
|
# create new proms
|
|
proms = torch.cat(
|
|
([ pre_prom ] if pre_prom is not None else []) +
|
|
[self.get_task_token("soe")] +
|
|
[ mid_prom ] + # is <mask> if task is CSE
|
|
[self.get_task_token("eoe")] +
|
|
([ post_prom ] if post_prom is not None else [])
|
|
)
|
|
# create new resp
|
|
resps = torch.cat(
|
|
([ pre_prom ] if pre_prom is not None else []) +
|
|
[ edit_prom ] +
|
|
([ post_prom ] if post_prom is not None else [])
|
|
)
|
|
else:
|
|
raise Exception(f'Undefined task: {task}')
|
|
"""
|
|
|
|
"""
|
|
# emulate SVC
|
|
# takes in an utterance of the target speaker, a target utterenace as a reference clip as the input prompt
|
|
# targets an utterance of the target speaker with the same tempo + pitch + etc as the reference clip
|
|
|
|
# NOTE: I do not have a clue how to go about this. I *could* dynamically generate clips through RVC here, but I imagine the penalty would be astronomical
|
|
# ahead-of-time dataset preparation of a shit ton of RVC clips might be the key.
|
|
# aside from that, I have no clue how to go about training this, as this is entirely a proof of concept task.
|
|
elif task == "svc":
|
|
# sample a random, clean utterance for the target speaker
|
|
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
|
# sample a reference clip from a different speaker
|
|
ref_proms = self.sample_rvc(self.sample_speakers(ignore=[spkr_name]))
|
|
#
|
|
resps =
|
|
# stitch together the promps
|
|
proms = torch.cat( [proms, self.get_task_token(task), ref_proms] )
|
|
|
|
# set the text prompt to empty to train without a guided text prompt
|
|
if random.random() < 0.5:
|
|
text = torch.tensor([1, 2]).to(self.text_dtype)
|
|
"""
|
|
|
|
# trim to fit to requested prom/resps levels
|
|
proms = proms[:, :cfg.model.prom_levels]
|
|
resps = resps[:, :cfg.model.prom_levels]
|
|
|
|
|
|
return dict(
|
|
index=index,
|
|
path=Path(path),
|
|
spkr_name=spkr_name,
|
|
spkr_id=spkr_id,
|
|
task=task,
|
|
lang=lang,
|
|
text=text,
|
|
proms=proms,
|
|
resps=resps,
|
|
)
|
|
|
|
def head_(self, n):
|
|
self._head = n
|
|
|
|
def training_(self, value):
|
|
self.training = value
|
|
|
|
def __len__(self):
|
|
if self.sampler_type == "group":
|
|
return min(len(self.spkr_groups), self._head or len(self.spkr_groups))
|
|
if self.sampler_type == "speaker":
|
|
return min(len(self.spkrs), self._head or len(self.spkrs))
|
|
return min(len(self.paths), self._head or len(self.paths))
|
|
|
|
def pin_memory(self):
|
|
self.text = self.text.pin_memory()
|
|
self.proms = self.proms.pin_memory()
|
|
self.resps = self.resps.pin_memory()
|
|
self.resp = self.resp.pin_memory()
|
|
return self
|
|
|
|
|
|
def collate_fn(samples: list[dict]):
|
|
batch: dict[str, Any] = {k: [s[k] for s in samples] for k in samples[0]}
|
|
return batch
|
|
|
|
|
|
def _seed_worker(worker_id):
|
|
worker_seed = torch.initial_seed() % 2**32
|
|
np.random.seed(worker_seed)
|
|
random.seed(worker_seed)
|
|
|
|
|
|
def _create_dataloader(dataset, training):
|
|
"""
|
|
if cfg.distributed and training:
|
|
sampler = DistributedSampler(dataset)
|
|
shuffle = False
|
|
"""
|
|
|
|
kwargs = dict(
|
|
shuffle=dataset.shuffle,
|
|
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
|
|
drop_last=training,
|
|
sampler=dataset.sampler,
|
|
) if not isinstance(dataset.sampler, BatchedOrderedSampler) else dict(
|
|
batch_sampler=dataset.sampler,
|
|
)
|
|
|
|
return DataLoader(
|
|
dataset=dataset,
|
|
num_workers=cfg.dataset.workers,
|
|
collate_fn=collate_fn,
|
|
persistent_workers=cfg.dataset.workers > 1,
|
|
pin_memory=False, # True,
|
|
worker_init_fn=_seed_worker,
|
|
**kwargs,
|
|
)
|
|
|
|
def create_datasets():
|
|
train_dataset = Dataset( training=True )
|
|
val_dataset = Dataset( phone_symmap=train_dataset.phone_symmap, training=False )
|
|
|
|
return train_dataset, val_dataset
|
|
|
|
|
|
def create_train_val_dataloader():
|
|
train_dataset, val_dataset = create_datasets()
|
|
|
|
# it'll cry about trying to pickle a torch._C_generator or something
|
|
try:
|
|
subtrain_dataset = copy.deepcopy(train_dataset)
|
|
except Exception as e:
|
|
subtrain_dataset = Dataset( training=True )
|
|
|
|
if subtrain_dataset.sampler_type == "path":
|
|
subtrain_dataset.head_(cfg.evaluation.size)
|
|
|
|
train_dl = _create_dataloader(train_dataset, training=True)
|
|
val_dl = _create_dataloader(val_dataset, training=False)
|
|
subtrain_dl = _create_dataloader(subtrain_dataset, training=False)
|
|
|
|
_logger.info(str(train_dataset.phone_symmap))
|
|
_logger.info(str(train_dataset.spkr_symmap))
|
|
_logger.info(str(train_dataset.spkr_group_symmap))
|
|
|
|
_logger.info(f"#samples (train): {len(train_dataset)}.")
|
|
_logger.info(f"#samples (val): {len(val_dataset)}.")
|
|
_logger.info(f"#samples (subtrain): {len(subtrain_dataset)}.")
|
|
|
|
_logger.info(f"#duration (train): {str(train_dataset.duration)}.")
|
|
_logger.info(f"#duration (val): {str(val_dataset.duration)}.")
|
|
_logger.info(f"#duration (subtrain): {str(subtrain_dataset.duration)}.")
|
|
|
|
assert isinstance(subtrain_dl.dataset, Dataset)
|
|
|
|
return train_dl, subtrain_dl, val_dl
|
|
|
|
# parse dataset into better to sample metadata
|
|
def create_dataset_metadata( skip_existing=True ):
|
|
symmap = get_phone_symmap()
|
|
|
|
root = str(cfg.data_dir)
|
|
metadata_root = str(cfg.metadata_dir)
|
|
|
|
cfg.metadata_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
def add( dir, type="training", audios=True, texts=True ):
|
|
name = str(dir)
|
|
name = name.replace(root, "")
|
|
|
|
speaker_name = name
|
|
|
|
metadata_path = Path(f"{metadata_root}/{speaker_name}.json")
|
|
metadata_path.parents[0].mkdir(parents=True, exist_ok=True)
|
|
|
|
try:
|
|
metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read())
|
|
except Exception as e:
|
|
metadata = {}
|
|
|
|
if not os.path.isdir(f'{root}/{name}/'):
|
|
return
|
|
# tqdm.write(f'{root}/{name}')
|
|
files = os.listdir(f'{root}/{name}/')
|
|
|
|
# grab IDs for every file
|
|
ids = { file.replace(_get_quant_extension(), "").replace(_get_phone_extension(), "") for file in files }
|
|
|
|
for id in tqdm(ids, desc=f"Processing {name}"):
|
|
try:
|
|
quant_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True
|
|
text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if texts else True
|
|
|
|
if not quant_exists:
|
|
continue
|
|
|
|
key = f'{type}/{speaker_name}/{id}'
|
|
|
|
if skip_existing and id in metadata:
|
|
continue
|
|
|
|
if id not in metadata:
|
|
metadata[id] = {}
|
|
|
|
utterance_metadata = {}
|
|
if audios:
|
|
# ideally we'll encode Encodec-based audio in a similar manner because np has smaller files than pt
|
|
dac = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()]
|
|
qnt = torch.from_numpy(dac["codes"].astype(int))[0].t().to(dtype=torch.int16)
|
|
|
|
if "text" in dac["metadata"]:
|
|
utterance_metadata["text"] = dac["metadata"]["text"]
|
|
if "phonemes" in dac["metadata"]:
|
|
utterance_metadata["phonemes"] = dac["metadata"]["phonemes"]
|
|
if "language" in dac["metadata"]:
|
|
utterance_metadata["language"] = dac["metadata"]["language"]
|
|
if "original_length" in dac["metadata"] and "sample_rate" in dac["metadata"]:
|
|
utterance_metadata["duration"] = dac["metadata"]["original_length"] / dac["metadata"]["sample_rate"]
|
|
# text
|
|
if texts and text_exists and not utterance_metadata:
|
|
utterance_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read())
|
|
|
|
for k, v in utterance_metadata.items():
|
|
metadata[id][k] = v
|
|
|
|
except Exception as e:
|
|
tqdm.write(f'Error while processing {id}: {e}')
|
|
|
|
with open(str(metadata_path), "w", encoding="utf-8") as f:
|
|
f.write( json.dumps( metadata ) )
|
|
|
|
# training
|
|
for data_dir in tqdm(sorted(cfg.dataset.training), desc="Processing Training"):
|
|
add( data_dir, type="training" )
|
|
|
|
# validation
|
|
for data_dir in tqdm(sorted(cfg.dataset.validation), desc='Processing Validation'):
|
|
add( data_dir, type="validation" )
|
|
|
|
# noise
|
|
for data_dir in tqdm(sorted(cfg.dataset.noise), desc='Processing Noise'):
|
|
add( data_dir, type="noise", texts=False )
|
|
|
|
# parse yaml to create an hdf5 file
|
|
def create_dataset_hdf5( skip_existing=True ):
|
|
cfg.dataset.use_hdf5 = True
|
|
cfg.load_hdf5(write=True)
|
|
hf = cfg.hdf5
|
|
|
|
symmap = get_phone_symmap()
|
|
|
|
root = str(cfg.data_dir)
|
|
metadata_root = str(cfg.metadata_dir)
|
|
|
|
|
|
def add( dir, type="training", audios=True, texts=True ):
|
|
name = str(dir)
|
|
name = name.replace(root, "")
|
|
|
|
# yucky
|
|
speaker_name = name
|
|
if "LibriTTS-R" in speaker_name:
|
|
speaker_name = speaker_name.replace("LibriTTS-R", "LibriVox")
|
|
|
|
metadata_path = Path(f"{metadata_root}/{speaker_name}.json")
|
|
metadata_path.parents[0].mkdir(parents=True, exist_ok=True)
|
|
|
|
metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read())
|
|
|
|
if not os.path.isdir(f'{root}/{name}/'):
|
|
return
|
|
|
|
files = os.listdir(f'{root}/{name}/')
|
|
|
|
# grab IDs for every file
|
|
ids = { file.replace(_get_quant_extension(), "").replace(_get_phone_extension(), "") for file in files }
|
|
|
|
"""
|
|
# rephonemizes if you fuck up and use and old tokenizer...
|
|
for id, entry in tqdm(metadata.items(), desc=f"Processing {name}"):
|
|
key = f'{type}/{speaker_name}/{id}'
|
|
|
|
if key not in hf:
|
|
continue
|
|
|
|
group = hf[key]
|
|
|
|
if "phonemes" not in entry:
|
|
continue
|
|
if "text" not in group:
|
|
continue
|
|
|
|
txt = entry["phonemes"]
|
|
phn = "".join(txt)
|
|
phn = cfg.tokenizer.encode(phn)
|
|
phn = np.array(phn).astype(np.uint8)
|
|
|
|
del group["text"]
|
|
group.create_dataset('text', data=phn, compression='lzf')
|
|
"""
|
|
|
|
for id in tqdm(ids, desc=f"Processing {name}"):
|
|
try:
|
|
quant_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True
|
|
text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if texts else True
|
|
|
|
if not quant_exists:
|
|
continue
|
|
|
|
key = f'{type}/{speaker_name}/{id}'
|
|
|
|
if skip_existing and key in hf:
|
|
continue
|
|
|
|
group = hf.create_group(key) if key not in hf else hf[key]
|
|
|
|
if id not in metadata:
|
|
metadata[id] = {}
|
|
|
|
utterance_metadata = {}
|
|
|
|
# audio
|
|
if audios:
|
|
dac = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()]
|
|
qnt = torch.from_numpy(dac["codes"].astype(int))[0].t().to(dtype=torch.int16)
|
|
|
|
if "text" in dac["metadata"]:
|
|
utterance_metadata["text"] = dac["metadata"]["text"]
|
|
if "phonemes" in dac["metadata"]:
|
|
utterance_metadata["phonemes"] = dac["metadata"]["phonemes"]
|
|
if "language" in dac["metadata"]:
|
|
utterance_metadata["language"] = dac["metadata"]["language"]
|
|
if "original_length" in dac["metadata"] and "sample_rate" in dac["metadata"]:
|
|
utterance_metadata["duration"] = dac["metadata"]["original_length"] / dac["metadata"]["sample_rate"]
|
|
|
|
if "audio" not in group:
|
|
group.create_dataset('audio', data=qnt.numpy().astype(np.int16), compression='lzf')
|
|
|
|
# text
|
|
if texts:
|
|
if not utterance_metadata and text_exists:
|
|
utterance_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read())
|
|
|
|
phn = "".join(utterance_metadata["phonemes"])
|
|
phn = cfg.tokenizer.encode(phn)
|
|
phn = np.array(phn).astype(np.uint8)
|
|
|
|
if "text" not in group:
|
|
group.create_dataset('text', data=phn, compression='lzf')
|
|
|
|
for k, v in utterance_metadata.items():
|
|
group.attrs[k] = v
|
|
metadata[id][k] = v
|
|
|
|
except Exception as e:
|
|
tqdm.write(f'Error while processing {id}: {e}')
|
|
|
|
"""
|
|
with open(str(metadata_path), "w", encoding="utf-8") as f:
|
|
f.write( json.dumps( metadata ) )
|
|
"""
|
|
|
|
|
|
# training
|
|
for data_dir in tqdm(cfg.dataset.training, desc="Processing Training"):
|
|
add( data_dir, type="training" )
|
|
|
|
# validation
|
|
for data_dir in tqdm(cfg.dataset.validation, desc='Processing Validation'):
|
|
add( data_dir, type="validation" )
|
|
|
|
# noise
|
|
for data_dir in tqdm(cfg.dataset.noise, desc='Processing Noise'):
|
|
add( data_dir, type="noise", texts=False )
|
|
|
|
# write symmap
|
|
if "symmap" in hf:
|
|
del hf['symmap']
|
|
|
|
hf.create_dataset('symmap', data=json.dumps(symmap))
|
|
hf.close()
|
|
|
|
def transcribe_dataset():
|
|
import os
|
|
import json
|
|
import torch
|
|
import torchaudio
|
|
import whisperx
|
|
|
|
from tqdm.auto import tqdm
|
|
from pathlib import Path
|
|
|
|
# to-do: use argparser
|
|
batch_size = 16
|
|
device = "cuda"
|
|
dtype = "float16"
|
|
model_name = "large-v3"
|
|
|
|
input_audio = "voices"
|
|
output_dataset = "training/metadata"
|
|
|
|
skip_existing = True
|
|
diarize = False
|
|
|
|
#
|
|
model = whisperx.load_model(model_name, device, compute_type=dtype)
|
|
align_model, align_model_metadata, align_model_language = (None, None, None)
|
|
if diarize:
|
|
diarize_model = whisperx.DiarizationPipeline(device=device)
|
|
else:
|
|
diarize_model = None
|
|
|
|
def pad(num, zeroes):
|
|
return str(num).zfill(zeroes+1)
|
|
|
|
for dataset_name in os.listdir(f'./{input_audio}/'):
|
|
if not os.path.isdir(f'./{input_audio}/{dataset_name}/'):
|
|
continue
|
|
|
|
for speaker_id in tqdm(os.listdir(f'./{input_audio}/{dataset_name}/'), desc="Processing speaker"):
|
|
if not os.path.isdir(f'./{input_audio}/{dataset_name}/{speaker_id}'):
|
|
continue
|
|
|
|
outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/whisper.json')
|
|
|
|
if outpath.exists():
|
|
metadata = json.loads(open(outpath, 'r', encoding='utf-8').read())
|
|
else:
|
|
os.makedirs(f'./{output_dataset}/{dataset_name}/{speaker_id}/', exist_ok=True)
|
|
metadata = {}
|
|
|
|
for filename in tqdm(os.listdir(f'./{input_audio}/{dataset_name}/{speaker_id}/'), desc=f"Processing speaker: {speaker_id}"):
|
|
|
|
if skip_existing and filename in metadata:
|
|
continue
|
|
|
|
if ".json" in filename:
|
|
continue
|
|
|
|
inpath = f'./{input_audio}/{dataset_name}/{speaker_id}/{filename}'
|
|
|
|
if os.path.isdir(inpath):
|
|
continue
|
|
|
|
metadata[filename] = {
|
|
"segments": [],
|
|
"language": "",
|
|
"text": "",
|
|
"start": 0,
|
|
"end": 0,
|
|
}
|
|
|
|
audio = whisperx.load_audio(inpath)
|
|
result = model.transcribe(audio, batch_size=batch_size)
|
|
language = result["language"]
|
|
|
|
if language[:2] not in ["ja"]:
|
|
language = "en"
|
|
|
|
if align_model_language != language:
|
|
tqdm.write(f'Loading language: {language}')
|
|
align_model, align_model_metadata = whisperx.load_align_model(language_code=language, device=device)
|
|
align_model_language = language
|
|
|
|
result = whisperx.align(result["segments"], align_model, align_model_metadata, audio, device, return_char_alignments=False)
|
|
|
|
metadata[filename]["segments"] = result["segments"]
|
|
metadata[filename]["language"] = language
|
|
|
|
if diarize_model is not None:
|
|
diarize_segments = diarize_model(audio)
|
|
result = whisperx.assign_word_speakers(diarize_segments, result)
|
|
|
|
text = []
|
|
start = 0
|
|
end = 0
|
|
for segment in result["segments"]:
|
|
text.append( segment["text"] )
|
|
start = min( start, segment["start"] )
|
|
end = max( end, segment["end"] )
|
|
|
|
metadata[filename]["text"] = " ".join(text).strip()
|
|
metadata[filename]["start"] = start
|
|
metadata[filename]["end"] = end
|
|
|
|
open(outpath, 'w', encoding='utf-8').write(json.dumps(metadata))
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser("Save trained model to path.")
|
|
parser.add_argument("--action", type=str)
|
|
parser.add_argument("--tasks", type=str)
|
|
args, unknown = parser.parse_known_args()
|
|
|
|
task = args.action
|
|
|
|
cfg.dataset.workers = 1
|
|
|
|
class LoggerOveride:
|
|
def info(self, *args):
|
|
print(*args)
|
|
|
|
_logger = LoggerOveride()
|
|
|
|
if args.action == "hdf5":
|
|
transcribe_dataset()
|
|
elif args.action == "hdf5":
|
|
create_dataset_hdf5()
|
|
elif args.action == "list-dataset":
|
|
dataset = []
|
|
for group in os.listdir(cfg.data_dir):
|
|
for name in os.listdir(cfg.data_dir / group):
|
|
if len(os.listdir(cfg.data_dir / group / name)) == 0:
|
|
continue
|
|
dataset.append(f'{group}/{name}')
|
|
|
|
print(json.dumps(dataset))
|
|
elif args.action == "metadata":
|
|
create_dataset_metadata()
|
|
elif args.action == "sample":
|
|
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
|
|
|
samples = {
|
|
"training": [ next(iter(train_dl)), next(iter(train_dl)) ],
|
|
"evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ],
|
|
"validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
|
|
}
|
|
|
|
Path("./data/sample-test/").mkdir(parents=True, exist_ok=True)
|
|
|
|
for k, v in samples.items():
|
|
for i in range(len(v)):
|
|
for j in tqdm(range(len(v[i]['proms'])), desc="Decoding..."):
|
|
"""
|
|
try:
|
|
decode_to_file( v[i]['proms'][j], f"./data/sample-test/{k}.{i}.{j}.proms.wav", device="cpu" )
|
|
except Exception as e:
|
|
print(f"Error while decoding prom {k}.{i}.{j}.wav:", str(e))
|
|
try:
|
|
decode_to_file( v[i]['resps'][j], f"./data/sample-test/{k}.{i}.{j}.resps.wav", device="cpu" )
|
|
except Exception as e:
|
|
print(f"Error while decoding resp {k}.{i}.{j}.wav:", str(e))
|
|
"""
|
|
v[i]['proms'][j] = v[i]['proms'][j].shape
|
|
v[i]['resps'][j] = v[i]['resps'][j].shape
|
|
|
|
for k, v in samples.items():
|
|
for i in range(len(v)):
|
|
print(f'{k}[{i}]:', v[i])
|
|
|
|
elif args.action == "tasks":
|
|
index = 0
|
|
cfg.dataset.tasks_list = args.tasks.split(",")
|
|
|
|
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
|
batch = next(iter(train_dl))
|
|
|
|
for text, resps, proms, task in zip(batch["text"], batch["resps"], batch["proms"], batch["task"]):
|
|
if task not in cfg.dataset.tasks_list:
|
|
continue
|
|
|
|
print(text, task, cfg.model.prom_levels)
|
|
print( proms.shape, resps.shape )
|
|
|
|
tokens = 0
|
|
tokens += sum([ text.shape[0] for text in batch["text"] ])
|
|
tokens += sum([ resps.shape[0] for resps in batch["resps"] ])
|
|
print( tokens )
|
|
|
|
decode_to_file( proms, f"./data/{task}.proms.wav", device="cpu" )
|
|
decode_to_file( resps, f"./data/{task}.resps.wav", device="cpu" )
|
|
break |