tweaked vall_e.emb.process to instead process audio one file at a time instead of all the files for a given speaker to avoid OOMing on less-memory-filled systems with --low-memory

This commit is contained in:
mrq 2024-08-06 14:24:40 -05:00
parent 9710b06b74
commit b6ba2cc8e7
2 changed files with 100 additions and 86 deletions

View File

@ -22,10 +22,66 @@ from .qnt import encode as quantize, _replace_file_extension
def pad(num, zeroes): def pad(num, zeroes):
return str(num).zfill(zeroes+1) return str(num).zfill(zeroes+1)
def load_audio( path, device ):
waveform, sr = torchaudio.load( path )
if waveform.shape[0] > 1:
# mix channels
waveform = torch.mean(waveform, dim=0, keepdim=True)
return waveform.to(device=device), sr
def process_items( items, stride=0, stride_offset=0 ): def process_items( items, stride=0, stride_offset=0 ):
items = sorted( items ) items = sorted( items )
return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ] return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ]
def process_job( outpath, text, language, waveform, sample_rate ):
phones = phonemize(text, language=language)
qnt = quantize(waveform, sr=sample_rate, device=waveform.device)
if cfg.audio_backend == "dac":
np.save(open(outpath, "wb"), {
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": qnt.original_length,
"sample_rate": qnt.sample_rate,
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
"chunk_length": qnt.chunk_length,
"channels": qnt.channels,
"padding": qnt.padding,
"dac_version": "1.0.0",
"text": text.strip(),
"phonemes": "".join(phones),
"language": language,
},
})
else:
np.save(open(outpath, "wb"), {
"codes": qnt.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": waveform.shape[-1],
"sample_rate": sample_rate,
"text": text.strip(),
"phonemes": "".join(phones),
"language": language,
},
})
def process_jobs( jobs, speaker_id="", raise_exceptions=True ):
if not jobs:
return
for job in tqdm(jobs, desc=f"Quantizing: {speaker_id}"):
outpath, text, language, waveform, sample_rate = job
try:
process_job( outpath, text, language, waveform, sample_rate )
except Exception as e:
print(f"Failed to quantize: {outpath}:", e)
if raise_exceptions:
raise e
continue
def process( def process(
audio_backend="encodec", audio_backend="encodec",
input_audio="voices", input_audio="voices",
@ -36,6 +92,8 @@ def process(
stride_offset=0, stride_offset=0,
slice="auto", slice="auto",
low_memory=False,
device="cuda", device="cuda",
dtype="float16", dtype="float16",
amp=False, amp=False,
@ -73,6 +131,7 @@ def process(
only_speakers = [] # only process these speakers only_speakers = [] # only process these speakers
always_slice_groups = [] # always slice from this group always_slice_groups = [] # always slice from this group
audio_only = ["Noise"] # special pathway for processing audio only (without a transcription)
missing = { missing = {
"transcription": [], "transcription": [],
@ -102,19 +161,20 @@ def process(
os.makedirs(f'./{output_dataset}/{group_name}/{speaker_id}/', exist_ok=True) os.makedirs(f'./{output_dataset}/{group_name}/{speaker_id}/', exist_ok=True)
if speaker_id == "Noise": if speaker_id in audio_only:
for filename in sorted(os.listdir(f'./{input_audio}/{group_name}/{speaker_id}/')): for filename in sorted(os.listdir(f'./{input_audio}/{group_name}/{speaker_id}/')):
inpath = Path(f'./{input_audio}/{group_name}/{speaker_id}/{filename}') inpath = Path(f'./{input_audio}/{group_name}/{speaker_id}/{filename}')
outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{filename}') outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{filename}')
outpath = _replace_file_extension(outpath, audio_extension)
if _replace_file_extension(outpath, audio_extension).exists(): if outpath.exists():
continue continue
waveform, sample_rate = torchaudio.load(inpath) waveform, sample_rate = load_audio( inpath, device )
qnt = quantize(waveform, sr=sample_rate, device=device) qnt = quantize(waveform, sr=sample_rate, device=device)
if cfg.audio_backend == "dac": if cfg.audio_backend == "dac":
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { np.save(open(outpath, "wb"), {
"codes": qnt.codes.cpu().numpy().astype(np.uint16), "codes": qnt.codes.cpu().numpy().astype(np.uint16),
"metadata": { "metadata": {
"original_length": qnt.original_length, "original_length": qnt.original_length,
@ -128,7 +188,7 @@ def process(
}, },
}) })
else: else:
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { np.save(open(outpath, "wb"), {
"codes": qnt.cpu().numpy().astype(np.uint16), "codes": qnt.cpu().numpy().astype(np.uint16),
"metadata": { "metadata": {
"original_length": waveform.shape[-1], "original_length": waveform.shape[-1],
@ -152,8 +212,7 @@ def process(
if f'{group_name}/{speaker_id}' not in dataset: if f'{group_name}/{speaker_id}' not in dataset:
dataset.append(f'{group_name}/{speaker_id}') dataset.append(f'{group_name}/{speaker_id}')
txts = [] jobs = []
wavs = []
use_slices = slice == True or (slice == "auto" and len(metadata.keys()) == 1) or group_name in always_slice_groups use_slices = slice == True or (slice == "auto" and len(metadata.keys()) == 1) or group_name in always_slice_groups
@ -171,26 +230,17 @@ def process(
if len(metadata[filename]["segments"]) == 0 or not use_slices: if len(metadata[filename]["segments"]) == 0 or not use_slices:
outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}.{extension}') outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}.{extension}')
outpath = _replace_file_extension(outpath, audio_extension)
text = metadata[filename]["text"] text = metadata[filename]["text"]
if len(text) == 0: if len(text) == 0 or outpath.exists():
continue
if _replace_file_extension(outpath, audio_extension).exists():
continue continue
# audio not already loaded, load it
if waveform is None: if waveform is None:
waveform, sample_rate = torchaudio.load(inpath) waveform, sample_rate = load_audio( inpath, device )
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
wavs.append(( jobs.append(( outpath, text, language, waveform, sample_rate ))
outpath,
text,
language,
waveform,
sample_rate
))
else: else:
i = 0 i = 0
for segment in metadata[filename]["segments"]: for segment in metadata[filename]["segments"]:
@ -198,18 +248,15 @@ def process(
i = i + 1 i = i + 1
outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}_{id}.{extension}') outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}_{id}.{extension}')
outpath = _replace_file_extension(outpath, audio_extension)
text = segment["text"] text = segment["text"]
if len(text) == 0: if len(text) == 0 or outpath.exists():
continue
if _replace_file_extension(outpath, audio_extension).exists():
continue continue
# audio not already loaded, load it
if waveform is None: if waveform is None:
waveform, sample_rate = torchaudio.load(inpath) waveform, sample_rate = load_audio( inpath, device )
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
start = int(segment['start'] * sample_rate) start = int(segment['start'] * sample_rate)
end = int(segment['end'] * sample_rate) end = int(segment['end'] * sample_rate)
@ -222,59 +269,17 @@ def process(
if end - start < 0: if end - start < 0:
continue continue
wavs.append(( jobs.append(( outpath, text, language, waveform[:, start:end], sample_rate ))
outpath,
text,
language,
waveform[:, start:end],
sample_rate
))
if len(wavs) > 0: # processes audio files one at a time
for job in tqdm(wavs, desc=f"Quantizing: {speaker_id}"): if low_memory:
try: process_jobs( jobs, speaker_id=f'{speaker_id}/{filename}', raise_exceptions=raise_exceptions )
outpath, text, language, waveform, sample_rate = job jobs = []
phones = phonemize(text, language=language) # processes all audio files for a given speaker
qnt = quantize(waveform, sr=sample_rate, device=device) if not low_memory:
process_jobs( jobs, speaker_id=speaker_id, raise_exceptions=raise_exceptions )
jobs = []
if cfg.audio_backend == "dac":
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": qnt.original_length,
"sample_rate": qnt.sample_rate,
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
"chunk_length": qnt.chunk_length,
"channels": qnt.channels,
"padding": qnt.padding,
"dac_version": "1.0.0",
"text": text.strip(),
"phonemes": "".join(phones),
"language": language,
},
})
else:
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
"codes": qnt.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": waveform.shape[-1],
"sample_rate": sample_rate,
"text": text.strip(),
"phonemes": "".join(phones),
"language": language,
},
})
except Exception as e:
print(f"Failed to quantize: {outpath}:", e)
torchaudio.save( waveform.cpu )
if raise_exceptions:
raise e
continue
open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing)) open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing))
open(f"./{output_dataset}/dataset.json", 'w', encoding='utf-8').write(json.dumps(dataset)) open(f"./{output_dataset}/dataset.json", 'w', encoding='utf-8').write(json.dumps(dataset))
@ -287,6 +292,7 @@ def main():
parser.add_argument("--input-metadata", type=str, default="training/metadata") parser.add_argument("--input-metadata", type=str, default="training/metadata")
parser.add_argument("--output-dataset", type=str, default="training/dataset") parser.add_argument("--output-dataset", type=str, default="training/dataset")
parser.add_argument("--raise-exceptions", action="store_true") parser.add_argument("--raise-exceptions", action="store_true")
parser.add_argument("--low-memory", action="store_true")
parser.add_argument("--stride", type=int, default=0) parser.add_argument("--stride", type=int, default=0)
parser.add_argument("--stride-offset", type=int, default=0) parser.add_argument("--stride-offset", type=int, default=0)
parser.add_argument("--slice", type=str, default="auto") parser.add_argument("--slice", type=str, default="auto")
@ -314,6 +320,8 @@ def main():
stride_offset=args.stride_offset, stride_offset=args.stride_offset,
slice=args.slice, slice=args.slice,
low_memory=args.low_memory,
device=args.device, device=args.device,
dtype=args.dtype, dtype=args.dtype,
amp=args.amp, amp=args.amp,

View File

@ -254,10 +254,17 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N
dummy = True dummy = True
elif hasattr( metadata, "__dict__" ): elif hasattr( metadata, "__dict__" ):
metadata = metadata.__dict__ metadata = metadata.__dict__
metadata.pop("codes")
# generate object with copied metadata # generate object with copied metadata
artifact = DACFile( codes = codes, **metadata ) artifact = DACFile(
codes = codes,
chunk_length = metadata["chunk_length"],
original_length = metadata["original_length"],
input_db = metadata["input_db"],
channels = metadata["channels"],
sample_rate = metadata["sample_rate"],
padding = metadata["padding"],
dac_version = metadata["dac_version"],
)
artifact.dummy = dummy artifact.dummy = dummy
# to-do: inject the sample rate encoded at, because we can actually decouple # to-do: inject the sample rate encoded at, because we can actually decouple
@ -362,9 +369,8 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.mod
levels = 8 if model.model_type == "24khz" else None levels = 8 if model.model_type == "24khz" else None
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp): with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
# I guess it's safe to not encode in one chunk artifact = model.compress(signal, win_duration=None, verbose=False, n_quantizers=levels)
#artifact = model.compress(signal, win_duration=None, verbose=False, n_quantizers=levels) #artifact = model.compress(signal, n_quantizers=levels)
artifact = model.compress(signal, verbose=False, n_quantizers=levels)
return artifact.codes if not return_metadata else artifact return artifact.codes if not return_metadata else artifact
# AudioDec uses a different pathway # AudioDec uses a different pathway