tweaks and things
This commit is contained in:
parent
8bac8fe902
commit
9710b06b74
|
@ -119,8 +119,10 @@ def process(
|
||||||
continue
|
continue
|
||||||
|
|
||||||
inpath = Path(f'./{input_audio}/{group_name}/{speaker_id}/{book_id}/{filename}')
|
inpath = Path(f'./{input_audio}/{group_name}/{speaker_id}/{book_id}/{filename}')
|
||||||
if not inpath.exists():
|
textpath = _replace_file_extension(inpath, ".original.txt")
|
||||||
|
if not inpath.exists() or not textpath.exists():
|
||||||
missing["audio"].append(str(inpath))
|
missing["audio"].append(str(inpath))
|
||||||
|
continue
|
||||||
|
|
||||||
extension = os.path.splitext(filename)[-1][1:]
|
extension = os.path.splitext(filename)[-1][1:]
|
||||||
fname = filename.replace(f'.{extension}', "")
|
fname = filename.replace(f'.{extension}', "")
|
||||||
|
@ -129,7 +131,7 @@ def process(
|
||||||
language = "en"
|
language = "en"
|
||||||
|
|
||||||
outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}.{extension}')
|
outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}.{extension}')
|
||||||
text = open(_replace_file_extension(inpath, ".original.txt"), "r", encoding="utf-8").read()
|
text = open(textpath, "r", encoding="utf-8").read()
|
||||||
|
|
||||||
if len(text) == 0:
|
if len(text) == 0:
|
||||||
continue
|
continue
|
||||||
|
@ -214,6 +216,13 @@ def main():
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# do some assumption magic
|
||||||
|
# to-do: find a nice way to spawn multiple processes where tqdm plays nicely
|
||||||
|
if args.device.isnumeric():
|
||||||
|
args.stride = torch.cuda.device_count()
|
||||||
|
args.stride_offset = int(args.device)
|
||||||
|
args.device = f'cuda:{args.device}'
|
||||||
|
|
||||||
process(
|
process(
|
||||||
audio_backend=args.audio_backend,
|
audio_backend=args.audio_backend,
|
||||||
input_audio=args.input_audio,
|
input_audio=args.input_audio,
|
||||||
|
|
|
@ -15,6 +15,10 @@ from pathlib import Path
|
||||||
|
|
||||||
from ..config import cfg
|
from ..config import cfg
|
||||||
|
|
||||||
|
# need to validate if this is safe to import before modifying the config
|
||||||
|
from .g2p import encode as phonemize
|
||||||
|
from .qnt import encode as quantize, _replace_file_extension
|
||||||
|
|
||||||
def pad(num, zeroes):
|
def pad(num, zeroes):
|
||||||
return str(num).zfill(zeroes+1)
|
return str(num).zfill(zeroes+1)
|
||||||
|
|
||||||
|
@ -58,11 +62,6 @@ def process(
|
||||||
cfg.inference.weight_dtype = dtype # "bfloat16"
|
cfg.inference.weight_dtype = dtype # "bfloat16"
|
||||||
cfg.inference.amp = amp # False
|
cfg.inference.amp = amp # False
|
||||||
|
|
||||||
# import after because we've overriden the config above
|
|
||||||
# need to validate if this is even necessary anymore
|
|
||||||
from .g2p import encode as phonemize
|
|
||||||
from .qnt import encode as quantize, _replace_file_extension
|
|
||||||
|
|
||||||
output_dataset = f"{output_dataset}/{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}" # "training"
|
output_dataset = f"{output_dataset}/{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}" # "training"
|
||||||
|
|
||||||
language_map = {} # k = group, v = language
|
language_map = {} # k = group, v = language
|
||||||
|
@ -272,6 +271,7 @@ def process(
|
||||||
})
|
})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed to quantize: {outpath}:", e)
|
print(f"Failed to quantize: {outpath}:", e)
|
||||||
|
torchaudio.save( waveform.cpu )
|
||||||
if raise_exceptions:
|
if raise_exceptions:
|
||||||
raise e
|
raise e
|
||||||
continue
|
continue
|
||||||
|
@ -283,19 +283,27 @@ def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
parser.add_argument("--audio-backend", type=str, default="encodec")
|
parser.add_argument("--audio-backend", type=str, default="encodec")
|
||||||
parser.add_argument("--dtype", type=str, default="bfloat16")
|
|
||||||
parser.add_argument("--amp", action="store_true")
|
|
||||||
parser.add_argument("--input-audio", type=str, default="voices")
|
parser.add_argument("--input-audio", type=str, default="voices")
|
||||||
parser.add_argument("--input-metadata", type=str, default="training/metadata")
|
parser.add_argument("--input-metadata", type=str, default="training/metadata")
|
||||||
parser.add_argument("--output-dataset", type=str, default="training/dataset")
|
parser.add_argument("--output-dataset", type=str, default="training/dataset")
|
||||||
parser.add_argument("--device", type=str, default="cuda")
|
|
||||||
parser.add_argument("--raise-exceptions", action="store_true")
|
parser.add_argument("--raise-exceptions", action="store_true")
|
||||||
parser.add_argument("--stride", type=int, default=0)
|
parser.add_argument("--stride", type=int, default=0)
|
||||||
parser.add_argument("--stride-offset", type=int, default=0)
|
parser.add_argument("--stride-offset", type=int, default=0)
|
||||||
parser.add_argument("--slice", type=str, default="auto")
|
parser.add_argument("--slice", type=str, default="auto")
|
||||||
|
|
||||||
|
parser.add_argument("--device", type=str, default="cuda")
|
||||||
|
parser.add_argument("--dtype", type=str, default="bfloat16")
|
||||||
|
parser.add_argument("--amp", action="store_true")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# do some assumption magic
|
||||||
|
# to-do: find a nice way to spawn multiple processes where tqdm plays nicely
|
||||||
|
if args.device.isnumeric():
|
||||||
|
args.stride = torch.cuda.device_count()
|
||||||
|
args.stride_offset = int(args.device)
|
||||||
|
args.device = f'cuda:{args.device}'
|
||||||
|
|
||||||
process(
|
process(
|
||||||
audio_backend=args.audio_backend,
|
audio_backend=args.audio_backend,
|
||||||
input_audio=args.input_audio,
|
input_audio=args.input_audio,
|
||||||
|
|
|
@ -252,18 +252,12 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N
|
||||||
dac_version='1.0.0',
|
dac_version='1.0.0',
|
||||||
)
|
)
|
||||||
dummy = True
|
dummy = True
|
||||||
|
elif hasattr( metadata, "__dict__" ):
|
||||||
|
metadata = metadata.__dict__
|
||||||
|
metadata.pop("codes")
|
||||||
|
|
||||||
# generate object with copied metadata
|
# generate object with copied metadata
|
||||||
artifact = DACFile(
|
artifact = DACFile( codes = codes, **metadata )
|
||||||
codes = codes,
|
|
||||||
# yes I can **kwargs from a dict but what if I want to pass the actual DACFile.metadata from elsewhere
|
|
||||||
chunk_length = metadata["chunk_length"] if isinstance(metadata, dict) else metadata.chunk_length,
|
|
||||||
original_length = metadata["original_length"] if isinstance(metadata, dict) else metadata.original_length,
|
|
||||||
input_db = metadata["input_db"] if isinstance(metadata, dict) else metadata.input_db,
|
|
||||||
channels = metadata["channels"] if isinstance(metadata, dict) else metadata.channels,
|
|
||||||
sample_rate = metadata["sample_rate"] if isinstance(metadata, dict) else metadata.sample_rate,
|
|
||||||
padding = metadata["padding"] if isinstance(metadata, dict) else metadata.padding,
|
|
||||||
dac_version = metadata["dac_version"] if isinstance(metadata, dict) else metadata.dac_version,
|
|
||||||
)
|
|
||||||
artifact.dummy = dummy
|
artifact.dummy = dummy
|
||||||
|
|
||||||
# to-do: inject the sample rate encoded at, because we can actually decouple
|
# to-do: inject the sample rate encoded at, because we can actually decouple
|
||||||
|
@ -368,7 +362,9 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.mod
|
||||||
levels = 8 if model.model_type == "24khz" else None
|
levels = 8 if model.model_type == "24khz" else None
|
||||||
|
|
||||||
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
|
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
|
||||||
artifact = model.compress(signal, win_duration=None, verbose=False, n_quantizers=levels)
|
# I guess it's safe to not encode in one chunk
|
||||||
|
#artifact = model.compress(signal, win_duration=None, verbose=False, n_quantizers=levels)
|
||||||
|
artifact = model.compress(signal, verbose=False, n_quantizers=levels)
|
||||||
return artifact.codes if not return_metadata else artifact
|
return artifact.codes if not return_metadata else artifact
|
||||||
|
|
||||||
# AudioDec uses a different pathway
|
# AudioDec uses a different pathway
|
||||||
|
|
|
@ -17,6 +17,10 @@ from pathlib import Path
|
||||||
def pad(num, zeroes):
|
def pad(num, zeroes):
|
||||||
return str(num).zfill(zeroes+1)
|
return str(num).zfill(zeroes+1)
|
||||||
|
|
||||||
|
def process_items( items, stride=0, stride_offset=0 ):
|
||||||
|
items = sorted( items )
|
||||||
|
return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ]
|
||||||
|
|
||||||
def transcribe(
|
def transcribe(
|
||||||
input_audio = "voices",
|
input_audio = "voices",
|
||||||
output_metadata = "training/metadata",
|
output_metadata = "training/metadata",
|
||||||
|
@ -25,6 +29,9 @@ def transcribe(
|
||||||
skip_existing = True,
|
skip_existing = True,
|
||||||
diarize = False,
|
diarize = False,
|
||||||
|
|
||||||
|
stride = 0,
|
||||||
|
stride_offset = ,
|
||||||
|
|
||||||
batch_size = 16,
|
batch_size = 16,
|
||||||
device = "cuda",
|
device = "cuda",
|
||||||
dtype = "float16",
|
dtype = "float16",
|
||||||
|
@ -42,7 +49,7 @@ def transcribe(
|
||||||
if not os.path.isdir(f'./{input_audio}/{dataset_name}/'):
|
if not os.path.isdir(f'./{input_audio}/{dataset_name}/'):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for speaker_id in tqdm(os.listdir(f'./{input_audio}/{dataset_name}/'), desc="Processing speaker"):
|
for speaker_id in tqdm(process_items(os.listdir(f'./{input_audio}/{dataset_name}/')), desc="Processing speaker"):
|
||||||
if not os.path.isdir(f'./{input_audio}/{dataset_name}/{speaker_id}'):
|
if not os.path.isdir(f'./{input_audio}/{dataset_name}/{speaker_id}'):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -55,7 +62,6 @@ def transcribe(
|
||||||
metadata = {}
|
metadata = {}
|
||||||
|
|
||||||
for filename in tqdm(os.listdir(f'./{input_audio}/{dataset_name}/{speaker_id}/'), desc=f"Processing speaker: {speaker_id}"):
|
for filename in tqdm(os.listdir(f'./{input_audio}/{dataset_name}/{speaker_id}/'), desc=f"Processing speaker: {speaker_id}"):
|
||||||
|
|
||||||
if skip_existing and filename in metadata:
|
if skip_existing and filename in metadata:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -122,6 +128,8 @@ def main():
|
||||||
parser.add_argument("--skip-existing", action="store_true")
|
parser.add_argument("--skip-existing", action="store_true")
|
||||||
parser.add_argument("--diarize", action="store_true")
|
parser.add_argument("--diarize", action="store_true")
|
||||||
parser.add_argument("--batch-size", type=int, default=16)
|
parser.add_argument("--batch-size", type=int, default=16)
|
||||||
|
parser.add_argument("--stride", type=int, default=0)
|
||||||
|
parser.add_argument("--stride-offset", type=int, default=0)
|
||||||
|
|
||||||
parser.add_argument("--device", type=str, default="cuda")
|
parser.add_argument("--device", type=str, default="cuda")
|
||||||
parser.add_argument("--dtype", type=str, default="bfloat16")
|
parser.add_argument("--dtype", type=str, default="bfloat16")
|
||||||
|
@ -129,6 +137,13 @@ def main():
|
||||||
# parser.add_argument("--raise-exceptions", action="store_true")
|
# parser.add_argument("--raise-exceptions", action="store_true")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# do some assumption magic
|
||||||
|
# to-do: find a nice way to spawn multiple processes where tqdm plays nicely
|
||||||
|
if args.device.isnumeric():
|
||||||
|
args.stride = torch.cuda.device_count()
|
||||||
|
args.stride_offset = int(args.device)
|
||||||
|
args.device = f'cuda:{args.device}'
|
||||||
|
|
||||||
transcribe(
|
transcribe(
|
||||||
input_audio = args.input_audio,
|
input_audio = args.input_audio,
|
||||||
|
@ -138,6 +153,9 @@ def main():
|
||||||
skip_existing = args.skip_existing,
|
skip_existing = args.skip_existing,
|
||||||
diarize = args.diarize,
|
diarize = args.diarize,
|
||||||
|
|
||||||
|
stride = args.stride,
|
||||||
|
stride_offset = args.stride_offset,
|
||||||
|
|
||||||
batch_size = args.batch_size,
|
batch_size = args.batch_size,
|
||||||
device = args.device,
|
device = args.device,
|
||||||
dtype = args.dtype,
|
dtype = args.dtype,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user