tweaked vall_e.emb.process to instead process audio one file at a time instead of all the files for a given speaker to avoid OOMing on less-memory-filled systems with --low-memory

This commit is contained in:
mrq 2024-08-06 14:24:40 -05:00
parent 9710b06b74
commit b6ba2cc8e7
2 changed files with 100 additions and 86 deletions

View File

@ -22,10 +22,66 @@ from .qnt import encode as quantize, _replace_file_extension
def pad(num, zeroes):
return str(num).zfill(zeroes+1)
def load_audio( path, device ):
waveform, sr = torchaudio.load( path )
if waveform.shape[0] > 1:
# mix channels
waveform = torch.mean(waveform, dim=0, keepdim=True)
return waveform.to(device=device), sr
def process_items( items, stride=0, stride_offset=0 ):
items = sorted( items )
return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ]
def process_job( outpath, text, language, waveform, sample_rate ):
phones = phonemize(text, language=language)
qnt = quantize(waveform, sr=sample_rate, device=waveform.device)
if cfg.audio_backend == "dac":
np.save(open(outpath, "wb"), {
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": qnt.original_length,
"sample_rate": qnt.sample_rate,
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
"chunk_length": qnt.chunk_length,
"channels": qnt.channels,
"padding": qnt.padding,
"dac_version": "1.0.0",
"text": text.strip(),
"phonemes": "".join(phones),
"language": language,
},
})
else:
np.save(open(outpath, "wb"), {
"codes": qnt.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": waveform.shape[-1],
"sample_rate": sample_rate,
"text": text.strip(),
"phonemes": "".join(phones),
"language": language,
},
})
def process_jobs( jobs, speaker_id="", raise_exceptions=True ):
if not jobs:
return
for job in tqdm(jobs, desc=f"Quantizing: {speaker_id}"):
outpath, text, language, waveform, sample_rate = job
try:
process_job( outpath, text, language, waveform, sample_rate )
except Exception as e:
print(f"Failed to quantize: {outpath}:", e)
if raise_exceptions:
raise e
continue
def process(
audio_backend="encodec",
input_audio="voices",
@ -36,6 +92,8 @@ def process(
stride_offset=0,
slice="auto",
low_memory=False,
device="cuda",
dtype="float16",
amp=False,
@ -73,6 +131,7 @@ def process(
only_speakers = [] # only process these speakers
always_slice_groups = [] # always slice from this group
audio_only = ["Noise"] # special pathway for processing audio only (without a transcription)
missing = {
"transcription": [],
@ -102,19 +161,20 @@ def process(
os.makedirs(f'./{output_dataset}/{group_name}/{speaker_id}/', exist_ok=True)
if speaker_id == "Noise":
if speaker_id in audio_only:
for filename in sorted(os.listdir(f'./{input_audio}/{group_name}/{speaker_id}/')):
inpath = Path(f'./{input_audio}/{group_name}/{speaker_id}/{filename}')
outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{filename}')
outpath = _replace_file_extension(outpath, audio_extension)
if _replace_file_extension(outpath, audio_extension).exists():
if outpath.exists():
continue
waveform, sample_rate = torchaudio.load(inpath)
waveform, sample_rate = load_audio( inpath, device )
qnt = quantize(waveform, sr=sample_rate, device=device)
if cfg.audio_backend == "dac":
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
np.save(open(outpath, "wb"), {
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": qnt.original_length,
@ -128,7 +188,7 @@ def process(
},
})
else:
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
np.save(open(outpath, "wb"), {
"codes": qnt.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": waveform.shape[-1],
@ -152,8 +212,7 @@ def process(
if f'{group_name}/{speaker_id}' not in dataset:
dataset.append(f'{group_name}/{speaker_id}')
txts = []
wavs = []
jobs = []
use_slices = slice == True or (slice == "auto" and len(metadata.keys()) == 1) or group_name in always_slice_groups
@ -171,26 +230,17 @@ def process(
if len(metadata[filename]["segments"]) == 0 or not use_slices:
outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}.{extension}')
outpath = _replace_file_extension(outpath, audio_extension)
text = metadata[filename]["text"]
if len(text) == 0:
continue
if _replace_file_extension(outpath, audio_extension).exists():
if len(text) == 0 or outpath.exists():
continue
# audio not already loaded, load it
if waveform is None:
waveform, sample_rate = torchaudio.load(inpath)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
waveform, sample_rate = load_audio( inpath, device )
wavs.append((
outpath,
text,
language,
waveform,
sample_rate
))
jobs.append(( outpath, text, language, waveform, sample_rate ))
else:
i = 0
for segment in metadata[filename]["segments"]:
@ -198,18 +248,15 @@ def process(
i = i + 1
outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}_{id}.{extension}')
outpath = _replace_file_extension(outpath, audio_extension)
text = segment["text"]
if len(text) == 0:
continue
if _replace_file_extension(outpath, audio_extension).exists():
if len(text) == 0 or outpath.exists():
continue
# audio not already loaded, load it
if waveform is None:
waveform, sample_rate = torchaudio.load(inpath)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
waveform, sample_rate = load_audio( inpath, device )
start = int(segment['start'] * sample_rate)
end = int(segment['end'] * sample_rate)
@ -222,59 +269,17 @@ def process(
if end - start < 0:
continue
wavs.append((
outpath,
text,
language,
waveform[:, start:end],
sample_rate
))
jobs.append(( outpath, text, language, waveform[:, start:end], sample_rate ))
if len(wavs) > 0:
for job in tqdm(wavs, desc=f"Quantizing: {speaker_id}"):
try:
outpath, text, language, waveform, sample_rate = job
# processes audio files one at a time
if low_memory:
process_jobs( jobs, speaker_id=f'{speaker_id}/{filename}', raise_exceptions=raise_exceptions )
jobs = []
phones = phonemize(text, language=language)
qnt = quantize(waveform, sr=sample_rate, device=device)
if cfg.audio_backend == "dac":
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": qnt.original_length,
"sample_rate": qnt.sample_rate,
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
"chunk_length": qnt.chunk_length,
"channels": qnt.channels,
"padding": qnt.padding,
"dac_version": "1.0.0",
"text": text.strip(),
"phonemes": "".join(phones),
"language": language,
},
})
else:
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
"codes": qnt.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": waveform.shape[-1],
"sample_rate": sample_rate,
"text": text.strip(),
"phonemes": "".join(phones),
"language": language,
},
})
except Exception as e:
print(f"Failed to quantize: {outpath}:", e)
torchaudio.save( waveform.cpu )
if raise_exceptions:
raise e
continue
# processes all audio files for a given speaker
if not low_memory:
process_jobs( jobs, speaker_id=speaker_id, raise_exceptions=raise_exceptions )
jobs = []
open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing))
open(f"./{output_dataset}/dataset.json", 'w', encoding='utf-8').write(json.dumps(dataset))
@ -287,6 +292,7 @@ def main():
parser.add_argument("--input-metadata", type=str, default="training/metadata")
parser.add_argument("--output-dataset", type=str, default="training/dataset")
parser.add_argument("--raise-exceptions", action="store_true")
parser.add_argument("--low-memory", action="store_true")
parser.add_argument("--stride", type=int, default=0)
parser.add_argument("--stride-offset", type=int, default=0)
parser.add_argument("--slice", type=str, default="auto")
@ -314,6 +320,8 @@ def main():
stride_offset=args.stride_offset,
slice=args.slice,
low_memory=args.low_memory,
device=args.device,
dtype=args.dtype,
amp=args.amp,

View File

@ -254,10 +254,17 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N
dummy = True
elif hasattr( metadata, "__dict__" ):
metadata = metadata.__dict__
metadata.pop("codes")
# generate object with copied metadata
artifact = DACFile( codes = codes, **metadata )
artifact = DACFile(
codes = codes,
chunk_length = metadata["chunk_length"],
original_length = metadata["original_length"],
input_db = metadata["input_db"],
channels = metadata["channels"],
sample_rate = metadata["sample_rate"],
padding = metadata["padding"],
dac_version = metadata["dac_version"],
)
artifact.dummy = dummy
# to-do: inject the sample rate encoded at, because we can actually decouple
@ -362,9 +369,8 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.mod
levels = 8 if model.model_type == "24khz" else None
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
# I guess it's safe to not encode in one chunk
#artifact = model.compress(signal, win_duration=None, verbose=False, n_quantizers=levels)
artifact = model.compress(signal, verbose=False, n_quantizers=levels)
artifact = model.compress(signal, win_duration=None, verbose=False, n_quantizers=levels)
#artifact = model.compress(signal, n_quantizers=levels)
return artifact.codes if not return_metadata else artifact
# AudioDec uses a different pathway