tweaked vall_e.emb.process to instead process audio one file at a time instead of all the files for a given speaker to avoid OOMing on less-memory-filled systems with --low-memory
This commit is contained in:
parent
9710b06b74
commit
b6ba2cc8e7
|
@ -22,10 +22,66 @@ from .qnt import encode as quantize, _replace_file_extension
|
|||
def pad(num, zeroes):
|
||||
return str(num).zfill(zeroes+1)
|
||||
|
||||
def load_audio( path, device ):
|
||||
waveform, sr = torchaudio.load( path )
|
||||
if waveform.shape[0] > 1:
|
||||
# mix channels
|
||||
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
||||
return waveform.to(device=device), sr
|
||||
|
||||
def process_items( items, stride=0, stride_offset=0 ):
|
||||
items = sorted( items )
|
||||
return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ]
|
||||
|
||||
def process_job( outpath, text, language, waveform, sample_rate ):
|
||||
phones = phonemize(text, language=language)
|
||||
qnt = quantize(waveform, sr=sample_rate, device=waveform.device)
|
||||
|
||||
if cfg.audio_backend == "dac":
|
||||
np.save(open(outpath, "wb"), {
|
||||
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
|
||||
"metadata": {
|
||||
"original_length": qnt.original_length,
|
||||
"sample_rate": qnt.sample_rate,
|
||||
|
||||
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
|
||||
"chunk_length": qnt.chunk_length,
|
||||
"channels": qnt.channels,
|
||||
"padding": qnt.padding,
|
||||
"dac_version": "1.0.0",
|
||||
|
||||
"text": text.strip(),
|
||||
"phonemes": "".join(phones),
|
||||
"language": language,
|
||||
},
|
||||
})
|
||||
else:
|
||||
np.save(open(outpath, "wb"), {
|
||||
"codes": qnt.cpu().numpy().astype(np.uint16),
|
||||
"metadata": {
|
||||
"original_length": waveform.shape[-1],
|
||||
"sample_rate": sample_rate,
|
||||
|
||||
"text": text.strip(),
|
||||
"phonemes": "".join(phones),
|
||||
"language": language,
|
||||
},
|
||||
})
|
||||
|
||||
def process_jobs( jobs, speaker_id="", raise_exceptions=True ):
|
||||
if not jobs:
|
||||
return
|
||||
|
||||
for job in tqdm(jobs, desc=f"Quantizing: {speaker_id}"):
|
||||
outpath, text, language, waveform, sample_rate = job
|
||||
try:
|
||||
process_job( outpath, text, language, waveform, sample_rate )
|
||||
except Exception as e:
|
||||
print(f"Failed to quantize: {outpath}:", e)
|
||||
if raise_exceptions:
|
||||
raise e
|
||||
continue
|
||||
|
||||
def process(
|
||||
audio_backend="encodec",
|
||||
input_audio="voices",
|
||||
|
@ -36,6 +92,8 @@ def process(
|
|||
stride_offset=0,
|
||||
slice="auto",
|
||||
|
||||
low_memory=False,
|
||||
|
||||
device="cuda",
|
||||
dtype="float16",
|
||||
amp=False,
|
||||
|
@ -73,6 +131,7 @@ def process(
|
|||
only_speakers = [] # only process these speakers
|
||||
|
||||
always_slice_groups = [] # always slice from this group
|
||||
audio_only = ["Noise"] # special pathway for processing audio only (without a transcription)
|
||||
|
||||
missing = {
|
||||
"transcription": [],
|
||||
|
@ -102,19 +161,20 @@ def process(
|
|||
|
||||
os.makedirs(f'./{output_dataset}/{group_name}/{speaker_id}/', exist_ok=True)
|
||||
|
||||
if speaker_id == "Noise":
|
||||
if speaker_id in audio_only:
|
||||
for filename in sorted(os.listdir(f'./{input_audio}/{group_name}/{speaker_id}/')):
|
||||
inpath = Path(f'./{input_audio}/{group_name}/{speaker_id}/{filename}')
|
||||
outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{filename}')
|
||||
outpath = _replace_file_extension(outpath, audio_extension)
|
||||
|
||||
if _replace_file_extension(outpath, audio_extension).exists():
|
||||
if outpath.exists():
|
||||
continue
|
||||
|
||||
waveform, sample_rate = torchaudio.load(inpath)
|
||||
waveform, sample_rate = load_audio( inpath, device )
|
||||
qnt = quantize(waveform, sr=sample_rate, device=device)
|
||||
|
||||
if cfg.audio_backend == "dac":
|
||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
||||
np.save(open(outpath, "wb"), {
|
||||
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
|
||||
"metadata": {
|
||||
"original_length": qnt.original_length,
|
||||
|
@ -128,7 +188,7 @@ def process(
|
|||
},
|
||||
})
|
||||
else:
|
||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
||||
np.save(open(outpath, "wb"), {
|
||||
"codes": qnt.cpu().numpy().astype(np.uint16),
|
||||
"metadata": {
|
||||
"original_length": waveform.shape[-1],
|
||||
|
@ -152,8 +212,7 @@ def process(
|
|||
if f'{group_name}/{speaker_id}' not in dataset:
|
||||
dataset.append(f'{group_name}/{speaker_id}')
|
||||
|
||||
txts = []
|
||||
wavs = []
|
||||
jobs = []
|
||||
|
||||
use_slices = slice == True or (slice == "auto" and len(metadata.keys()) == 1) or group_name in always_slice_groups
|
||||
|
||||
|
@ -171,26 +230,17 @@ def process(
|
|||
|
||||
if len(metadata[filename]["segments"]) == 0 or not use_slices:
|
||||
outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}.{extension}')
|
||||
outpath = _replace_file_extension(outpath, audio_extension)
|
||||
text = metadata[filename]["text"]
|
||||
|
||||
if len(text) == 0:
|
||||
continue
|
||||
|
||||
if _replace_file_extension(outpath, audio_extension).exists():
|
||||
if len(text) == 0 or outpath.exists():
|
||||
continue
|
||||
|
||||
# audio not already loaded, load it
|
||||
if waveform is None:
|
||||
waveform, sample_rate = torchaudio.load(inpath)
|
||||
if waveform.shape[0] > 1:
|
||||
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
||||
waveform, sample_rate = load_audio( inpath, device )
|
||||
|
||||
wavs.append((
|
||||
outpath,
|
||||
text,
|
||||
language,
|
||||
waveform,
|
||||
sample_rate
|
||||
))
|
||||
jobs.append(( outpath, text, language, waveform, sample_rate ))
|
||||
else:
|
||||
i = 0
|
||||
for segment in metadata[filename]["segments"]:
|
||||
|
@ -198,18 +248,15 @@ def process(
|
|||
i = i + 1
|
||||
|
||||
outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}_{id}.{extension}')
|
||||
outpath = _replace_file_extension(outpath, audio_extension)
|
||||
text = segment["text"]
|
||||
|
||||
if len(text) == 0:
|
||||
continue
|
||||
|
||||
if _replace_file_extension(outpath, audio_extension).exists():
|
||||
if len(text) == 0 or outpath.exists():
|
||||
continue
|
||||
|
||||
# audio not already loaded, load it
|
||||
if waveform is None:
|
||||
waveform, sample_rate = torchaudio.load(inpath)
|
||||
if waveform.shape[0] > 1:
|
||||
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
||||
waveform, sample_rate = load_audio( inpath, device )
|
||||
|
||||
start = int(segment['start'] * sample_rate)
|
||||
end = int(segment['end'] * sample_rate)
|
||||
|
@ -222,59 +269,17 @@ def process(
|
|||
if end - start < 0:
|
||||
continue
|
||||
|
||||
wavs.append((
|
||||
outpath,
|
||||
text,
|
||||
language,
|
||||
waveform[:, start:end],
|
||||
sample_rate
|
||||
))
|
||||
jobs.append(( outpath, text, language, waveform[:, start:end], sample_rate ))
|
||||
|
||||
if len(wavs) > 0:
|
||||
for job in tqdm(wavs, desc=f"Quantizing: {speaker_id}"):
|
||||
try:
|
||||
outpath, text, language, waveform, sample_rate = job
|
||||
# processes audio files one at a time
|
||||
if low_memory:
|
||||
process_jobs( jobs, speaker_id=f'{speaker_id}/{filename}', raise_exceptions=raise_exceptions )
|
||||
jobs = []
|
||||
|
||||
phones = phonemize(text, language=language)
|
||||
qnt = quantize(waveform, sr=sample_rate, device=device)
|
||||
|
||||
|
||||
if cfg.audio_backend == "dac":
|
||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
||||
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
|
||||
"metadata": {
|
||||
"original_length": qnt.original_length,
|
||||
"sample_rate": qnt.sample_rate,
|
||||
|
||||
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
|
||||
"chunk_length": qnt.chunk_length,
|
||||
"channels": qnt.channels,
|
||||
"padding": qnt.padding,
|
||||
"dac_version": "1.0.0",
|
||||
|
||||
"text": text.strip(),
|
||||
"phonemes": "".join(phones),
|
||||
"language": language,
|
||||
},
|
||||
})
|
||||
else:
|
||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
||||
"codes": qnt.cpu().numpy().astype(np.uint16),
|
||||
"metadata": {
|
||||
"original_length": waveform.shape[-1],
|
||||
"sample_rate": sample_rate,
|
||||
|
||||
"text": text.strip(),
|
||||
"phonemes": "".join(phones),
|
||||
"language": language,
|
||||
},
|
||||
})
|
||||
except Exception as e:
|
||||
print(f"Failed to quantize: {outpath}:", e)
|
||||
torchaudio.save( waveform.cpu )
|
||||
if raise_exceptions:
|
||||
raise e
|
||||
continue
|
||||
# processes all audio files for a given speaker
|
||||
if not low_memory:
|
||||
process_jobs( jobs, speaker_id=speaker_id, raise_exceptions=raise_exceptions )
|
||||
jobs = []
|
||||
|
||||
open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing))
|
||||
open(f"./{output_dataset}/dataset.json", 'w', encoding='utf-8').write(json.dumps(dataset))
|
||||
|
@ -287,6 +292,7 @@ def main():
|
|||
parser.add_argument("--input-metadata", type=str, default="training/metadata")
|
||||
parser.add_argument("--output-dataset", type=str, default="training/dataset")
|
||||
parser.add_argument("--raise-exceptions", action="store_true")
|
||||
parser.add_argument("--low-memory", action="store_true")
|
||||
parser.add_argument("--stride", type=int, default=0)
|
||||
parser.add_argument("--stride-offset", type=int, default=0)
|
||||
parser.add_argument("--slice", type=str, default="auto")
|
||||
|
@ -314,6 +320,8 @@ def main():
|
|||
stride_offset=args.stride_offset,
|
||||
slice=args.slice,
|
||||
|
||||
low_memory=args.low_memory,
|
||||
|
||||
device=args.device,
|
||||
dtype=args.dtype,
|
||||
amp=args.amp,
|
||||
|
|
|
@ -254,10 +254,17 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N
|
|||
dummy = True
|
||||
elif hasattr( metadata, "__dict__" ):
|
||||
metadata = metadata.__dict__
|
||||
metadata.pop("codes")
|
||||
|
||||
# generate object with copied metadata
|
||||
artifact = DACFile( codes = codes, **metadata )
|
||||
artifact = DACFile(
|
||||
codes = codes,
|
||||
chunk_length = metadata["chunk_length"],
|
||||
original_length = metadata["original_length"],
|
||||
input_db = metadata["input_db"],
|
||||
channels = metadata["channels"],
|
||||
sample_rate = metadata["sample_rate"],
|
||||
padding = metadata["padding"],
|
||||
dac_version = metadata["dac_version"],
|
||||
)
|
||||
artifact.dummy = dummy
|
||||
|
||||
# to-do: inject the sample rate encoded at, because we can actually decouple
|
||||
|
@ -362,9 +369,8 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.mod
|
|||
levels = 8 if model.model_type == "24khz" else None
|
||||
|
||||
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
|
||||
# I guess it's safe to not encode in one chunk
|
||||
#artifact = model.compress(signal, win_duration=None, verbose=False, n_quantizers=levels)
|
||||
artifact = model.compress(signal, verbose=False, n_quantizers=levels)
|
||||
artifact = model.compress(signal, win_duration=None, verbose=False, n_quantizers=levels)
|
||||
#artifact = model.compress(signal, n_quantizers=levels)
|
||||
return artifact.codes if not return_metadata else artifact
|
||||
|
||||
# AudioDec uses a different pathway
|
||||
|
|
Loading…
Reference in New Issue
Block a user