tweaks and things

This commit is contained in:
mrq 2024-08-06 08:17:25 -05:00
parent 8bac8fe902
commit 9710b06b74
4 changed files with 55 additions and 24 deletions

View File

@ -119,8 +119,10 @@ def process(
continue continue
inpath = Path(f'./{input_audio}/{group_name}/{speaker_id}/{book_id}/{filename}') inpath = Path(f'./{input_audio}/{group_name}/{speaker_id}/{book_id}/{filename}')
if not inpath.exists(): textpath = _replace_file_extension(inpath, ".original.txt")
if not inpath.exists() or not textpath.exists():
missing["audio"].append(str(inpath)) missing["audio"].append(str(inpath))
continue
extension = os.path.splitext(filename)[-1][1:] extension = os.path.splitext(filename)[-1][1:]
fname = filename.replace(f'.{extension}', "") fname = filename.replace(f'.{extension}', "")
@ -129,7 +131,7 @@ def process(
language = "en" language = "en"
outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}.{extension}') outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}.{extension}')
text = open(_replace_file_extension(inpath, ".original.txt"), "r", encoding="utf-8").read() text = open(textpath, "r", encoding="utf-8").read()
if len(text) == 0: if len(text) == 0:
continue continue
@ -214,6 +216,13 @@ def main():
args = parser.parse_args() args = parser.parse_args()
# do some assumption magic
# to-do: find a nice way to spawn multiple processes where tqdm plays nicely
if args.device.isnumeric():
args.stride = torch.cuda.device_count()
args.stride_offset = int(args.device)
args.device = f'cuda:{args.device}'
process( process(
audio_backend=args.audio_backend, audio_backend=args.audio_backend,
input_audio=args.input_audio, input_audio=args.input_audio,

View File

@ -15,6 +15,10 @@ from pathlib import Path
from ..config import cfg from ..config import cfg
# need to validate if this is safe to import before modifying the config
from .g2p import encode as phonemize
from .qnt import encode as quantize, _replace_file_extension
def pad(num, zeroes): def pad(num, zeroes):
return str(num).zfill(zeroes+1) return str(num).zfill(zeroes+1)
@ -58,11 +62,6 @@ def process(
cfg.inference.weight_dtype = dtype # "bfloat16" cfg.inference.weight_dtype = dtype # "bfloat16"
cfg.inference.amp = amp # False cfg.inference.amp = amp # False
# import after because we've overriden the config above
# need to validate if this is even necessary anymore
from .g2p import encode as phonemize
from .qnt import encode as quantize, _replace_file_extension
output_dataset = f"{output_dataset}/{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}" # "training" output_dataset = f"{output_dataset}/{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}" # "training"
language_map = {} # k = group, v = language language_map = {} # k = group, v = language
@ -272,6 +271,7 @@ def process(
}) })
except Exception as e: except Exception as e:
print(f"Failed to quantize: {outpath}:", e) print(f"Failed to quantize: {outpath}:", e)
torchaudio.save( waveform.cpu )
if raise_exceptions: if raise_exceptions:
raise e raise e
continue continue
@ -283,19 +283,27 @@ def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--audio-backend", type=str, default="encodec") parser.add_argument("--audio-backend", type=str, default="encodec")
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--amp", action="store_true")
parser.add_argument("--input-audio", type=str, default="voices") parser.add_argument("--input-audio", type=str, default="voices")
parser.add_argument("--input-metadata", type=str, default="training/metadata") parser.add_argument("--input-metadata", type=str, default="training/metadata")
parser.add_argument("--output-dataset", type=str, default="training/dataset") parser.add_argument("--output-dataset", type=str, default="training/dataset")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--raise-exceptions", action="store_true") parser.add_argument("--raise-exceptions", action="store_true")
parser.add_argument("--stride", type=int, default=0) parser.add_argument("--stride", type=int, default=0)
parser.add_argument("--stride-offset", type=int, default=0) parser.add_argument("--stride-offset", type=int, default=0)
parser.add_argument("--slice", type=str, default="auto") parser.add_argument("--slice", type=str, default="auto")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--amp", action="store_true")
args = parser.parse_args() args = parser.parse_args()
# do some assumption magic
# to-do: find a nice way to spawn multiple processes where tqdm plays nicely
if args.device.isnumeric():
args.stride = torch.cuda.device_count()
args.stride_offset = int(args.device)
args.device = f'cuda:{args.device}'
process( process(
audio_backend=args.audio_backend, audio_backend=args.audio_backend,
input_audio=args.input_audio, input_audio=args.input_audio,

View File

@ -252,18 +252,12 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N
dac_version='1.0.0', dac_version='1.0.0',
) )
dummy = True dummy = True
elif hasattr( metadata, "__dict__" ):
metadata = metadata.__dict__
metadata.pop("codes")
# generate object with copied metadata # generate object with copied metadata
artifact = DACFile( artifact = DACFile( codes = codes, **metadata )
codes = codes,
# yes I can **kwargs from a dict but what if I want to pass the actual DACFile.metadata from elsewhere
chunk_length = metadata["chunk_length"] if isinstance(metadata, dict) else metadata.chunk_length,
original_length = metadata["original_length"] if isinstance(metadata, dict) else metadata.original_length,
input_db = metadata["input_db"] if isinstance(metadata, dict) else metadata.input_db,
channels = metadata["channels"] if isinstance(metadata, dict) else metadata.channels,
sample_rate = metadata["sample_rate"] if isinstance(metadata, dict) else metadata.sample_rate,
padding = metadata["padding"] if isinstance(metadata, dict) else metadata.padding,
dac_version = metadata["dac_version"] if isinstance(metadata, dict) else metadata.dac_version,
)
artifact.dummy = dummy artifact.dummy = dummy
# to-do: inject the sample rate encoded at, because we can actually decouple # to-do: inject the sample rate encoded at, because we can actually decouple
@ -368,7 +362,9 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.mod
levels = 8 if model.model_type == "24khz" else None levels = 8 if model.model_type == "24khz" else None
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp): with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
artifact = model.compress(signal, win_duration=None, verbose=False, n_quantizers=levels) # I guess it's safe to not encode in one chunk
#artifact = model.compress(signal, win_duration=None, verbose=False, n_quantizers=levels)
artifact = model.compress(signal, verbose=False, n_quantizers=levels)
return artifact.codes if not return_metadata else artifact return artifact.codes if not return_metadata else artifact
# AudioDec uses a different pathway # AudioDec uses a different pathway

View File

@ -17,6 +17,10 @@ from pathlib import Path
def pad(num, zeroes): def pad(num, zeroes):
return str(num).zfill(zeroes+1) return str(num).zfill(zeroes+1)
def process_items( items, stride=0, stride_offset=0 ):
items = sorted( items )
return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ]
def transcribe( def transcribe(
input_audio = "voices", input_audio = "voices",
output_metadata = "training/metadata", output_metadata = "training/metadata",
@ -25,6 +29,9 @@ def transcribe(
skip_existing = True, skip_existing = True,
diarize = False, diarize = False,
stride = 0,
stride_offset = ,
batch_size = 16, batch_size = 16,
device = "cuda", device = "cuda",
dtype = "float16", dtype = "float16",
@ -42,7 +49,7 @@ def transcribe(
if not os.path.isdir(f'./{input_audio}/{dataset_name}/'): if not os.path.isdir(f'./{input_audio}/{dataset_name}/'):
continue continue
for speaker_id in tqdm(os.listdir(f'./{input_audio}/{dataset_name}/'), desc="Processing speaker"): for speaker_id in tqdm(process_items(os.listdir(f'./{input_audio}/{dataset_name}/')), desc="Processing speaker"):
if not os.path.isdir(f'./{input_audio}/{dataset_name}/{speaker_id}'): if not os.path.isdir(f'./{input_audio}/{dataset_name}/{speaker_id}'):
continue continue
@ -55,7 +62,6 @@ def transcribe(
metadata = {} metadata = {}
for filename in tqdm(os.listdir(f'./{input_audio}/{dataset_name}/{speaker_id}/'), desc=f"Processing speaker: {speaker_id}"): for filename in tqdm(os.listdir(f'./{input_audio}/{dataset_name}/{speaker_id}/'), desc=f"Processing speaker: {speaker_id}"):
if skip_existing and filename in metadata: if skip_existing and filename in metadata:
continue continue
@ -122,6 +128,8 @@ def main():
parser.add_argument("--skip-existing", action="store_true") parser.add_argument("--skip-existing", action="store_true")
parser.add_argument("--diarize", action="store_true") parser.add_argument("--diarize", action="store_true")
parser.add_argument("--batch-size", type=int, default=16) parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--stride", type=int, default=0)
parser.add_argument("--stride-offset", type=int, default=0)
parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--dtype", type=str, default="bfloat16") parser.add_argument("--dtype", type=str, default="bfloat16")
@ -130,6 +138,13 @@ def main():
args = parser.parse_args() args = parser.parse_args()
# do some assumption magic
# to-do: find a nice way to spawn multiple processes where tqdm plays nicely
if args.device.isnumeric():
args.stride = torch.cuda.device_count()
args.stride_offset = int(args.device)
args.device = f'cuda:{args.device}'
transcribe( transcribe(
input_audio = args.input_audio, input_audio = args.input_audio,
output_metadata = args.output_metadata, output_metadata = args.output_metadata,
@ -138,6 +153,9 @@ def main():
skip_existing = args.skip_existing, skip_existing = args.skip_existing,
diarize = args.diarize, diarize = args.diarize,
stride = args.stride,
stride_offset = args.stride_offset,
batch_size = args.batch_size, batch_size = args.batch_size,
device = args.device, device = args.device,
dtype = args.dtype, dtype = args.dtype,