updated process_datasets.py, added argparsing so I can mostly stop manually editing things, and some other cleanup

This commit is contained in:
mrq 2024-08-05 15:59:25 -05:00
parent debcc93e7e
commit 7cdfa3dc0c
2 changed files with 233 additions and 180 deletions

View File

@ -1,191 +1,104 @@
import os import os
import json import json
import argparse
import torch import torch
import torchaudio import torchaudio
import numpy as np import numpy as np
from tqdm.auto import tqdm from tqdm.auto import tqdm
from pathlib import Path from pathlib import Path
from vall_e.config import cfg from vall_e.config import cfg
# things that could be args
cfg.sample_rate = 24_000
cfg.audio_backend = "encodec"
"""
cfg.inference.weight_dtype = "bfloat16"
cfg.inference.dtype = torch.bfloat16
cfg.inference.amp = True
"""
from vall_e.emb.g2p import encode as valle_phonemize
from vall_e.emb.qnt import encode as valle_quantize, _replace_file_extension
input_audio = "voices"
input_metadata = "metadata"
output_dataset = f"training-{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}"
device = "cuda"
audio_extension = ".enc"
if cfg.audio_backend == "dac":
audio_extension = ".dac"
elif cfg.audio_backend == "audiodec":
audio_extension = ".dec"
slice = "auto"
missing = {
"transcription": [],
"audio": []
}
dataset = []
def pad(num, zeroes): def pad(num, zeroes):
return str(num).zfill(zeroes+1) return str(num).zfill(zeroes+1)
for dataset_name in sorted(os.listdir(f'./{input_audio}/')): def process_items( items, stride=0 ):
if not os.path.isdir(f'./{input_audio}/{dataset_name}/'): items = sorted( items )
print("Is not dir:", f'./{input_audio}/{dataset_name}/') return items if stride == 0 else [ item for i, item in enumerate( items ) if i % stride == 0 ]
continue
for speaker_id in tqdm(sorted(os.listdir(f'./{input_audio}/{dataset_name}/')), desc=f"Processing speaker in {dataset_name}"): def process_dataset( args ):
if not os.path.isdir(f'./{input_audio}/{dataset_name}/{speaker_id}'): # encodec / vocos
print("Is not dir:", f'./{input_audio}/{dataset_name}/{speaker_id}')
continue
os.makedirs(f'./{output_dataset}/{dataset_name}/{speaker_id}/', exist_ok=True)
if speaker_id == "Noise": if args.audio_backend in ["encodec", "vocos"]:
for filename in sorted(os.listdir(f'./{input_audio}/{dataset_name}/{speaker_id}/')): audio_extension = ".enc"
inpath = Path(f'./{input_audio}/{dataset_name}/{speaker_id}/{filename}') cfg.sample_rate = 24_000
outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/{filename}') cfg.model.resp_levels = 8
elif args.audio_backend == "dac":
audio_extension = ".dac"
cfg.sample_rate = 44_100
cfg.model.resp_levels = 9
elif cfg.audio_backend == "audiodec":
sample_rate = 48_000
audio_extension = ".dec"
cfg.model.resp_levels = 8 # ?
else:
raise Exception(f"Unknown audio backend: {args.audio_backend}")
if _replace_file_extension(outpath, audio_extension).exists(): # prepare from args
continue cfg.audio_backend = args.audio_backend # "encodec"
cfg.inference.weight_dtype = args.dtype # "bfloat16"
cfg.inference.amp = args.amp # False
waveform, sample_rate = torchaudio.load(inpath) # import after because we've overriden the config above
qnt = valle_quantize(waveform, sr=sample_rate, device=device) from vall_e.emb.g2p import encode as valle_phonemize
from vall_e.emb.qnt import encode as valle_quantize, _replace_file_extension
if cfg.audio_backend == "dac": input_audio = args.input_audio # "voice""
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { input_metadata = args.input_metadata # "metadata"
"codes": qnt.codes.cpu().numpy().astype(np.uint16), output_group = f"{args.output_group}-{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}" # "training"
"metadata": { device = args.device # "cuda"
"original_length": qnt.original_length, raise_exceptions = args.raise_exceptions # False
"sample_rate": qnt.sample_rate, stride = args.stride # 0
slice = args.slice # "auto"
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
"chunk_length": qnt.chunk_length,
"channels": qnt.channels,
"padding": qnt.padding,
"dac_version": "1.0.0",
},
})
else:
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
"codes": qnt.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": waveform.shape[-1],
"sample_rate": sample_rate,
},
})
continue language_map = {} # k = group, v = language
metadata_path = Path(f'./{input_metadata}/{dataset_name}/{speaker_id}/whisper.json') ignore_groups = [] # skip these groups
if not metadata_path.exists(): ignore_speakers = [] # skip these speakers
missing["transcription"].append(str(metadata_path))
only_groups = [] # only process these groups
only_speakers = [] # only process these speakers
always_slice_groups = [] # always slice from this group
missing = {
"transcription": [],
"audio": []
}
dataset = []
for group_name in sorted(os.listdir(f'./{input_audio}/')):
if not os.path.isdir(f'./{input_audio}/{group_name}/'):
print("Is not dir:", f'./{input_audio}/{group_name}/')
continue continue
try: if group_name in ignore_groups:
metadata = json.loads(open(metadata_path, "r", encoding="utf-8").read()) continue
except Exception as e: if only_groups and group_name not in only_groups:
missing["transcription"].append(str(metadata_path))
continue continue
if f'{dataset_name}/{speaker_id}' not in dataset: for speaker_id in tqdm(process_items(os.listdir(f'./{input_audio}/{group_name}/'), stride=stride), desc=f"Processing speaker in {group_name}"):
dataset.append(f'{dataset_name}/{speaker_id}') if not os.path.isdir(f'./{input_audio}/{group_name}/{speaker_id}'):
print("Is not dir:", f'./{input_audio}/{group_name}/{speaker_id}')
txts = []
wavs = []
use_slices = slice == True or (slice == "auto" and len(metadata.keys()) == 1) or dataset_name in ["LibriVox", "Audiobooks"]
for filename in sorted(metadata.keys()):
inpath = Path(f'./{input_audio}/{dataset_name}/{speaker_id}/{filename}')
if not inpath.exists():
missing["audio"].append(str(inpath))
continue continue
extension = os.path.splitext(filename)[-1][1:] if speaker_id in ignore_speakers:
fname = filename.replace(f'.{extension}', "") continue
if only_speakers and speaker_id not in only_speakers:
continue
waveform, sample_rate = None, None os.makedirs(f'./{output_group}/{group_name}/{speaker_id}/', exist_ok=True)
language = metadata[filename]["language"] if "language" in metadata[filename] else "en"
if len(metadata[filename]["segments"]) == 0 or not use_slices: if speaker_id == "Noise":
outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/{fname}.{extension}') for filename in sorted(os.listdir(f'./{input_audio}/{group_name}/{speaker_id}/')):
text = metadata[filename]["text"] inpath = Path(f'./{input_audio}/{group_name}/{speaker_id}/{filename}')
outpath = Path(f'./{output_group}/{group_name}/{speaker_id}/{filename}')
if len(text) == 0:
continue
if _replace_file_extension(outpath, audio_extension).exists():
continue
if waveform is None:
waveform, sample_rate = torchaudio.load(inpath)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
wavs.append((
outpath,
text,
language,
waveform,
sample_rate
))
else:
i = 0
for segment in metadata[filename]["segments"]:
id = pad(i, 4)
i = i + 1
outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/{fname}_{id}.{extension}')
text = segment["text"]
if len(text) == 0:
continue
if _replace_file_extension(outpath, audio_extension).exists(): if _replace_file_extension(outpath, audio_extension).exists():
continue continue
if waveform is None: waveform, sample_rate = torchaudio.load(inpath)
waveform, sample_rate = torchaudio.load(inpath)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
start = int(segment['start'] * sample_rate)
end = int(segment['end'] * sample_rate)
if start < 0:
start = 0
if end >= waveform.shape[-1]:
end = waveform.shape[-1] - 1
if end - start < 0:
continue
wavs.append((
outpath,
text,
language,
waveform[:, start:end],
sample_rate
))
if len(wavs) > 0:
for job in tqdm(wavs, desc=f"Quantizing: {speaker_id}"):
try:
outpath, text, language, waveform, sample_rate = job
phones = valle_phonemize( text, language=language )
qnt = valle_quantize(waveform, sr=sample_rate, device=device) qnt = valle_quantize(waveform, sr=sample_rate, device=device)
if cfg.audio_backend == "dac": if cfg.audio_backend == "dac":
@ -200,10 +113,6 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')):
"channels": qnt.channels, "channels": qnt.channels,
"padding": qnt.padding, "padding": qnt.padding,
"dac_version": "1.0.0", "dac_version": "1.0.0",
"text": text.strip(),
"phonemes": "".join(phones),
"language": language,
}, },
}) })
else: else:
@ -212,15 +121,168 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')):
"metadata": { "metadata": {
"original_length": waveform.shape[-1], "original_length": waveform.shape[-1],
"sample_rate": sample_rate, "sample_rate": sample_rate,
"text": text.strip(),
"phonemes": "".join(phones),
"language": language,
}, },
}) })
except Exception as e:
print(f"Failed to quantize: {outpath}:", e)
continue
open("./missing.json", 'w', encoding='utf-8').write(json.dumps(missing)) continue
open("./dataset_list.json", 'w', encoding='utf-8').write(json.dumps(dataset))
metadata_path = Path(f'./{input_metadata}/{group_name}/{speaker_id}/whisper.json')
if not metadata_path.exists():
missing["transcription"].append(str(metadata_path))
continue
try:
metadata = json.loads(open(metadata_path, "r", encoding="utf-8").read())
except Exception as e:
missing["transcription"].append(str(metadata_path))
continue
if f'{group_name}/{speaker_id}' not in dataset:
dataset.append(f'{group_name}/{speaker_id}')
txts = []
wavs = []
use_slices = slice == True or (slice == "auto" and len(metadata.keys()) == 1) or group_name in always_slice_groups
for filename in sorted(metadata.keys()):
inpath = Path(f'./{input_audio}/{group_name}/{speaker_id}/{filename}')
if not inpath.exists():
missing["audio"].append(str(inpath))
continue
extension = os.path.splitext(filename)[-1][1:]
fname = filename.replace(f'.{extension}', "")
waveform, sample_rate = None, None
language = language_map[group_name] if group_name in language_map else (metadata[filename]["language"] if "language" in metadata[filename] else "en")
if len(metadata[filename]["segments"]) == 0 or not use_slices:
outpath = Path(f'./{output_group}/{group_name}/{speaker_id}/{fname}.{extension}')
text = metadata[filename]["text"]
if len(text) == 0:
continue
if _replace_file_extension(outpath, audio_extension).exists():
continue
if waveform is None:
waveform, sample_rate = torchaudio.load(inpath)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
wavs.append((
outpath,
text,
language,
waveform,
sample_rate
))
else:
i = 0
for segment in metadata[filename]["segments"]:
id = pad(i, 4)
i = i + 1
outpath = Path(f'./{output_group}/{group_name}/{speaker_id}/{fname}_{id}.{extension}')
text = segment["text"]
if len(text) == 0:
continue
if _replace_file_extension(outpath, audio_extension).exists():
continue
if waveform is None:
waveform, sample_rate = torchaudio.load(inpath)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
start = int(segment['start'] * sample_rate)
end = int(segment['end'] * sample_rate)
if start < 0:
start = 0
if end >= waveform.shape[-1]:
end = waveform.shape[-1] - 1
if end - start < 0:
continue
wavs.append((
outpath,
text,
language,
waveform[:, start:end],
sample_rate
))
if len(wavs) > 0:
for job in tqdm(wavs, desc=f"Quantizing: {speaker_id}"):
try:
outpath, text, language, waveform, sample_rate = job
phones = valle_phonemize(text, language=language)
qnt = valle_quantize(waveform, sr=sample_rate, device=device)
if cfg.audio_backend == "dac":
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": qnt.original_length,
"sample_rate": qnt.sample_rate,
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
"chunk_length": qnt.chunk_length,
"channels": qnt.channels,
"padding": qnt.padding,
"dac_version": "1.0.0",
"text": text.strip(),
"phonemes": "".join(phones),
"language": language,
},
})
else:
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
"codes": qnt.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": waveform.shape[-1],
"sample_rate": sample_rate,
"text": text.strip(),
"phonemes": "".join(phones),
"language": language,
},
})
except Exception as e:
print(f"Failed to quantize: {outpath}:", e)
if raise_exceptions:
raise e
continue
open("./missing.json", 'w', encoding='utf-8').write(json.dumps(missing))
open("./dataset_list.json", 'w', encoding='utf-8').write(json.dumps(dataset))
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--audio-backend", type=str, default="encodec")
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--amp", action="store_true")
parser.add_argument("--input-audio", type=str, default="voices")
parser.add_argument("--input-metadata", type=str, default="metadata")
parser.add_argument("--output_group", type=str, default="training")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--raise-exceptions", action="store_true")
parser.add_argument("--stride", type=int, default=0)
parser.add_argument("--slice", type=str, default="auto")
args = parser.parse_args()
process_dataset( args )
if __name__ == "__main__":
main()

View File

@ -813,7 +813,7 @@ class Base(nn.Module):
inputs_embeds=x, inputs_embeds=x,
past_key_values=state, past_key_values=state,
position_ids=position_ids, position_ids=position_ids,
use_cache=True, use_cache=not self.training,
# return_dict=True, # return_dict=True,
) )
if self.n_experts > 1 and self.training: if self.n_experts > 1 and self.training:
@ -1350,15 +1350,6 @@ class Base(nn.Module):
x, m = list_to_tensor(x_list) x, m = list_to_tensor(x_list)
training = self.training training = self.training
# yes, there's a better way.
"""
training = False
for batch_index, batch in enumerate(inputs):
for name, input in batch:
if name == "targ":
training = True
"""
device = x.device device = x.device
batch_size = len(x_list) batch_size = len(x_list)