tweaked vall_e.emb.process to instead process audio one file at a time instead of all the files for a given speaker to avoid OOMing on less-memory-filled systems with --low-memory
This commit is contained in:
parent
9710b06b74
commit
b6ba2cc8e7
|
@ -22,10 +22,66 @@ from .qnt import encode as quantize, _replace_file_extension
|
||||||
def pad(num, zeroes):
|
def pad(num, zeroes):
|
||||||
return str(num).zfill(zeroes+1)
|
return str(num).zfill(zeroes+1)
|
||||||
|
|
||||||
|
def load_audio( path, device ):
|
||||||
|
waveform, sr = torchaudio.load( path )
|
||||||
|
if waveform.shape[0] > 1:
|
||||||
|
# mix channels
|
||||||
|
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
||||||
|
return waveform.to(device=device), sr
|
||||||
|
|
||||||
def process_items( items, stride=0, stride_offset=0 ):
|
def process_items( items, stride=0, stride_offset=0 ):
|
||||||
items = sorted( items )
|
items = sorted( items )
|
||||||
return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ]
|
return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ]
|
||||||
|
|
||||||
|
def process_job( outpath, text, language, waveform, sample_rate ):
|
||||||
|
phones = phonemize(text, language=language)
|
||||||
|
qnt = quantize(waveform, sr=sample_rate, device=waveform.device)
|
||||||
|
|
||||||
|
if cfg.audio_backend == "dac":
|
||||||
|
np.save(open(outpath, "wb"), {
|
||||||
|
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
|
||||||
|
"metadata": {
|
||||||
|
"original_length": qnt.original_length,
|
||||||
|
"sample_rate": qnt.sample_rate,
|
||||||
|
|
||||||
|
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
|
||||||
|
"chunk_length": qnt.chunk_length,
|
||||||
|
"channels": qnt.channels,
|
||||||
|
"padding": qnt.padding,
|
||||||
|
"dac_version": "1.0.0",
|
||||||
|
|
||||||
|
"text": text.strip(),
|
||||||
|
"phonemes": "".join(phones),
|
||||||
|
"language": language,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
np.save(open(outpath, "wb"), {
|
||||||
|
"codes": qnt.cpu().numpy().astype(np.uint16),
|
||||||
|
"metadata": {
|
||||||
|
"original_length": waveform.shape[-1],
|
||||||
|
"sample_rate": sample_rate,
|
||||||
|
|
||||||
|
"text": text.strip(),
|
||||||
|
"phonemes": "".join(phones),
|
||||||
|
"language": language,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
def process_jobs( jobs, speaker_id="", raise_exceptions=True ):
|
||||||
|
if not jobs:
|
||||||
|
return
|
||||||
|
|
||||||
|
for job in tqdm(jobs, desc=f"Quantizing: {speaker_id}"):
|
||||||
|
outpath, text, language, waveform, sample_rate = job
|
||||||
|
try:
|
||||||
|
process_job( outpath, text, language, waveform, sample_rate )
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to quantize: {outpath}:", e)
|
||||||
|
if raise_exceptions:
|
||||||
|
raise e
|
||||||
|
continue
|
||||||
|
|
||||||
def process(
|
def process(
|
||||||
audio_backend="encodec",
|
audio_backend="encodec",
|
||||||
input_audio="voices",
|
input_audio="voices",
|
||||||
|
@ -36,6 +92,8 @@ def process(
|
||||||
stride_offset=0,
|
stride_offset=0,
|
||||||
slice="auto",
|
slice="auto",
|
||||||
|
|
||||||
|
low_memory=False,
|
||||||
|
|
||||||
device="cuda",
|
device="cuda",
|
||||||
dtype="float16",
|
dtype="float16",
|
||||||
amp=False,
|
amp=False,
|
||||||
|
@ -73,6 +131,7 @@ def process(
|
||||||
only_speakers = [] # only process these speakers
|
only_speakers = [] # only process these speakers
|
||||||
|
|
||||||
always_slice_groups = [] # always slice from this group
|
always_slice_groups = [] # always slice from this group
|
||||||
|
audio_only = ["Noise"] # special pathway for processing audio only (without a transcription)
|
||||||
|
|
||||||
missing = {
|
missing = {
|
||||||
"transcription": [],
|
"transcription": [],
|
||||||
|
@ -102,19 +161,20 @@ def process(
|
||||||
|
|
||||||
os.makedirs(f'./{output_dataset}/{group_name}/{speaker_id}/', exist_ok=True)
|
os.makedirs(f'./{output_dataset}/{group_name}/{speaker_id}/', exist_ok=True)
|
||||||
|
|
||||||
if speaker_id == "Noise":
|
if speaker_id in audio_only:
|
||||||
for filename in sorted(os.listdir(f'./{input_audio}/{group_name}/{speaker_id}/')):
|
for filename in sorted(os.listdir(f'./{input_audio}/{group_name}/{speaker_id}/')):
|
||||||
inpath = Path(f'./{input_audio}/{group_name}/{speaker_id}/{filename}')
|
inpath = Path(f'./{input_audio}/{group_name}/{speaker_id}/{filename}')
|
||||||
outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{filename}')
|
outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{filename}')
|
||||||
|
outpath = _replace_file_extension(outpath, audio_extension)
|
||||||
|
|
||||||
if _replace_file_extension(outpath, audio_extension).exists():
|
if outpath.exists():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
waveform, sample_rate = torchaudio.load(inpath)
|
waveform, sample_rate = load_audio( inpath, device )
|
||||||
qnt = quantize(waveform, sr=sample_rate, device=device)
|
qnt = quantize(waveform, sr=sample_rate, device=device)
|
||||||
|
|
||||||
if cfg.audio_backend == "dac":
|
if cfg.audio_backend == "dac":
|
||||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
np.save(open(outpath, "wb"), {
|
||||||
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
|
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"original_length": qnt.original_length,
|
"original_length": qnt.original_length,
|
||||||
|
@ -128,7 +188,7 @@ def process(
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
else:
|
else:
|
||||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
np.save(open(outpath, "wb"), {
|
||||||
"codes": qnt.cpu().numpy().astype(np.uint16),
|
"codes": qnt.cpu().numpy().astype(np.uint16),
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"original_length": waveform.shape[-1],
|
"original_length": waveform.shape[-1],
|
||||||
|
@ -152,8 +212,7 @@ def process(
|
||||||
if f'{group_name}/{speaker_id}' not in dataset:
|
if f'{group_name}/{speaker_id}' not in dataset:
|
||||||
dataset.append(f'{group_name}/{speaker_id}')
|
dataset.append(f'{group_name}/{speaker_id}')
|
||||||
|
|
||||||
txts = []
|
jobs = []
|
||||||
wavs = []
|
|
||||||
|
|
||||||
use_slices = slice == True or (slice == "auto" and len(metadata.keys()) == 1) or group_name in always_slice_groups
|
use_slices = slice == True or (slice == "auto" and len(metadata.keys()) == 1) or group_name in always_slice_groups
|
||||||
|
|
||||||
|
@ -171,26 +230,17 @@ def process(
|
||||||
|
|
||||||
if len(metadata[filename]["segments"]) == 0 or not use_slices:
|
if len(metadata[filename]["segments"]) == 0 or not use_slices:
|
||||||
outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}.{extension}')
|
outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}.{extension}')
|
||||||
|
outpath = _replace_file_extension(outpath, audio_extension)
|
||||||
text = metadata[filename]["text"]
|
text = metadata[filename]["text"]
|
||||||
|
|
||||||
if len(text) == 0:
|
if len(text) == 0 or outpath.exists():
|
||||||
continue
|
|
||||||
|
|
||||||
if _replace_file_extension(outpath, audio_extension).exists():
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# audio not already loaded, load it
|
||||||
if waveform is None:
|
if waveform is None:
|
||||||
waveform, sample_rate = torchaudio.load(inpath)
|
waveform, sample_rate = load_audio( inpath, device )
|
||||||
if waveform.shape[0] > 1:
|
|
||||||
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
|
||||||
|
|
||||||
wavs.append((
|
jobs.append(( outpath, text, language, waveform, sample_rate ))
|
||||||
outpath,
|
|
||||||
text,
|
|
||||||
language,
|
|
||||||
waveform,
|
|
||||||
sample_rate
|
|
||||||
))
|
|
||||||
else:
|
else:
|
||||||
i = 0
|
i = 0
|
||||||
for segment in metadata[filename]["segments"]:
|
for segment in metadata[filename]["segments"]:
|
||||||
|
@ -198,18 +248,15 @@ def process(
|
||||||
i = i + 1
|
i = i + 1
|
||||||
|
|
||||||
outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}_{id}.{extension}')
|
outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}_{id}.{extension}')
|
||||||
|
outpath = _replace_file_extension(outpath, audio_extension)
|
||||||
text = segment["text"]
|
text = segment["text"]
|
||||||
|
|
||||||
if len(text) == 0:
|
if len(text) == 0 or outpath.exists():
|
||||||
continue
|
|
||||||
|
|
||||||
if _replace_file_extension(outpath, audio_extension).exists():
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# audio not already loaded, load it
|
||||||
if waveform is None:
|
if waveform is None:
|
||||||
waveform, sample_rate = torchaudio.load(inpath)
|
waveform, sample_rate = load_audio( inpath, device )
|
||||||
if waveform.shape[0] > 1:
|
|
||||||
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
|
||||||
|
|
||||||
start = int(segment['start'] * sample_rate)
|
start = int(segment['start'] * sample_rate)
|
||||||
end = int(segment['end'] * sample_rate)
|
end = int(segment['end'] * sample_rate)
|
||||||
|
@ -222,59 +269,17 @@ def process(
|
||||||
if end - start < 0:
|
if end - start < 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
wavs.append((
|
jobs.append(( outpath, text, language, waveform[:, start:end], sample_rate ))
|
||||||
outpath,
|
|
||||||
text,
|
|
||||||
language,
|
|
||||||
waveform[:, start:end],
|
|
||||||
sample_rate
|
|
||||||
))
|
|
||||||
|
|
||||||
if len(wavs) > 0:
|
# processes audio files one at a time
|
||||||
for job in tqdm(wavs, desc=f"Quantizing: {speaker_id}"):
|
if low_memory:
|
||||||
try:
|
process_jobs( jobs, speaker_id=f'{speaker_id}/{filename}', raise_exceptions=raise_exceptions )
|
||||||
outpath, text, language, waveform, sample_rate = job
|
jobs = []
|
||||||
|
|
||||||
phones = phonemize(text, language=language)
|
# processes all audio files for a given speaker
|
||||||
qnt = quantize(waveform, sr=sample_rate, device=device)
|
if not low_memory:
|
||||||
|
process_jobs( jobs, speaker_id=speaker_id, raise_exceptions=raise_exceptions )
|
||||||
|
jobs = []
|
||||||
if cfg.audio_backend == "dac":
|
|
||||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
|
||||||
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
|
|
||||||
"metadata": {
|
|
||||||
"original_length": qnt.original_length,
|
|
||||||
"sample_rate": qnt.sample_rate,
|
|
||||||
|
|
||||||
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
|
|
||||||
"chunk_length": qnt.chunk_length,
|
|
||||||
"channels": qnt.channels,
|
|
||||||
"padding": qnt.padding,
|
|
||||||
"dac_version": "1.0.0",
|
|
||||||
|
|
||||||
"text": text.strip(),
|
|
||||||
"phonemes": "".join(phones),
|
|
||||||
"language": language,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
|
||||||
"codes": qnt.cpu().numpy().astype(np.uint16),
|
|
||||||
"metadata": {
|
|
||||||
"original_length": waveform.shape[-1],
|
|
||||||
"sample_rate": sample_rate,
|
|
||||||
|
|
||||||
"text": text.strip(),
|
|
||||||
"phonemes": "".join(phones),
|
|
||||||
"language": language,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Failed to quantize: {outpath}:", e)
|
|
||||||
torchaudio.save( waveform.cpu )
|
|
||||||
if raise_exceptions:
|
|
||||||
raise e
|
|
||||||
continue
|
|
||||||
|
|
||||||
open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing))
|
open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing))
|
||||||
open(f"./{output_dataset}/dataset.json", 'w', encoding='utf-8').write(json.dumps(dataset))
|
open(f"./{output_dataset}/dataset.json", 'w', encoding='utf-8').write(json.dumps(dataset))
|
||||||
|
@ -287,6 +292,7 @@ def main():
|
||||||
parser.add_argument("--input-metadata", type=str, default="training/metadata")
|
parser.add_argument("--input-metadata", type=str, default="training/metadata")
|
||||||
parser.add_argument("--output-dataset", type=str, default="training/dataset")
|
parser.add_argument("--output-dataset", type=str, default="training/dataset")
|
||||||
parser.add_argument("--raise-exceptions", action="store_true")
|
parser.add_argument("--raise-exceptions", action="store_true")
|
||||||
|
parser.add_argument("--low-memory", action="store_true")
|
||||||
parser.add_argument("--stride", type=int, default=0)
|
parser.add_argument("--stride", type=int, default=0)
|
||||||
parser.add_argument("--stride-offset", type=int, default=0)
|
parser.add_argument("--stride-offset", type=int, default=0)
|
||||||
parser.add_argument("--slice", type=str, default="auto")
|
parser.add_argument("--slice", type=str, default="auto")
|
||||||
|
@ -314,6 +320,8 @@ def main():
|
||||||
stride_offset=args.stride_offset,
|
stride_offset=args.stride_offset,
|
||||||
slice=args.slice,
|
slice=args.slice,
|
||||||
|
|
||||||
|
low_memory=args.low_memory,
|
||||||
|
|
||||||
device=args.device,
|
device=args.device,
|
||||||
dtype=args.dtype,
|
dtype=args.dtype,
|
||||||
amp=args.amp,
|
amp=args.amp,
|
||||||
|
|
|
@ -254,10 +254,17 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N
|
||||||
dummy = True
|
dummy = True
|
||||||
elif hasattr( metadata, "__dict__" ):
|
elif hasattr( metadata, "__dict__" ):
|
||||||
metadata = metadata.__dict__
|
metadata = metadata.__dict__
|
||||||
metadata.pop("codes")
|
|
||||||
|
|
||||||
# generate object with copied metadata
|
# generate object with copied metadata
|
||||||
artifact = DACFile( codes = codes, **metadata )
|
artifact = DACFile(
|
||||||
|
codes = codes,
|
||||||
|
chunk_length = metadata["chunk_length"],
|
||||||
|
original_length = metadata["original_length"],
|
||||||
|
input_db = metadata["input_db"],
|
||||||
|
channels = metadata["channels"],
|
||||||
|
sample_rate = metadata["sample_rate"],
|
||||||
|
padding = metadata["padding"],
|
||||||
|
dac_version = metadata["dac_version"],
|
||||||
|
)
|
||||||
artifact.dummy = dummy
|
artifact.dummy = dummy
|
||||||
|
|
||||||
# to-do: inject the sample rate encoded at, because we can actually decouple
|
# to-do: inject the sample rate encoded at, because we can actually decouple
|
||||||
|
@ -362,9 +369,8 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.mod
|
||||||
levels = 8 if model.model_type == "24khz" else None
|
levels = 8 if model.model_type == "24khz" else None
|
||||||
|
|
||||||
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
|
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
|
||||||
# I guess it's safe to not encode in one chunk
|
artifact = model.compress(signal, win_duration=None, verbose=False, n_quantizers=levels)
|
||||||
#artifact = model.compress(signal, win_duration=None, verbose=False, n_quantizers=levels)
|
#artifact = model.compress(signal, n_quantizers=levels)
|
||||||
artifact = model.compress(signal, verbose=False, n_quantizers=levels)
|
|
||||||
return artifact.codes if not return_metadata else artifact
|
return artifact.codes if not return_metadata else artifact
|
||||||
|
|
||||||
# AudioDec uses a different pathway
|
# AudioDec uses a different pathway
|
||||||
|
|
Loading…
Reference in New Issue
Block a user