final tweaks, hopefully

This commit is contained in:
mrq 2024-04-28 22:28:29 -05:00
parent ffc334cf58
commit caad7ee3c9
6 changed files with 220 additions and 81 deletions

View File

@ -0,0 +1,99 @@
import os
import json
import torch
import torchaudio
from tqdm.auto import tqdm
from pathlib import Path
input_dataset = "metadata"
output_dataset = "metadata-cleaned"
def pad(num, zeroes):
return str(num).zfill(zeroes+1)
for dataset_name in os.listdir(f'./{input_dataset}/'):
if not os.path.isdir(f'./{input_dataset}/{dataset_name}/'):
print("Is not dir:", f'./{input_dataset}/{dataset_name}/')
continue
for speaker_id in tqdm(os.listdir(f'./{input_dataset}/{dataset_name}/'), desc=f"Processing speaker: {dataset_name}"):
if not os.path.isdir(f'./{input_dataset}/{dataset_name}/{speaker_id}'):
print("Is not dir:", f'./{input_dataset}/{dataset_name}/{speaker_id}')
continue
inpath = Path(f'./{input_dataset}/{dataset_name}/{speaker_id}/whisper.json')
outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/whisper.json')
if not inpath.exists():
continue
if outpath.exists():
continue
os.makedirs(f'./{output_dataset}/{dataset_name}/{speaker_id}/', exist_ok=True)
try:
in_metadata = json.loads(open(inpath, 'r', encoding='utf-8').read())
except Exception as e:
print("Failed to open metadata file:", inpath)
continue
out_metadata = {}
speaker_metadatas = {}
for filename, result in in_metadata.items():
language = result["language"] if "language" in result else "en"
out_metadata[filename] = {
"segments": [],
"language": language,
"text": "",
"start": 0,
"end": 0,
}
segments = []
text = []
start = 0
end = 0
diarized = False
for segment in result["segments"]:
# diarize split
if "speaker" in segment:
diarized = True
speaker_id = segment["speaker"]
if speaker_id not in speaker_metadatas:
speaker_metadatas[speaker_id] = {}
if filename not in speaker_metadatas[speaker_id]:
speaker_metadatas[speaker_id][filename] = {
"segments": [],
"language": language,
"text": "",
"start": 0,
"end": 0,
}
speaker_metadatas[speaker_id][filename]["segments"].append( segment )
else:
segments.append( segment )
text.append( segment["text"] )
start = min( start, segment["start"] )
end = max( end, segment["end"] )
out_metadata[filename]["segments"] = segments
out_metadata[filename]["text"] = " ".join(text).strip()
out_metadata[filename]["start"] = start
out_metadata[filename]["end"] = end
if len(segments) == 0:
del out_metadata[filename]
open(outpath, 'w', encoding='utf-8').write(json.dumps(out_metadata))
for speaker_id, out_metadata in speaker_metadatas.items():
os.makedirs(f'./{output_dataset}/{dataset_name}/{speaker_id}/', exist_ok=True)
outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/whisper.json')
open(outpath, 'w', encoding='utf-8').write(json.dumps(out_metadata))

View File

@ -8,26 +8,27 @@ from pathlib import Path
from vall_e.emb.g2p import encode as valle_phonemize
from vall_e.emb.qnt import encode as valle_quantize, _replace_file_extension
input_audio = "voice"
# things that could be args
input_audio = "voices"
input_metadata = "metadata"
output_dataset = "training-24K"
device = "cuda"
slice = "auto"
missing = {
"transcription": [],
"audio": []
}
device = "cuda"
def pad(num, zeroes):
return str(num).zfill(zeroes+1)
for dataset_name in os.listdir(f'./{input_audio}/'):
for dataset_name in sorted(os.listdir(f'./{input_audio}/')):
if not os.path.isdir(f'./{input_audio}/{dataset_name}/'):
print("Is not dir:", f'./{input_audio}/{dataset_name}/')
continue
for speaker_id in tqdm(os.listdir(f'./{input_audio}/{dataset_name}/'), desc="Processing speaker"):
for speaker_id in tqdm(sorted(os.listdir(f'./{input_audio}/{dataset_name}/')), desc=f"Processing speaker in {dataset_name}"):
if not os.path.isdir(f'./{input_audio}/{dataset_name}/{speaker_id}'):
print("Is not dir:", f'./{input_audio}/{dataset_name}/{speaker_id}')
continue
@ -36,24 +37,23 @@ for dataset_name in os.listdir(f'./{input_audio}/'):
metadata_path = Path(f'./{input_metadata}/{dataset_name}/{speaker_id}/whisper.json')
if not metadata_path.exists():
#print("Does not exist:", metadata_path)
missing["transcription"].append(str(metadata_path))
continue
try:
metadata = json.loads(open(metadata_path, "r", encoding="utf-8").read())
except Exception as e:
#print("Failed to load metadata:", metadata_path, e)
missing["transcription"].append(str(metadata_path))
continue
txts = []
wavs = []
for filename in metadata.keys():
use_slices = slice == True or (slice == "auto" and len(metadata.keys()) == 1) or dataset_name in ["LibriVox", "Audiobooks"]
for filename in sorted(metadata.keys()):
inpath = Path(f'./{input_audio}/{dataset_name}/{speaker_id}/{filename}')
if not inpath.exists():
#print("Does not exist:", inpath)
missing["audio"].append(str(inpath))
continue
@ -63,9 +63,8 @@ for dataset_name in os.listdir(f'./{input_audio}/'):
waveform, sample_rate = None, None
language = metadata[filename]["language"] if "language" in metadata[filename] else "english"
if len(metadata[filename]["segments"]) == 0:
id = pad(0, 4)
outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/{fname}_{id}.{extension}')
if len(metadata[filename]["segments"]) == 0 or not use_slices:
outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/{fname}.{extension}')
text = metadata[filename]["text"]
if len(text) == 0:
@ -91,8 +90,10 @@ for dataset_name in os.listdir(f'./{input_audio}/'):
sample_rate
))
else:
i = 0
for segment in metadata[filename]["segments"]:
id = pad(segment['id'], 4)
id = pad(i, 4)
i = i + 1
outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/{fname}_{id}.{extension}')
if _replace_file_extension(outpath, ".json").exists() and _replace_file_extension(outpath, ".dac").exists():

View File

@ -7,30 +7,35 @@ import whisperx
from tqdm.auto import tqdm
from pathlib import Path
device = "cuda"
# should be args
batch_size = 16
device = "cuda"
dtype = "float16"
model_size = "large-v2"
model_name = "large-v3"
input_audio = "voice"
input_audio = "voices"
output_dataset = "metadata"
skip_existing = True
diarize = False
model = whisperx.load_model(model_size, device, compute_type=dtype)
#
model = whisperx.load_model(model_name, device, compute_type=dtype)
align_model, align_model_metadata, align_model_language = (None, None, None)
if diarize:
diarize_model = whisperx.DiarizationPipeline(device=device)
else:
diarize_model = None
def pad(num, zeroes):
return str(num).zfill(zeroes+1)
for dataset_name in os.listdir(f'./{input_audio}/'):
if not os.path.isdir(f'./{input_audio}/{dataset_name}/'):
print("Is not dir:", f'./{input_audio}/{dataset_name}/')
continue
for speaker_id in tqdm(os.listdir(f'./{input_audio}/{dataset_name}/'), desc="Processing speaker"):
if not os.path.isdir(f'./{input_audio}/{dataset_name}/{speaker_id}'):
print("Is not dir:", f'./{input_audio}/{dataset_name}/{speaker_id}')
continue
outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/whisper.json')
@ -46,18 +51,29 @@ for dataset_name in os.listdir(f'./{input_audio}/'):
if skip_existing and filename in metadata:
continue
if ".json" in filename:
continue
inpath = f'./{input_audio}/{dataset_name}/{speaker_id}/{filename}'
if os.path.isdir(inpath):
continue
metadata[filename] = {
"segments": [],
"language": "",
"text": [],
"text": "",
"start": 0,
"end": 0,
}
audio = whisperx.load_audio(inpath)
result = model.transcribe(audio, batch_size=batch_size)
language = result["language"]
if language[:2] not in ["ja"]:
language = "en"
if align_model_language != language:
tqdm.write(f'Loading language: {language}')
align_model, align_model_metadata = whisperx.load_align_model(language_code=language, device=device)
@ -68,12 +84,20 @@ for dataset_name in os.listdir(f'./{input_audio}/'):
metadata[filename]["segments"] = result["segments"]
metadata[filename]["language"] = language
if diarize_model is not None:
diarize_segments = diarize_model(audio)
result = whisperx.assign_word_speakers(diarize_segments, result)
text = []
start = 0
end = 0
for segment in result["segments"]:
id = len(text)
text.append( segment["text"] )
metadata[filename]["segments"][id]["id"] = id
start = min( start, segment["start"] )
end = max( end, segment["end"] )
metadata[filename]["text"] = " ".join(text).strip()
metadata[filename]["start"] = start
metadata[filename]["end"] = end
open(outpath, 'w', encoding='utf-8').write(json.dumps(metadata))

View File

@ -33,6 +33,14 @@ class _Config:
def cache_dir(self):
return self.relpath / ".cache"
@property
def data_dir(self):
return self.relpath / "data"
@property
def metadata_dir(self):
return self.relpath / "metadata"
@property
def ckpt_dir(self):
return self.relpath / "ckpt"

View File

@ -85,46 +85,45 @@ def _calculate_durations( type="training" ):
@cfg.diskcache()
def _load_paths(dataset, type="training"):
return { cfg.get_spkr( data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") }
return { cfg.get_spkr( cfg.data_dir / data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") }
def _load_paths_from_metadata(dataset_name, type="training", validate=False):
data_dir = dataset_name if cfg.dataset.use_hdf5 else cfg.data_dir / dataset_name
def _load_paths_from_metadata(data_dir, type="training", validate=False):
_fn = _get_hdf5_paths if cfg.dataset.use_hdf5 else _get_paths_of_extensions
def _validate( entry ):
if "phones" not in entry or "duration" not in entry:
return False
phones = entry['phones']
duration = entry['duration']
phones = entry['phones'] if "phones" in entry else 0
duration = entry['duration'] if "duration" in entry else 0
if type not in _total_durations:
_total_durations[type] = 0
_total_durations[type] += entry['duration']
_total_durations[type] += duration
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
metadata_path = data_dir / "metadata.json"
metadata_path = cfg.metadata_dir / f'{dataset_name}.json'
metadata = {}
if cfg.dataset.use_metadata and metadata_path.exists():
metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read())
if len(metadata) == 0:
return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_quant_extension(), validate )
def key( dir, id ):
if not cfg.dataset.use_hdf5:
return data_dir / id
return f"/{type}{_get_hdf5_path(data_dir)}/{id}"
return f"/{type}/{_get_hdf5_path(data_dir)}/{id}"
return [ key(dir, id) for id in metadata.keys() if not validate or _validate(metadata[id]) ]
def _get_hdf5_path(path):
path = str(path)
if path[:2] != "./":
path = f'./{path}'
res = path.replace(cfg.cfg_path, "")
return res
# to-do: better validation
#print(path)
return str(path)
def _get_hdf5_paths( data_dir, type="training", validate=False ):
data_dir = str(data_dir)
@ -137,7 +136,7 @@ def _get_hdf5_paths( data_dir, type="training", validate=False ):
_total_durations[type] += child.attrs['duration']
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
key = f"/{type}{_get_hdf5_path(data_dir)}"
key = f"/{type}/{_get_hdf5_path(data_dir)}"
return [ Path(f"{key}/{child.attrs['id']}") for child in cfg.hdf5[key].values() if not validate or _validate(child) ] if key in cfg.hdf5 else []
def _get_paths_of_extensions( path, extensions=_get_quant_extension(), validate=False ):
@ -427,6 +426,9 @@ class Dataset(_Dataset):
if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path)
if key not in cfg.hdf5:
raise RuntimeError(f'Key of Path ({path}) not in HDF5: {key}')
text = cfg.hdf5[key]["text"][:]
resps = cfg.hdf5[key]["audio"][:, :]
@ -752,6 +754,10 @@ def create_train_val_dataloader():
# parse dataset into better to sample metadata
def create_dataset_metadata():
# need to fix
if True:
return
cfg.dataset.validate = False
cfg.dataset.use_hdf5 = False
@ -805,14 +811,19 @@ def create_dataset_hdf5( skip_existing=True ):
symmap = get_phone_symmap()
root = cfg.cfg_path
root = str(cfg.data_dir)
metadata_root = str(cfg.metadata_dir)
hf = cfg.hdf5
cfg.metadata_dir.mkdir(parents=True, exist_ok=True)
def add( dir, type="training", audios=True, texts=True ):
name = "./" + str(dir)
name = name .replace(root, "")
metadata = {}
name = str(dir)
name = name.replace(root, "")
metadata_path = Path(f"{metadata_root}/{name}.json")
metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read())
if not os.path.isdir(f'{root}/{name}/'):
return
@ -831,36 +842,38 @@ def create_dataset_hdf5( skip_existing=True ):
continue
key = f'{type}/{name}/{id}'
if key in hf:
if skip_existing:
continue
del hf[key]
group = hf.create_group(key)
if skip_existing and key in hf:
continue
group = hf.create_group(key) if key not in hf else hf[key]
group.attrs['id'] = id
group.attrs['type'] = type
group.attrs['speaker'] = name
metadata[id] = {}
if id not in metadata:
metadata[id] = {}
# audio
if audios:
qnt = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()]
codes = torch.from_numpy(qnt["codes"].astype(int))[0].t().to(dtype=torch.int16)
if _get_quant_extension() == ".dac":
if "audio" in group:
del group["audio"]
duration = qnt["metadata"]["original_length"] / qnt["metadata"]["sample_rate"]
dac = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()]
qnt = torch.from_numpy(dac["codes"].astype(int))[0].t().to(dtype=torch.int16)
duration = dac["metadata"]["original_length"] / dac["metadata"]["sample_rate"]
metadata[id]["metadata"] = {
"original_length": qnt["metadata"]["original_length"],
"sample_rate": qnt["metadata"]["sample_rate"],
"original_length": dac["metadata"]["original_length"],
"sample_rate": dac["metadata"]["sample_rate"],
}
else:
qnt = torch.load(f'{root}/{name}/{id}{_get_quant_extension()}')[0].t()
duration = qnt.shape[0] / 75
group.create_dataset('audio', data=qnt.numpy().astype(np.int16), compression='lzf')
qnt = qnt.numpy().astype(np.int16)
if "audio" not in group:
group.create_dataset('audio', data=qnt, compression='lzf')
group.attrs['duration'] = duration
metadata[id]["duration"] = duration
@ -870,52 +883,46 @@ def create_dataset_hdf5( skip_existing=True ):
# text
if texts:
if _get_quant_extension() == ".json":
if _get_phone_extension() == ".json":
json_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read())
content = json_metadata["phonemes"]
txt = json_metadata["text"]
else:
content = open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read().split(" ")
"""
phones = [f"<s>"] + [ " " if not p else p for p in content ] + [f"</s>"]
for s in set(phones):
if s not in symmap:
symmap[s] = len(symmap.keys())
phn = [ symmap[s] for s in phones ]
"""
txt = ""
phn = cfg.tokenizer.encode("".join(content))
phn = np.array(phn).astype(np.uint8)
if "text" in group:
del group["text"]
group.create_dataset('text', data=phn, compression='lzf', chunks=True)
group.create_dataset('transcription', data=txt, compression='lzf', chunks=True)
if "text" not in group:
group.create_dataset('text', data=phn, compression='lzf')
group.attrs['phonemes'] = len(phn)
group.attrs['transcription'] = txt
metadata[id]["phones"] = len(phn)
metadata[id]["transcription"] = txt
else:
group.attrs['phonemes'] = 0
metadata[id]["phones"] = 0
except Exception as e:
pass
raise e
#pass
with open(dir / "metadata.json", "w", encoding="utf-8") as f:
with open(str(metadata_path), "w", encoding="utf-8") as f:
f.write( json.dumps( metadata ) )
# training
for data_dir in tqdm(cfg.dataset.training, desc="Processing Training"):
for data_dir in tqdm(sorted(cfg.dataset.training), desc="Processing Training"):
add( data_dir, type="training" )
# validation
for data_dir in tqdm(cfg.dataset.validation, desc='Processing Validation'):
for data_dir in tqdm(sorted(cfg.dataset.validation), desc='Processing Validation'):
add( data_dir, type="validation" )
# noise
for data_dir in tqdm(cfg.dataset.noise, desc='Processing Noise'):
for data_dir in tqdm(sorted(cfg.dataset.noise), desc='Processing Noise'):
add( data_dir, type="noise", texts=False )
# write symmap

View File

@ -340,7 +340,7 @@ def example_usage():
def _load_quants(path) -> Tensor:
if cfg.inference.audio_backend == "dac":
qnt = np.load(f'{path}.dac', allow_pickle=True)[()]
return torch.from_numpy(qnt["codes"].astype(int))[0][:, :].t().to(torch.int16)
return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.prom_levels, :].t().to(torch.int16)
return torch.load(f'{path}.pt')[0][:, :cfg.model.prom_levels].t().to(torch.int16)
qnt = _load_quants("./data/qnt")
@ -350,7 +350,7 @@ def example_usage():
tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device),
]
proms_list = [
qnt[:75*3, :].to(device),
qnt.to(device),
]
resps_list = [
qnt.to(device),
@ -407,7 +407,7 @@ def example_usage():
frozen_params = set()
for k in list(embeddings.keys()):
if re.findall(r'_emb\.', k):
if re.findall(r'_emb.', k):
frozen_params.add(k)
else:
del embeddings[k]