updated process_datasets.py, added argparsing so I can mostly stop manually editing things, and some other cleanup

This commit is contained in:
mrq 2024-08-05 15:59:25 -05:00
parent debcc93e7e
commit 7cdfa3dc0c
2 changed files with 233 additions and 180 deletions

View File

@ -1,191 +1,104 @@
import os import os
import json import json
import argparse
import torch import torch
import torchaudio import torchaudio
import numpy as np import numpy as np
from tqdm.auto import tqdm from tqdm.auto import tqdm
from pathlib import Path from pathlib import Path
from vall_e.config import cfg from vall_e.config import cfg
# things that could be args
cfg.sample_rate = 24_000
cfg.audio_backend = "encodec"
"""
cfg.inference.weight_dtype = "bfloat16"
cfg.inference.dtype = torch.bfloat16
cfg.inference.amp = True
"""
from vall_e.emb.g2p import encode as valle_phonemize
from vall_e.emb.qnt import encode as valle_quantize, _replace_file_extension
input_audio = "voices"
input_metadata = "metadata"
output_dataset = f"training-{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}"
device = "cuda"
audio_extension = ".enc"
if cfg.audio_backend == "dac":
audio_extension = ".dac"
elif cfg.audio_backend == "audiodec":
audio_extension = ".dec"
slice = "auto"
missing = {
"transcription": [],
"audio": []
}
dataset = []
def pad(num, zeroes): def pad(num, zeroes):
return str(num).zfill(zeroes+1) return str(num).zfill(zeroes+1)
for dataset_name in sorted(os.listdir(f'./{input_audio}/')): def process_items( items, stride=0 ):
if not os.path.isdir(f'./{input_audio}/{dataset_name}/'): items = sorted( items )
print("Is not dir:", f'./{input_audio}/{dataset_name}/') return items if stride == 0 else [ item for i, item in enumerate( items ) if i % stride == 0 ]
continue
for speaker_id in tqdm(sorted(os.listdir(f'./{input_audio}/{dataset_name}/')), desc=f"Processing speaker in {dataset_name}"): def process_dataset( args ):
if not os.path.isdir(f'./{input_audio}/{dataset_name}/{speaker_id}'): # encodec / vocos
print("Is not dir:", f'./{input_audio}/{dataset_name}/{speaker_id}')
if args.audio_backend in ["encodec", "vocos"]:
audio_extension = ".enc"
cfg.sample_rate = 24_000
cfg.model.resp_levels = 8
elif args.audio_backend == "dac":
audio_extension = ".dac"
cfg.sample_rate = 44_100
cfg.model.resp_levels = 9
elif cfg.audio_backend == "audiodec":
sample_rate = 48_000
audio_extension = ".dec"
cfg.model.resp_levels = 8 # ?
else:
raise Exception(f"Unknown audio backend: {args.audio_backend}")
# prepare from args
cfg.audio_backend = args.audio_backend # "encodec"
cfg.inference.weight_dtype = args.dtype # "bfloat16"
cfg.inference.amp = args.amp # False
# import after because we've overriden the config above
from vall_e.emb.g2p import encode as valle_phonemize
from vall_e.emb.qnt import encode as valle_quantize, _replace_file_extension
input_audio = args.input_audio # "voice""
input_metadata = args.input_metadata # "metadata"
output_group = f"{args.output_group}-{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}" # "training"
device = args.device # "cuda"
raise_exceptions = args.raise_exceptions # False
stride = args.stride # 0
slice = args.slice # "auto"
language_map = {} # k = group, v = language
ignore_groups = [] # skip these groups
ignore_speakers = [] # skip these speakers
only_groups = [] # only process these groups
only_speakers = [] # only process these speakers
always_slice_groups = [] # always slice from this group
missing = {
"transcription": [],
"audio": []
}
dataset = []
for group_name in sorted(os.listdir(f'./{input_audio}/')):
if not os.path.isdir(f'./{input_audio}/{group_name}/'):
print("Is not dir:", f'./{input_audio}/{group_name}/')
continue continue
os.makedirs(f'./{output_dataset}/{dataset_name}/{speaker_id}/', exist_ok=True) if group_name in ignore_groups:
continue
if speaker_id == "Noise": if only_groups and group_name not in only_groups:
for filename in sorted(os.listdir(f'./{input_audio}/{dataset_name}/{speaker_id}/')):
inpath = Path(f'./{input_audio}/{dataset_name}/{speaker_id}/{filename}')
outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/{filename}')
if _replace_file_extension(outpath, audio_extension).exists():
continue
waveform, sample_rate = torchaudio.load(inpath)
qnt = valle_quantize(waveform, sr=sample_rate, device=device)
if cfg.audio_backend == "dac":
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": qnt.original_length,
"sample_rate": qnt.sample_rate,
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
"chunk_length": qnt.chunk_length,
"channels": qnt.channels,
"padding": qnt.padding,
"dac_version": "1.0.0",
},
})
else:
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
"codes": qnt.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": waveform.shape[-1],
"sample_rate": sample_rate,
},
})
continue continue
metadata_path = Path(f'./{input_metadata}/{dataset_name}/{speaker_id}/whisper.json') for speaker_id in tqdm(process_items(os.listdir(f'./{input_audio}/{group_name}/'), stride=stride), desc=f"Processing speaker in {group_name}"):
if not metadata_path.exists(): if not os.path.isdir(f'./{input_audio}/{group_name}/{speaker_id}'):
missing["transcription"].append(str(metadata_path)) print("Is not dir:", f'./{input_audio}/{group_name}/{speaker_id}')
continue
try:
metadata = json.loads(open(metadata_path, "r", encoding="utf-8").read())
except Exception as e:
missing["transcription"].append(str(metadata_path))
continue
if f'{dataset_name}/{speaker_id}' not in dataset:
dataset.append(f'{dataset_name}/{speaker_id}')
txts = []
wavs = []
use_slices = slice == True or (slice == "auto" and len(metadata.keys()) == 1) or dataset_name in ["LibriVox", "Audiobooks"]
for filename in sorted(metadata.keys()):
inpath = Path(f'./{input_audio}/{dataset_name}/{speaker_id}/{filename}')
if not inpath.exists():
missing["audio"].append(str(inpath))
continue continue
extension = os.path.splitext(filename)[-1][1:] if speaker_id in ignore_speakers:
fname = filename.replace(f'.{extension}', "") continue
if only_speakers and speaker_id not in only_speakers:
continue
waveform, sample_rate = None, None os.makedirs(f'./{output_group}/{group_name}/{speaker_id}/', exist_ok=True)
language = metadata[filename]["language"] if "language" in metadata[filename] else "en"
if len(metadata[filename]["segments"]) == 0 or not use_slices: if speaker_id == "Noise":
outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/{fname}.{extension}') for filename in sorted(os.listdir(f'./{input_audio}/{group_name}/{speaker_id}/')):
text = metadata[filename]["text"] inpath = Path(f'./{input_audio}/{group_name}/{speaker_id}/{filename}')
outpath = Path(f'./{output_group}/{group_name}/{speaker_id}/{filename}')
if len(text) == 0:
continue
if _replace_file_extension(outpath, audio_extension).exists():
continue
if waveform is None:
waveform, sample_rate = torchaudio.load(inpath)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
wavs.append((
outpath,
text,
language,
waveform,
sample_rate
))
else:
i = 0
for segment in metadata[filename]["segments"]:
id = pad(i, 4)
i = i + 1
outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/{fname}_{id}.{extension}')
text = segment["text"]
if len(text) == 0:
continue
if _replace_file_extension(outpath, audio_extension).exists(): if _replace_file_extension(outpath, audio_extension).exists():
continue continue
if waveform is None: waveform, sample_rate = torchaudio.load(inpath)
waveform, sample_rate = torchaudio.load(inpath)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
start = int(segment['start'] * sample_rate)
end = int(segment['end'] * sample_rate)
if start < 0:
start = 0
if end >= waveform.shape[-1]:
end = waveform.shape[-1] - 1
if end - start < 0:
continue
wavs.append((
outpath,
text,
language,
waveform[:, start:end],
sample_rate
))
if len(wavs) > 0:
for job in tqdm(wavs, desc=f"Quantizing: {speaker_id}"):
try:
outpath, text, language, waveform, sample_rate = job
phones = valle_phonemize( text, language=language )
qnt = valle_quantize(waveform, sr=sample_rate, device=device) qnt = valle_quantize(waveform, sr=sample_rate, device=device)
if cfg.audio_backend == "dac": if cfg.audio_backend == "dac":
@ -200,10 +113,6 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')):
"channels": qnt.channels, "channels": qnt.channels,
"padding": qnt.padding, "padding": qnt.padding,
"dac_version": "1.0.0", "dac_version": "1.0.0",
"text": text.strip(),
"phonemes": "".join(phones),
"language": language,
}, },
}) })
else: else:
@ -212,15 +121,168 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')):
"metadata": { "metadata": {
"original_length": waveform.shape[-1], "original_length": waveform.shape[-1],
"sample_rate": sample_rate, "sample_rate": sample_rate,
"text": text.strip(),
"phonemes": "".join(phones),
"language": language,
}, },
}) })
except Exception as e:
print(f"Failed to quantize: {outpath}:", e) continue
metadata_path = Path(f'./{input_metadata}/{group_name}/{speaker_id}/whisper.json')
if not metadata_path.exists():
missing["transcription"].append(str(metadata_path))
continue
try:
metadata = json.loads(open(metadata_path, "r", encoding="utf-8").read())
except Exception as e:
missing["transcription"].append(str(metadata_path))
continue
if f'{group_name}/{speaker_id}' not in dataset:
dataset.append(f'{group_name}/{speaker_id}')
txts = []
wavs = []
use_slices = slice == True or (slice == "auto" and len(metadata.keys()) == 1) or group_name in always_slice_groups
for filename in sorted(metadata.keys()):
inpath = Path(f'./{input_audio}/{group_name}/{speaker_id}/{filename}')
if not inpath.exists():
missing["audio"].append(str(inpath))
continue continue
open("./missing.json", 'w', encoding='utf-8').write(json.dumps(missing)) extension = os.path.splitext(filename)[-1][1:]
open("./dataset_list.json", 'w', encoding='utf-8').write(json.dumps(dataset)) fname = filename.replace(f'.{extension}', "")
waveform, sample_rate = None, None
language = language_map[group_name] if group_name in language_map else (metadata[filename]["language"] if "language" in metadata[filename] else "en")
if len(metadata[filename]["segments"]) == 0 or not use_slices:
outpath = Path(f'./{output_group}/{group_name}/{speaker_id}/{fname}.{extension}')
text = metadata[filename]["text"]
if len(text) == 0:
continue
if _replace_file_extension(outpath, audio_extension).exists():
continue
if waveform is None:
waveform, sample_rate = torchaudio.load(inpath)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
wavs.append((
outpath,
text,
language,
waveform,
sample_rate
))
else:
i = 0
for segment in metadata[filename]["segments"]:
id = pad(i, 4)
i = i + 1
outpath = Path(f'./{output_group}/{group_name}/{speaker_id}/{fname}_{id}.{extension}')
text = segment["text"]
if len(text) == 0:
continue
if _replace_file_extension(outpath, audio_extension).exists():
continue
if waveform is None:
waveform, sample_rate = torchaudio.load(inpath)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
start = int(segment['start'] * sample_rate)
end = int(segment['end'] * sample_rate)
if start < 0:
start = 0
if end >= waveform.shape[-1]:
end = waveform.shape[-1] - 1
if end - start < 0:
continue
wavs.append((
outpath,
text,
language,
waveform[:, start:end],
sample_rate
))
if len(wavs) > 0:
for job in tqdm(wavs, desc=f"Quantizing: {speaker_id}"):
try:
outpath, text, language, waveform, sample_rate = job
phones = valle_phonemize(text, language=language)
qnt = valle_quantize(waveform, sr=sample_rate, device=device)
if cfg.audio_backend == "dac":
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": qnt.original_length,
"sample_rate": qnt.sample_rate,
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
"chunk_length": qnt.chunk_length,
"channels": qnt.channels,
"padding": qnt.padding,
"dac_version": "1.0.0",
"text": text.strip(),
"phonemes": "".join(phones),
"language": language,
},
})
else:
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
"codes": qnt.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": waveform.shape[-1],
"sample_rate": sample_rate,
"text": text.strip(),
"phonemes": "".join(phones),
"language": language,
},
})
except Exception as e:
print(f"Failed to quantize: {outpath}:", e)
if raise_exceptions:
raise e
continue
open("./missing.json", 'w', encoding='utf-8').write(json.dumps(missing))
open("./dataset_list.json", 'w', encoding='utf-8').write(json.dumps(dataset))
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--audio-backend", type=str, default="encodec")
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--amp", action="store_true")
parser.add_argument("--input-audio", type=str, default="voices")
parser.add_argument("--input-metadata", type=str, default="metadata")
parser.add_argument("--output_group", type=str, default="training")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--raise-exceptions", action="store_true")
parser.add_argument("--stride", type=int, default=0)
parser.add_argument("--slice", type=str, default="auto")
args = parser.parse_args()
process_dataset( args )
if __name__ == "__main__":
main()

View File

@ -813,7 +813,7 @@ class Base(nn.Module):
inputs_embeds=x, inputs_embeds=x,
past_key_values=state, past_key_values=state,
position_ids=position_ids, position_ids=position_ids,
use_cache=True, use_cache=not self.training,
# return_dict=True, # return_dict=True,
) )
if self.n_experts > 1 and self.training: if self.n_experts > 1 and self.training:
@ -1350,15 +1350,6 @@ class Base(nn.Module):
x, m = list_to_tensor(x_list) x, m = list_to_tensor(x_list)
training = self.training training = self.training
# yes, there's a better way.
"""
training = False
for batch_index, batch in enumerate(inputs):
for name, input in batch:
if name == "targ":
training = True
"""
device = x.device device = x.device
batch_size = len(x_list) batch_size = len(x_list)