final tweaks, hopefully
This commit is contained in:
parent
ffc334cf58
commit
caad7ee3c9
99
scripts/cleanup_dataset.py
Normal file
99
scripts/cleanup_dataset.py
Normal file
|
@ -0,0 +1,99 @@
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
input_dataset = "metadata"
|
||||||
|
output_dataset = "metadata-cleaned"
|
||||||
|
|
||||||
|
def pad(num, zeroes):
|
||||||
|
return str(num).zfill(zeroes+1)
|
||||||
|
|
||||||
|
for dataset_name in os.listdir(f'./{input_dataset}/'):
|
||||||
|
if not os.path.isdir(f'./{input_dataset}/{dataset_name}/'):
|
||||||
|
print("Is not dir:", f'./{input_dataset}/{dataset_name}/')
|
||||||
|
continue
|
||||||
|
|
||||||
|
for speaker_id in tqdm(os.listdir(f'./{input_dataset}/{dataset_name}/'), desc=f"Processing speaker: {dataset_name}"):
|
||||||
|
if not os.path.isdir(f'./{input_dataset}/{dataset_name}/{speaker_id}'):
|
||||||
|
print("Is not dir:", f'./{input_dataset}/{dataset_name}/{speaker_id}')
|
||||||
|
continue
|
||||||
|
|
||||||
|
inpath = Path(f'./{input_dataset}/{dataset_name}/{speaker_id}/whisper.json')
|
||||||
|
outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/whisper.json')
|
||||||
|
|
||||||
|
if not inpath.exists():
|
||||||
|
continue
|
||||||
|
|
||||||
|
if outpath.exists():
|
||||||
|
continue
|
||||||
|
|
||||||
|
os.makedirs(f'./{output_dataset}/{dataset_name}/{speaker_id}/', exist_ok=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
in_metadata = json.loads(open(inpath, 'r', encoding='utf-8').read())
|
||||||
|
except Exception as e:
|
||||||
|
print("Failed to open metadata file:", inpath)
|
||||||
|
continue
|
||||||
|
|
||||||
|
out_metadata = {}
|
||||||
|
speaker_metadatas = {}
|
||||||
|
|
||||||
|
for filename, result in in_metadata.items():
|
||||||
|
language = result["language"] if "language" in result else "en"
|
||||||
|
out_metadata[filename] = {
|
||||||
|
"segments": [],
|
||||||
|
"language": language,
|
||||||
|
"text": "",
|
||||||
|
"start": 0,
|
||||||
|
"end": 0,
|
||||||
|
}
|
||||||
|
segments = []
|
||||||
|
text = []
|
||||||
|
start = 0
|
||||||
|
end = 0
|
||||||
|
diarized = False
|
||||||
|
|
||||||
|
for segment in result["segments"]:
|
||||||
|
# diarize split
|
||||||
|
if "speaker" in segment:
|
||||||
|
diarized = True
|
||||||
|
speaker_id = segment["speaker"]
|
||||||
|
if speaker_id not in speaker_metadatas:
|
||||||
|
speaker_metadatas[speaker_id] = {}
|
||||||
|
|
||||||
|
if filename not in speaker_metadatas[speaker_id]:
|
||||||
|
speaker_metadatas[speaker_id][filename] = {
|
||||||
|
"segments": [],
|
||||||
|
"language": language,
|
||||||
|
"text": "",
|
||||||
|
"start": 0,
|
||||||
|
"end": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
speaker_metadatas[speaker_id][filename]["segments"].append( segment )
|
||||||
|
else:
|
||||||
|
segments.append( segment )
|
||||||
|
|
||||||
|
text.append( segment["text"] )
|
||||||
|
start = min( start, segment["start"] )
|
||||||
|
end = max( end, segment["end"] )
|
||||||
|
|
||||||
|
out_metadata[filename]["segments"] = segments
|
||||||
|
out_metadata[filename]["text"] = " ".join(text).strip()
|
||||||
|
out_metadata[filename]["start"] = start
|
||||||
|
out_metadata[filename]["end"] = end
|
||||||
|
|
||||||
|
if len(segments) == 0:
|
||||||
|
del out_metadata[filename]
|
||||||
|
|
||||||
|
open(outpath, 'w', encoding='utf-8').write(json.dumps(out_metadata))
|
||||||
|
|
||||||
|
for speaker_id, out_metadata in speaker_metadatas.items():
|
||||||
|
os.makedirs(f'./{output_dataset}/{dataset_name}/{speaker_id}/', exist_ok=True)
|
||||||
|
outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/whisper.json')
|
||||||
|
|
||||||
|
open(outpath, 'w', encoding='utf-8').write(json.dumps(out_metadata))
|
|
@ -8,26 +8,27 @@ from pathlib import Path
|
||||||
from vall_e.emb.g2p import encode as valle_phonemize
|
from vall_e.emb.g2p import encode as valle_phonemize
|
||||||
from vall_e.emb.qnt import encode as valle_quantize, _replace_file_extension
|
from vall_e.emb.qnt import encode as valle_quantize, _replace_file_extension
|
||||||
|
|
||||||
input_audio = "voice"
|
# things that could be args
|
||||||
|
input_audio = "voices"
|
||||||
input_metadata = "metadata"
|
input_metadata = "metadata"
|
||||||
output_dataset = "training-24K"
|
output_dataset = "training-24K"
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
slice = "auto"
|
||||||
missing = {
|
missing = {
|
||||||
"transcription": [],
|
"transcription": [],
|
||||||
"audio": []
|
"audio": []
|
||||||
}
|
}
|
||||||
|
|
||||||
device = "cuda"
|
|
||||||
|
|
||||||
def pad(num, zeroes):
|
def pad(num, zeroes):
|
||||||
return str(num).zfill(zeroes+1)
|
return str(num).zfill(zeroes+1)
|
||||||
|
|
||||||
for dataset_name in os.listdir(f'./{input_audio}/'):
|
for dataset_name in sorted(os.listdir(f'./{input_audio}/')):
|
||||||
if not os.path.isdir(f'./{input_audio}/{dataset_name}/'):
|
if not os.path.isdir(f'./{input_audio}/{dataset_name}/'):
|
||||||
print("Is not dir:", f'./{input_audio}/{dataset_name}/')
|
print("Is not dir:", f'./{input_audio}/{dataset_name}/')
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for speaker_id in tqdm(os.listdir(f'./{input_audio}/{dataset_name}/'), desc="Processing speaker"):
|
for speaker_id in tqdm(sorted(os.listdir(f'./{input_audio}/{dataset_name}/')), desc=f"Processing speaker in {dataset_name}"):
|
||||||
if not os.path.isdir(f'./{input_audio}/{dataset_name}/{speaker_id}'):
|
if not os.path.isdir(f'./{input_audio}/{dataset_name}/{speaker_id}'):
|
||||||
print("Is not dir:", f'./{input_audio}/{dataset_name}/{speaker_id}')
|
print("Is not dir:", f'./{input_audio}/{dataset_name}/{speaker_id}')
|
||||||
continue
|
continue
|
||||||
|
@ -36,24 +37,23 @@ for dataset_name in os.listdir(f'./{input_audio}/'):
|
||||||
|
|
||||||
metadata_path = Path(f'./{input_metadata}/{dataset_name}/{speaker_id}/whisper.json')
|
metadata_path = Path(f'./{input_metadata}/{dataset_name}/{speaker_id}/whisper.json')
|
||||||
if not metadata_path.exists():
|
if not metadata_path.exists():
|
||||||
#print("Does not exist:", metadata_path)
|
|
||||||
missing["transcription"].append(str(metadata_path))
|
missing["transcription"].append(str(metadata_path))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
metadata = json.loads(open(metadata_path, "r", encoding="utf-8").read())
|
metadata = json.loads(open(metadata_path, "r", encoding="utf-8").read())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
#print("Failed to load metadata:", metadata_path, e)
|
|
||||||
missing["transcription"].append(str(metadata_path))
|
missing["transcription"].append(str(metadata_path))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
txts = []
|
txts = []
|
||||||
wavs = []
|
wavs = []
|
||||||
|
|
||||||
for filename in metadata.keys():
|
use_slices = slice == True or (slice == "auto" and len(metadata.keys()) == 1) or dataset_name in ["LibriVox", "Audiobooks"]
|
||||||
|
|
||||||
|
for filename in sorted(metadata.keys()):
|
||||||
inpath = Path(f'./{input_audio}/{dataset_name}/{speaker_id}/{filename}')
|
inpath = Path(f'./{input_audio}/{dataset_name}/{speaker_id}/{filename}')
|
||||||
if not inpath.exists():
|
if not inpath.exists():
|
||||||
#print("Does not exist:", inpath)
|
|
||||||
missing["audio"].append(str(inpath))
|
missing["audio"].append(str(inpath))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -63,9 +63,8 @@ for dataset_name in os.listdir(f'./{input_audio}/'):
|
||||||
waveform, sample_rate = None, None
|
waveform, sample_rate = None, None
|
||||||
language = metadata[filename]["language"] if "language" in metadata[filename] else "english"
|
language = metadata[filename]["language"] if "language" in metadata[filename] else "english"
|
||||||
|
|
||||||
if len(metadata[filename]["segments"]) == 0:
|
if len(metadata[filename]["segments"]) == 0 or not use_slices:
|
||||||
id = pad(0, 4)
|
outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/{fname}.{extension}')
|
||||||
outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/{fname}_{id}.{extension}')
|
|
||||||
text = metadata[filename]["text"]
|
text = metadata[filename]["text"]
|
||||||
|
|
||||||
if len(text) == 0:
|
if len(text) == 0:
|
||||||
|
@ -91,8 +90,10 @@ for dataset_name in os.listdir(f'./{input_audio}/'):
|
||||||
sample_rate
|
sample_rate
|
||||||
))
|
))
|
||||||
else:
|
else:
|
||||||
|
i = 0
|
||||||
for segment in metadata[filename]["segments"]:
|
for segment in metadata[filename]["segments"]:
|
||||||
id = pad(segment['id'], 4)
|
id = pad(i, 4)
|
||||||
|
i = i + 1
|
||||||
outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/{fname}_{id}.{extension}')
|
outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/{fname}_{id}.{extension}')
|
||||||
|
|
||||||
if _replace_file_extension(outpath, ".json").exists() and _replace_file_extension(outpath, ".dac").exists():
|
if _replace_file_extension(outpath, ".json").exists() and _replace_file_extension(outpath, ".dac").exists():
|
||||||
|
|
|
@ -7,30 +7,35 @@ import whisperx
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
device = "cuda"
|
# should be args
|
||||||
batch_size = 16
|
batch_size = 16
|
||||||
|
device = "cuda"
|
||||||
dtype = "float16"
|
dtype = "float16"
|
||||||
model_size = "large-v2"
|
model_name = "large-v3"
|
||||||
|
|
||||||
input_audio = "voice"
|
input_audio = "voices"
|
||||||
output_dataset = "metadata"
|
output_dataset = "metadata"
|
||||||
|
|
||||||
skip_existing = True
|
skip_existing = True
|
||||||
|
diarize = False
|
||||||
|
|
||||||
model = whisperx.load_model(model_size, device, compute_type=dtype)
|
#
|
||||||
|
model = whisperx.load_model(model_name, device, compute_type=dtype)
|
||||||
align_model, align_model_metadata, align_model_language = (None, None, None)
|
align_model, align_model_metadata, align_model_language = (None, None, None)
|
||||||
|
if diarize:
|
||||||
|
diarize_model = whisperx.DiarizationPipeline(device=device)
|
||||||
|
else:
|
||||||
|
diarize_model = None
|
||||||
|
|
||||||
def pad(num, zeroes):
|
def pad(num, zeroes):
|
||||||
return str(num).zfill(zeroes+1)
|
return str(num).zfill(zeroes+1)
|
||||||
|
|
||||||
for dataset_name in os.listdir(f'./{input_audio}/'):
|
for dataset_name in os.listdir(f'./{input_audio}/'):
|
||||||
if not os.path.isdir(f'./{input_audio}/{dataset_name}/'):
|
if not os.path.isdir(f'./{input_audio}/{dataset_name}/'):
|
||||||
print("Is not dir:", f'./{input_audio}/{dataset_name}/')
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for speaker_id in tqdm(os.listdir(f'./{input_audio}/{dataset_name}/'), desc="Processing speaker"):
|
for speaker_id in tqdm(os.listdir(f'./{input_audio}/{dataset_name}/'), desc="Processing speaker"):
|
||||||
if not os.path.isdir(f'./{input_audio}/{dataset_name}/{speaker_id}'):
|
if not os.path.isdir(f'./{input_audio}/{dataset_name}/{speaker_id}'):
|
||||||
print("Is not dir:", f'./{input_audio}/{dataset_name}/{speaker_id}')
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/whisper.json')
|
outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/whisper.json')
|
||||||
|
@ -46,18 +51,29 @@ for dataset_name in os.listdir(f'./{input_audio}/'):
|
||||||
if skip_existing and filename in metadata:
|
if skip_existing and filename in metadata:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if ".json" in filename:
|
||||||
|
continue
|
||||||
|
|
||||||
inpath = f'./{input_audio}/{dataset_name}/{speaker_id}/{filename}'
|
inpath = f'./{input_audio}/{dataset_name}/{speaker_id}/{filename}'
|
||||||
|
|
||||||
|
if os.path.isdir(inpath):
|
||||||
|
continue
|
||||||
|
|
||||||
metadata[filename] = {
|
metadata[filename] = {
|
||||||
"segments": [],
|
"segments": [],
|
||||||
"language": "",
|
"language": "",
|
||||||
"text": [],
|
"text": "",
|
||||||
|
"start": 0,
|
||||||
|
"end": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
audio = whisperx.load_audio(inpath)
|
audio = whisperx.load_audio(inpath)
|
||||||
result = model.transcribe(audio, batch_size=batch_size)
|
result = model.transcribe(audio, batch_size=batch_size)
|
||||||
language = result["language"]
|
language = result["language"]
|
||||||
|
|
||||||
|
if language[:2] not in ["ja"]:
|
||||||
|
language = "en"
|
||||||
|
|
||||||
if align_model_language != language:
|
if align_model_language != language:
|
||||||
tqdm.write(f'Loading language: {language}')
|
tqdm.write(f'Loading language: {language}')
|
||||||
align_model, align_model_metadata = whisperx.load_align_model(language_code=language, device=device)
|
align_model, align_model_metadata = whisperx.load_align_model(language_code=language, device=device)
|
||||||
|
@ -68,12 +84,20 @@ for dataset_name in os.listdir(f'./{input_audio}/'):
|
||||||
metadata[filename]["segments"] = result["segments"]
|
metadata[filename]["segments"] = result["segments"]
|
||||||
metadata[filename]["language"] = language
|
metadata[filename]["language"] = language
|
||||||
|
|
||||||
|
if diarize_model is not None:
|
||||||
|
diarize_segments = diarize_model(audio)
|
||||||
|
result = whisperx.assign_word_speakers(diarize_segments, result)
|
||||||
|
|
||||||
text = []
|
text = []
|
||||||
|
start = 0
|
||||||
|
end = 0
|
||||||
for segment in result["segments"]:
|
for segment in result["segments"]:
|
||||||
id = len(text)
|
|
||||||
text.append( segment["text"] )
|
text.append( segment["text"] )
|
||||||
metadata[filename]["segments"][id]["id"] = id
|
start = min( start, segment["start"] )
|
||||||
|
end = max( end, segment["end"] )
|
||||||
|
|
||||||
metadata[filename]["text"] = " ".join(text).strip()
|
metadata[filename]["text"] = " ".join(text).strip()
|
||||||
|
metadata[filename]["start"] = start
|
||||||
|
metadata[filename]["end"] = end
|
||||||
|
|
||||||
open(outpath, 'w', encoding='utf-8').write(json.dumps(metadata))
|
open(outpath, 'w', encoding='utf-8').write(json.dumps(metadata))
|
|
@ -33,6 +33,14 @@ class _Config:
|
||||||
def cache_dir(self):
|
def cache_dir(self):
|
||||||
return self.relpath / ".cache"
|
return self.relpath / ".cache"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data_dir(self):
|
||||||
|
return self.relpath / "data"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def metadata_dir(self):
|
||||||
|
return self.relpath / "metadata"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ckpt_dir(self):
|
def ckpt_dir(self):
|
||||||
return self.relpath / "ckpt"
|
return self.relpath / "ckpt"
|
||||||
|
|
115
vall_e/data.py
115
vall_e/data.py
|
@ -85,46 +85,45 @@ def _calculate_durations( type="training" ):
|
||||||
|
|
||||||
@cfg.diskcache()
|
@cfg.diskcache()
|
||||||
def _load_paths(dataset, type="training"):
|
def _load_paths(dataset, type="training"):
|
||||||
return { cfg.get_spkr( data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") }
|
return { cfg.get_spkr( cfg.data_dir / data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") }
|
||||||
|
|
||||||
|
def _load_paths_from_metadata(dataset_name, type="training", validate=False):
|
||||||
|
data_dir = dataset_name if cfg.dataset.use_hdf5 else cfg.data_dir / dataset_name
|
||||||
|
|
||||||
def _load_paths_from_metadata(data_dir, type="training", validate=False):
|
|
||||||
_fn = _get_hdf5_paths if cfg.dataset.use_hdf5 else _get_paths_of_extensions
|
_fn = _get_hdf5_paths if cfg.dataset.use_hdf5 else _get_paths_of_extensions
|
||||||
|
|
||||||
def _validate( entry ):
|
def _validate( entry ):
|
||||||
if "phones" not in entry or "duration" not in entry:
|
phones = entry['phones'] if "phones" in entry else 0
|
||||||
return False
|
duration = entry['duration'] if "duration" in entry else 0
|
||||||
phones = entry['phones']
|
|
||||||
duration = entry['duration']
|
|
||||||
if type not in _total_durations:
|
if type not in _total_durations:
|
||||||
_total_durations[type] = 0
|
_total_durations[type] = 0
|
||||||
_total_durations[type] += entry['duration']
|
_total_durations[type] += duration
|
||||||
|
|
||||||
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
|
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
|
||||||
|
|
||||||
metadata_path = data_dir / "metadata.json"
|
metadata_path = cfg.metadata_dir / f'{dataset_name}.json'
|
||||||
metadata = {}
|
metadata = {}
|
||||||
|
|
||||||
if cfg.dataset.use_metadata and metadata_path.exists():
|
if cfg.dataset.use_metadata and metadata_path.exists():
|
||||||
metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read())
|
metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read())
|
||||||
|
|
||||||
if len(metadata) == 0:
|
if len(metadata) == 0:
|
||||||
return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_quant_extension(), validate )
|
return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_quant_extension(), validate )
|
||||||
|
|
||||||
|
|
||||||
def key( dir, id ):
|
def key( dir, id ):
|
||||||
if not cfg.dataset.use_hdf5:
|
if not cfg.dataset.use_hdf5:
|
||||||
return data_dir / id
|
return data_dir / id
|
||||||
|
|
||||||
return f"/{type}{_get_hdf5_path(data_dir)}/{id}"
|
return f"/{type}/{_get_hdf5_path(data_dir)}/{id}"
|
||||||
|
|
||||||
return [ key(dir, id) for id in metadata.keys() if not validate or _validate(metadata[id]) ]
|
return [ key(dir, id) for id in metadata.keys() if not validate or _validate(metadata[id]) ]
|
||||||
|
|
||||||
|
|
||||||
def _get_hdf5_path(path):
|
def _get_hdf5_path(path):
|
||||||
path = str(path)
|
# to-do: better validation
|
||||||
if path[:2] != "./":
|
#print(path)
|
||||||
path = f'./{path}'
|
return str(path)
|
||||||
|
|
||||||
res = path.replace(cfg.cfg_path, "")
|
|
||||||
return res
|
|
||||||
|
|
||||||
def _get_hdf5_paths( data_dir, type="training", validate=False ):
|
def _get_hdf5_paths( data_dir, type="training", validate=False ):
|
||||||
data_dir = str(data_dir)
|
data_dir = str(data_dir)
|
||||||
|
@ -137,7 +136,7 @@ def _get_hdf5_paths( data_dir, type="training", validate=False ):
|
||||||
_total_durations[type] += child.attrs['duration']
|
_total_durations[type] += child.attrs['duration']
|
||||||
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
|
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
|
||||||
|
|
||||||
key = f"/{type}{_get_hdf5_path(data_dir)}"
|
key = f"/{type}/{_get_hdf5_path(data_dir)}"
|
||||||
return [ Path(f"{key}/{child.attrs['id']}") for child in cfg.hdf5[key].values() if not validate or _validate(child) ] if key in cfg.hdf5 else []
|
return [ Path(f"{key}/{child.attrs['id']}") for child in cfg.hdf5[key].values() if not validate or _validate(child) ] if key in cfg.hdf5 else []
|
||||||
|
|
||||||
def _get_paths_of_extensions( path, extensions=_get_quant_extension(), validate=False ):
|
def _get_paths_of_extensions( path, extensions=_get_quant_extension(), validate=False ):
|
||||||
|
@ -427,6 +426,9 @@ class Dataset(_Dataset):
|
||||||
if cfg.dataset.use_hdf5:
|
if cfg.dataset.use_hdf5:
|
||||||
key = _get_hdf5_path(path)
|
key = _get_hdf5_path(path)
|
||||||
|
|
||||||
|
if key not in cfg.hdf5:
|
||||||
|
raise RuntimeError(f'Key of Path ({path}) not in HDF5: {key}')
|
||||||
|
|
||||||
text = cfg.hdf5[key]["text"][:]
|
text = cfg.hdf5[key]["text"][:]
|
||||||
resps = cfg.hdf5[key]["audio"][:, :]
|
resps = cfg.hdf5[key]["audio"][:, :]
|
||||||
|
|
||||||
|
@ -752,6 +754,10 @@ def create_train_val_dataloader():
|
||||||
|
|
||||||
# parse dataset into better to sample metadata
|
# parse dataset into better to sample metadata
|
||||||
def create_dataset_metadata():
|
def create_dataset_metadata():
|
||||||
|
# need to fix
|
||||||
|
if True:
|
||||||
|
return
|
||||||
|
|
||||||
cfg.dataset.validate = False
|
cfg.dataset.validate = False
|
||||||
cfg.dataset.use_hdf5 = False
|
cfg.dataset.use_hdf5 = False
|
||||||
|
|
||||||
|
@ -805,14 +811,19 @@ def create_dataset_hdf5( skip_existing=True ):
|
||||||
|
|
||||||
symmap = get_phone_symmap()
|
symmap = get_phone_symmap()
|
||||||
|
|
||||||
root = cfg.cfg_path
|
root = str(cfg.data_dir)
|
||||||
|
metadata_root = str(cfg.metadata_dir)
|
||||||
hf = cfg.hdf5
|
hf = cfg.hdf5
|
||||||
|
|
||||||
|
cfg.metadata_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
def add( dir, type="training", audios=True, texts=True ):
|
def add( dir, type="training", audios=True, texts=True ):
|
||||||
name = "./" + str(dir)
|
name = str(dir)
|
||||||
name = name .replace(root, "")
|
name = name.replace(root, "")
|
||||||
metadata = {}
|
|
||||||
|
metadata_path = Path(f"{metadata_root}/{name}.json")
|
||||||
|
|
||||||
|
metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read())
|
||||||
|
|
||||||
if not os.path.isdir(f'{root}/{name}/'):
|
if not os.path.isdir(f'{root}/{name}/'):
|
||||||
return
|
return
|
||||||
|
@ -831,36 +842,38 @@ def create_dataset_hdf5( skip_existing=True ):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
key = f'{type}/{name}/{id}'
|
key = f'{type}/{name}/{id}'
|
||||||
if key in hf:
|
|
||||||
if skip_existing:
|
|
||||||
continue
|
|
||||||
del hf[key]
|
|
||||||
|
|
||||||
group = hf.create_group(key)
|
if skip_existing and key in hf:
|
||||||
|
continue
|
||||||
|
|
||||||
|
group = hf.create_group(key) if key not in hf else hf[key]
|
||||||
|
|
||||||
group.attrs['id'] = id
|
group.attrs['id'] = id
|
||||||
group.attrs['type'] = type
|
group.attrs['type'] = type
|
||||||
group.attrs['speaker'] = name
|
group.attrs['speaker'] = name
|
||||||
|
|
||||||
|
if id not in metadata:
|
||||||
metadata[id] = {}
|
metadata[id] = {}
|
||||||
|
|
||||||
# audio
|
# audio
|
||||||
if audios:
|
if audios:
|
||||||
qnt = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()]
|
|
||||||
codes = torch.from_numpy(qnt["codes"].astype(int))[0].t().to(dtype=torch.int16)
|
|
||||||
|
|
||||||
if _get_quant_extension() == ".dac":
|
if _get_quant_extension() == ".dac":
|
||||||
if "audio" in group:
|
dac = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()]
|
||||||
del group["audio"]
|
qnt = torch.from_numpy(dac["codes"].astype(int))[0].t().to(dtype=torch.int16)
|
||||||
duration = qnt["metadata"]["original_length"] / qnt["metadata"]["sample_rate"]
|
|
||||||
|
duration = dac["metadata"]["original_length"] / dac["metadata"]["sample_rate"]
|
||||||
metadata[id]["metadata"] = {
|
metadata[id]["metadata"] = {
|
||||||
"original_length": qnt["metadata"]["original_length"],
|
"original_length": dac["metadata"]["original_length"],
|
||||||
"sample_rate": qnt["metadata"]["sample_rate"],
|
"sample_rate": dac["metadata"]["sample_rate"],
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
qnt = torch.load(f'{root}/{name}/{id}{_get_quant_extension()}')[0].t()
|
qnt = torch.load(f'{root}/{name}/{id}{_get_quant_extension()}')[0].t()
|
||||||
duration = qnt.shape[0] / 75
|
duration = qnt.shape[0] / 75
|
||||||
|
|
||||||
group.create_dataset('audio', data=qnt.numpy().astype(np.int16), compression='lzf')
|
qnt = qnt.numpy().astype(np.int16)
|
||||||
|
|
||||||
|
if "audio" not in group:
|
||||||
|
group.create_dataset('audio', data=qnt, compression='lzf')
|
||||||
|
|
||||||
group.attrs['duration'] = duration
|
group.attrs['duration'] = duration
|
||||||
metadata[id]["duration"] = duration
|
metadata[id]["duration"] = duration
|
||||||
|
@ -870,52 +883,46 @@ def create_dataset_hdf5( skip_existing=True ):
|
||||||
|
|
||||||
# text
|
# text
|
||||||
if texts:
|
if texts:
|
||||||
if _get_quant_extension() == ".json":
|
if _get_phone_extension() == ".json":
|
||||||
json_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read())
|
json_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read())
|
||||||
content = json_metadata["phonemes"]
|
content = json_metadata["phonemes"]
|
||||||
|
txt = json_metadata["text"]
|
||||||
else:
|
else:
|
||||||
content = open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read().split(" ")
|
content = open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read().split(" ")
|
||||||
|
txt = ""
|
||||||
"""
|
|
||||||
phones = [f"<s>"] + [ " " if not p else p for p in content ] + [f"</s>"]
|
|
||||||
for s in set(phones):
|
|
||||||
if s not in symmap:
|
|
||||||
symmap[s] = len(symmap.keys())
|
|
||||||
|
|
||||||
phn = [ symmap[s] for s in phones ]
|
|
||||||
"""
|
|
||||||
|
|
||||||
phn = cfg.tokenizer.encode("".join(content))
|
phn = cfg.tokenizer.encode("".join(content))
|
||||||
phn = np.array(phn).astype(np.uint8)
|
phn = np.array(phn).astype(np.uint8)
|
||||||
|
|
||||||
if "text" in group:
|
if "text" not in group:
|
||||||
del group["text"]
|
group.create_dataset('text', data=phn, compression='lzf')
|
||||||
|
|
||||||
group.create_dataset('text', data=phn, compression='lzf', chunks=True)
|
|
||||||
group.create_dataset('transcription', data=txt, compression='lzf', chunks=True)
|
|
||||||
|
|
||||||
group.attrs['phonemes'] = len(phn)
|
group.attrs['phonemes'] = len(phn)
|
||||||
|
group.attrs['transcription'] = txt
|
||||||
|
|
||||||
metadata[id]["phones"] = len(phn)
|
metadata[id]["phones"] = len(phn)
|
||||||
|
metadata[id]["transcription"] = txt
|
||||||
else:
|
else:
|
||||||
group.attrs['phonemes'] = 0
|
group.attrs['phonemes'] = 0
|
||||||
metadata[id]["phones"] = 0
|
metadata[id]["phones"] = 0
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
raise e
|
||||||
|
#pass
|
||||||
|
|
||||||
with open(dir / "metadata.json", "w", encoding="utf-8") as f:
|
with open(str(metadata_path), "w", encoding="utf-8") as f:
|
||||||
f.write( json.dumps( metadata ) )
|
f.write( json.dumps( metadata ) )
|
||||||
|
|
||||||
|
|
||||||
# training
|
# training
|
||||||
for data_dir in tqdm(cfg.dataset.training, desc="Processing Training"):
|
for data_dir in tqdm(sorted(cfg.dataset.training), desc="Processing Training"):
|
||||||
add( data_dir, type="training" )
|
add( data_dir, type="training" )
|
||||||
|
|
||||||
# validation
|
# validation
|
||||||
for data_dir in tqdm(cfg.dataset.validation, desc='Processing Validation'):
|
for data_dir in tqdm(sorted(cfg.dataset.validation), desc='Processing Validation'):
|
||||||
add( data_dir, type="validation" )
|
add( data_dir, type="validation" )
|
||||||
|
|
||||||
# noise
|
# noise
|
||||||
for data_dir in tqdm(cfg.dataset.noise, desc='Processing Noise'):
|
for data_dir in tqdm(sorted(cfg.dataset.noise), desc='Processing Noise'):
|
||||||
add( data_dir, type="noise", texts=False )
|
add( data_dir, type="noise", texts=False )
|
||||||
|
|
||||||
# write symmap
|
# write symmap
|
||||||
|
|
|
@ -340,7 +340,7 @@ def example_usage():
|
||||||
def _load_quants(path) -> Tensor:
|
def _load_quants(path) -> Tensor:
|
||||||
if cfg.inference.audio_backend == "dac":
|
if cfg.inference.audio_backend == "dac":
|
||||||
qnt = np.load(f'{path}.dac', allow_pickle=True)[()]
|
qnt = np.load(f'{path}.dac', allow_pickle=True)[()]
|
||||||
return torch.from_numpy(qnt["codes"].astype(int))[0][:, :].t().to(torch.int16)
|
return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.prom_levels, :].t().to(torch.int16)
|
||||||
return torch.load(f'{path}.pt')[0][:, :cfg.model.prom_levels].t().to(torch.int16)
|
return torch.load(f'{path}.pt')[0][:, :cfg.model.prom_levels].t().to(torch.int16)
|
||||||
|
|
||||||
qnt = _load_quants("./data/qnt")
|
qnt = _load_quants("./data/qnt")
|
||||||
|
@ -350,7 +350,7 @@ def example_usage():
|
||||||
tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device),
|
tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device),
|
||||||
]
|
]
|
||||||
proms_list = [
|
proms_list = [
|
||||||
qnt[:75*3, :].to(device),
|
qnt.to(device),
|
||||||
]
|
]
|
||||||
resps_list = [
|
resps_list = [
|
||||||
qnt.to(device),
|
qnt.to(device),
|
||||||
|
@ -407,7 +407,7 @@ def example_usage():
|
||||||
|
|
||||||
frozen_params = set()
|
frozen_params = set()
|
||||||
for k in list(embeddings.keys()):
|
for k in list(embeddings.keys()):
|
||||||
if re.findall(r'_emb\.', k):
|
if re.findall(r'_emb.', k):
|
||||||
frozen_params.add(k)
|
frozen_params.add(k)
|
||||||
else:
|
else:
|
||||||
del embeddings[k]
|
del embeddings[k]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user