From eac353cd0bd38b1f0e24dca9adbe1117b87e0588 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 6 Aug 2024 20:23:33 -0500 Subject: [PATCH] busy work and cleanup while I wait for 1TB of audio to quantize... again. --- README.md | 6 +- vall_e/config.py | 41 ++++++-- vall_e/data.py | 10 +- vall_e/emb/g2p.py | 28 +++-- vall_e/emb/process.py | 93 +++++------------ vall_e/emb/qnt.py | 230 ++++++++++++++++++++++++++++++++---------- 6 files changed, 269 insertions(+), 139 deletions(-) diff --git a/README.md b/README.md index 3497221..15ffdd7 100755 --- a/README.md +++ b/README.md @@ -58,13 +58,13 @@ If you already have a dataset you want, for example, your own large corpus or fo 1. Populate your source voices under `./voices/{group name}/{speaker name}/`. -2. Run `python3 ./scripts/transcribe_dataset.py`. This will generate a transcription with timestamps for your dataset. +2. Run `python3 -m vall_e.emb.transcribe`. This will generate a transcription with timestamps for your dataset. + If you're interested in using a different model, edit the script's `model_name` and `batch_size` variables. -3. Run `python3 ./scripts/process_dataset.py`. This will phonemize the transcriptions and quantize the audio. +3. Run `python3 -m vall_e.emb.process`. This will phonemize the transcriptions and quantize the audio. + If you're using a Descript-Audio-Codec based model, ensure to set the sample rate and audio backend accordingly. -4. Copy `./data/config.yaml` to `./training/config.yaml`. Customize the training configuration and populate your `dataset.training` list with the values stored under `./training/dataset_list.json`. +4. Copy `./data/config.yaml` to `./training/config.yaml`. Customize the training configuration and populate your `dataset.training` list with the values stored under `./training/dataset/list.json`. + Refer to `./vall_e/config.py` for additional configuration details. ### Dataset Formats diff --git a/vall_e/config.py b/vall_e/config.py index 687da14..63306cb 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -157,6 +157,8 @@ class Dataset: max_resps: int = 1 # number of samples to target for training p_resp_append: float = 1.0 # probability to append another sample to the training target + p_resp_pad_silence: float = 0.0 # probability to pad resp with silence to fit within the next window + sample_type: str = "path" # path | speaker sample_order: str = "interleaved" # duration sample_max_duration_batch: float = 0.0 # total number of seconds of utterances per batched, 0 to disable @@ -177,11 +179,8 @@ class Dataset: return self._frames_per_second if cfg.audio_backend == "dac": - # using the 44KHz model with 24KHz sources has a frame rate of 41Hz - if cfg.variable_sample_rate and cfg.sample_rate == 24_000: - return 41 - if cfg.sample_rate == 44_000 or cfg.sample_rate == 44_100: # to-do: find the actual value for 44.1K - return 86 + if cfg.sample_rate == 44_100: + return 87 if cfg.sample_rate == 16_000: return 50 @@ -712,14 +711,40 @@ class Config(BaseConfig): tokenizer_path: str = "./tokenizer.json" # tokenizer path sample_rate: int = 24_000 # sample rate the model expects - variable_sample_rate: bool = False # NOT recommended, as running directly 24Khz audio in the 44Khz DAC model will have detrimental quality loss - audio_backend: str = "vocos" # audio backend to use "encodec" | "vocos" | "dac"" weights_format: str = "pth" # "pth" | "sft" - supported_weights_formats: list[str] = field(default_factory=lambda: ["sft", "safetensors", "pt", "pth"]) + def set_audio_backend(self, audio_backend): + cfg.audio_backend = audio_backend + audio_extension = None + if audio_backend in ["encodec", "vocos"]: + audio_extension = ".enc" + cfg.sample_rate = 24_000 + cfg.model.resp_levels = 8 + elif audio_backend == "dac": + audio_extension = ".dac" + cfg.sample_rate = 44_100 + cfg.model.resp_levels = 9 + elif cfg.audio_backend == "audiodec": + audio_extension = ".dec" + sample_rate = 48_000 + cfg.model.resp_levels = 8 # ? + else: + raise Exception(f"Unknown audio backend: {audio_backend}") + + @property + def audio_backend_extension(self): + audio_extension = None + if self.audio_backend in ["encodec", "vocos"]: + audio_extension = ".enc" + elif self.audio_backend == "dac": + audio_extension = ".dac" + elif self.audio_backend == "audiodec": + audio_extension = ".dec" + return audio_extension + @property def model(self): for i, model in enumerate(self.models): diff --git a/vall_e/data.py b/vall_e/data.py index e20be63..8c7caf1 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -11,7 +11,7 @@ import torch import itertools from .config import cfg -from .emb.qnt import trim, trim_random, repeat_extend_audio, concat_audio, merge_audio, decode_to_file, decode as decode_qnt, encode as encode_qnt +from .emb.qnt import trim, trim_random, repeat_extend_audio, concat_audio, merge_audio, decode_to_file, decode as decode_qnt, encode as encode_qnt, pad_codes_with_silence from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler from .utils.distributed import global_rank, local_rank, world_size from .utils.io import torch_save, torch_load @@ -368,7 +368,9 @@ def get_phone_symmap(): return cfg.tokenizer.get_vocab() def tokenize( phones ): - return cfg.tokenizer.encode( "".join(phones) ) + if isinstance( phones, list ): + phones = "".join( phones ) + return cfg.tokenizer.encode( phones ) def get_lang_symmap(): return { @@ -1146,6 +1148,10 @@ class Dataset(_Dataset): if text is None: text = torch.tensor([bos_id, eos_id]).to(self.text_dtype) + # pad the target with silence + if p_resp_pad_silence < random.random(): + resps = pad_codes_with_silence( resps ) + return dict( index=index, path=Path(path), diff --git a/vall_e/emb/g2p.py b/vall_e/emb/g2p.py index 16e25de..96d5f15 100755 --- a/vall_e/emb/g2p.py +++ b/vall_e/emb/g2p.py @@ -54,17 +54,29 @@ def encode(text: str, language="en-us", backend="auto", punctuation=True, stress if not backend or backend == "auto": backend = "espeak" # if language[:2] != "en" else "festival" - text = [ text ] - backend = _get_backend(language=language, backend=backend, stress=stress, strip=strip, punctuation=punctuation) if backend is not None: - tokens = backend.phonemize( text, strip=strip ) + tokens = backend.phonemize( [ text ], strip=strip ) else: - tokens = phonemize( text, language=language, strip=strip, preserve_punctuation=punctuation, with_stress=stress ) + tokens = phonemize( [ text ], language=language, strip=strip, preserve_punctuation=punctuation, with_stress=stress ) if not len(tokens): - tokens = [] - else: - tokens = list(tokens[0]) + raise Exception(f"Failed to phonemize, received empty string: {text}") - return tokens \ No newline at end of file + return tokens[0] + +# Helper function to debug phonemizer +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("string", type=str) + parser.add_argument("--language", type=str, default="en-us") + parser.add_argument("--backend", type=str, default="auto") + parser.add_argument("--no-punctuation", action="store_true") + parser.add_argument("--no-stress", action="store_true") + parser.add_argument("--no-strip", action="store_true") + + args = parser.parse_args() + + phonemes = encode( args.string, language=args.language, backend=args.backend, punctuation=not args.no_punctuation, stress=not args.no_stress, strip=not args.no_strip ) + print( phonemes ) \ No newline at end of file diff --git a/vall_e/emb/process.py b/vall_e/emb/process.py index 5d32379..8418620 100644 --- a/vall_e/emb/process.py +++ b/vall_e/emb/process.py @@ -17,7 +17,7 @@ from ..config import cfg # need to validate if this is safe to import before modifying the config from .g2p import encode as phonemize -from .qnt import encode as quantize, _replace_file_extension +from .qnt import encode as quantize def pad(num, zeroes): return str(num).zfill(zeroes+1) @@ -33,12 +33,11 @@ def process_items( items, stride=0, stride_offset=0 ): items = sorted( items ) return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ] -def process_job( outpath, text, language, waveform, sample_rate ): - phones = phonemize(text, language=language) +def process_job( outpath, waveform, sample_rate, text=None, language="en" ): qnt = quantize(waveform, sr=sample_rate, device=waveform.device) if cfg.audio_backend == "dac": - np.save(open(outpath, "wb"), { + state_dict = { "codes": qnt.codes.cpu().numpy().astype(np.uint16), "metadata": { "original_length": qnt.original_length, @@ -49,33 +48,35 @@ def process_job( outpath, text, language, waveform, sample_rate ): "channels": qnt.channels, "padding": qnt.padding, "dac_version": "1.0.0", - - "text": text.strip(), - "phonemes": "".join(phones), - "language": language, }, - }) + } else: - np.save(open(outpath, "wb"), { + state_dict = { "codes": qnt.cpu().numpy().astype(np.uint16), "metadata": { "original_length": waveform.shape[-1], "sample_rate": sample_rate, - - "text": text.strip(), - "phonemes": "".join(phones), - "language": language, }, - }) + } + + if text: + text = text.strip() + state_dict['metadata'] |= { + "text": text, + "phonemes": phonemize(text, language=language), + "language": language, + } + + np.save(open(outpath, "wb"), state_dict) def process_jobs( jobs, speaker_id="", raise_exceptions=True ): if not jobs: return for job in tqdm(jobs, desc=f"Quantizing: {speaker_id}"): - outpath, text, language, waveform, sample_rate = job + outpath, waveform, sample_rate, text, language = job try: - process_job( outpath, text, language, waveform, sample_rate ) + process_job( outpath, waveform, sample_rate, text, language ) except Exception as e: print(f"Failed to quantize: {outpath}:", e) if raise_exceptions: @@ -98,30 +99,16 @@ def process( dtype="float16", amp=False, ): - # encodec / vocos - - if audio_backend in ["encodec", "vocos"]: - audio_extension = ".enc" - cfg.sample_rate = 24_000 - cfg.model.resp_levels = 8 - elif audio_backend == "dac": - audio_extension = ".dac" - cfg.sample_rate = 44_100 - cfg.model.resp_levels = 9 - elif cfg.audio_backend == "audiodec": - sample_rate = 48_000 - audio_extension = ".dec" - cfg.model.resp_levels = 8 # ? - else: - raise Exception(f"Unknown audio backend: {audio_backend}") - # prepare from args - cfg.audio_backend = audio_backend # "encodec" + cfg.set_audio_backend(args.audio_backend) + audio_extension = cfg.audio_backend_extension + cfg.inference.weight_dtype = dtype # "bfloat16" cfg.inference.amp = amp # False output_dataset = f"{output_dataset}/{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}" # "training" + # to-do: make this also prepared from args language_map = {} # k = group, v = language ignore_groups = [] # skip these groups @@ -164,8 +151,7 @@ def process( if speaker_id in audio_only: for filename in sorted(os.listdir(f'./{input_audio}/{group_name}/{speaker_id}/')): inpath = Path(f'./{input_audio}/{group_name}/{speaker_id}/{filename}') - outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{filename}') - outpath = _replace_file_extension(outpath, audio_extension) + outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{filename}').with_suffix(audio_extension) if outpath.exists(): continue @@ -173,28 +159,7 @@ def process( waveform, sample_rate = load_audio( inpath, device ) qnt = quantize(waveform, sr=sample_rate, device=device) - if cfg.audio_backend == "dac": - np.save(open(outpath, "wb"), { - "codes": qnt.codes.cpu().numpy().astype(np.uint16), - "metadata": { - "original_length": qnt.original_length, - "sample_rate": qnt.sample_rate, - - "input_db": qnt.input_db.cpu().numpy().astype(np.float32), - "chunk_length": qnt.chunk_length, - "channels": qnt.channels, - "padding": qnt.padding, - "dac_version": "1.0.0", - }, - }) - else: - np.save(open(outpath, "wb"), { - "codes": qnt.cpu().numpy().astype(np.uint16), - "metadata": { - "original_length": waveform.shape[-1], - "sample_rate": sample_rate, - }, - }) + process_job(outpath, waveform, sample_rate) continue @@ -229,8 +194,7 @@ def process( language = language_map[group_name] if group_name in language_map else (metadata[filename]["language"] if "language" in metadata[filename] else "en") if len(metadata[filename]["segments"]) == 0 or not use_slices: - outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}.{extension}') - outpath = _replace_file_extension(outpath, audio_extension) + outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}.{extension}').with_suffix(audio_extension) text = metadata[filename]["text"] if len(text) == 0 or outpath.exists(): @@ -240,15 +204,14 @@ def process( if waveform is None: waveform, sample_rate = load_audio( inpath, device ) - jobs.append(( outpath, text, language, waveform, sample_rate )) + jobs.append(( outpath, waveform, sample_rate, text, language )) else: i = 0 for segment in metadata[filename]["segments"]: id = pad(i, 4) i = i + 1 - outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}_{id}.{extension}') - outpath = _replace_file_extension(outpath, audio_extension) + outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}_{id}.{extension}').with_suffix(audio_extension) text = segment["text"] if len(text) == 0 or outpath.exists(): @@ -269,7 +232,7 @@ def process( if end - start < 0: continue - jobs.append(( outpath, text, language, waveform[:, start:end], sample_rate )) + jobs.append(( outpath, waveform[:, start:end], sample_rate, text, language )) # processes audio files one at a time if low_memory: diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index 9e8e6e9..89432fd 100755 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -2,8 +2,10 @@ from ..config import cfg import argparse import random +import math import torch import torchaudio +import numpy as np from functools import cache from pathlib import Path @@ -215,9 +217,12 @@ except Exception as e: """ @cache -def _load_encodec_model(device="cuda", levels=cfg.model.max_levels): +def _load_encodec_model(device="cuda", levels=0): assert cfg.sample_rate == 24_000 + if not levels: + levels = cfg.model.max_levels + # too lazy to un-if ladder this shit bandwidth_id = 6.0 if levels == 2: @@ -243,9 +248,12 @@ def _load_encodec_model(device="cuda", levels=cfg.model.max_levels): return model @cache -def _load_vocos_model(device="cuda", levels=cfg.model.max_levels): +def _load_vocos_model(device="cuda", levels=0): assert cfg.sample_rate == 24_000 + if not levels: + levels = cfg.model.max_levels + model = Vocos.from_pretrained("charactr/vocos-encodec-24khz") model = model.to(device) model = model.eval() @@ -267,32 +275,27 @@ def _load_vocos_model(device="cuda", levels=cfg.model.max_levels): return model @cache -def _load_dac_model(device="cuda", levels=cfg.model.max_levels): +def _load_dac_model(device="cuda"): kwargs = dict(model_type="44khz",model_bitrate="8kbps",tag="latest") - if not cfg.variable_sample_rate: - # yes there's a better way, something like f'{cfg.sample.rate//1000}hz' - if cfg.sample_rate == 44_000 or cfg.sample_rate == 44_100: # because I messed up and had assumed it was an even 44K and not 44.1K - kwargs["model_type"] = "44khz" - elif cfg.sample_rate == 16_000: - kwargs["model_type"] = "16khz" - else: - raise Exception(f'unsupported sample rate: {cfg.sample_rate}') + # yes there's a better way, something like f'{cfg.sample.rate//1000}hz' + if cfg.sample_rate == 44_100: + kwargs["model_type"] = "44khz" + elif cfg.sample_rate == 16_000: + kwargs["model_type"] = "16khz" + else: + raise Exception(f'unsupported sample rate: {cfg.sample_rate}') model = __load_dac_model(**kwargs) model = model.to(device) model = model.eval() - # to revisit later, but experiments shown that this is a bad idea - if cfg.variable_sample_rate: - model.sample_rate = cfg.sample_rate - model.backend = "dac" model.model_type = kwargs["model_type"] return model @cache -def _load_audiodec_model(device="cuda", model_name=None, levels=cfg.model.max_levels): +def _load_audiodec_model(device="cuda", model_name=None): if not model_name: model_name = "libritts_v1" if cfg.sample_rate == 24_000 else "vctk_v1" sample_rate, encoder_checkpoint, decoder_checkpoint = _audiodec_assign_model(model_name) @@ -307,25 +310,25 @@ def _load_audiodec_model(device="cuda", model_name=None, levels=cfg.model.max_le return model @cache -def _load_model(device="cuda", backend=None, levels=cfg.model.max_levels): +def _load_model(device="cuda", backend=None): if not backend: backend = cfg.audio_backend if backend == "audiodec": - return _load_audiodec_model(device, levels=levels) + return _load_audiodec_model(device) if backend == "dac": - return _load_dac_model(device, levels=levels) + return _load_dac_model(device) if backend == "vocos": - return _load_vocos_model(device, levels=levels) + return _load_vocos_model(device) - return _load_encodec_model(device, levels=levels) + return _load_encodec_model(device) def unload_model(): _load_model.cache_clear() _load_encodec_model.cache_clear() # because vocos can only decode @torch.inference_mode() -def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=None): +def decode(codes: Tensor, device="cuda", metadata=None, window_duration=None): # upcast so it won't whine if codes.dtype == torch.int8 or codes.dtype == torch.int16 or codes.dtype == torch.uint8: codes = codes.to(torch.int32) @@ -342,7 +345,7 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N assert codes.dim() == 3, f'Requires shape (b q t) but got {codes.shape}' # load the model - model = _load_model(device, levels=levels) + model = _load_model(device) # AudioDec uses a different pathway if model.backend == "audiodec": @@ -356,7 +359,7 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N dummy = False if metadata is None: metadata = dict( - chunk_length= codes.shape[-1], + chunk_length=codes.shape[-1], original_length=0, input_db=-12, channels=1, @@ -367,10 +370,11 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N dummy = True elif hasattr( metadata, "__dict__" ): metadata = metadata.__dict__ + # generate object with copied metadata artifact = DACFile( codes = codes, - chunk_length = metadata["chunk_length"], + chunk_length = math.floor(window_duration * cfg.dataset.frames_per_second) if window_duration else metadata["chunk_length"], original_length = metadata["original_length"], input_db = metadata["input_db"], channels = metadata["channels"], @@ -400,8 +404,8 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N return wav, model.sample_rate # huh -def decode_to_wave(resps: Tensor, device="cuda", levels=cfg.model.max_levels): - return decode(resps, device=device, levels=levels) +def decode_to_wave(resps: Tensor, device="cuda"): + return decode(resps, device=device) def decode_to_file(resps: Tensor, path: Path, device="cuda"): wavs, sr = decode(resps, device=device) @@ -471,23 +475,19 @@ def encode_as_embedding(codes: Tensor, quant_level: int = 0, sums=False, device= raise Exception(f'Currently only DAC is supported') @torch.inference_mode() -def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.model.max_levels, return_metadata=True): - +def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", return_metadata=True, window_duration=None): # DAC uses a different pathway if cfg.audio_backend == "dac": - model = _load_dac_model(device, levels=levels ) + model = _load_dac_model( device ) signal = AudioSignal(wav, sample_rate=sr) - if not isinstance(levels, int): - levels = 8 if model.model_type == "24khz" else None - - artifact = model.compress(signal, win_duration=None, verbose=False, n_quantizers=levels) - #artifact = model.compress(signal, n_quantizers=levels) + artifact = model.compress(signal, win_duration=window_duration, verbose=False) # , n_quantizers=levels) + #artifact = model.compress(signal) return artifact.codes if not return_metadata else artifact # AudioDec uses a different pathway if cfg.audio_backend == "audiodec": - model = _load_audiodec_model(device, levels=levels ) + model = _load_audiodec_model(device) wav = wav.unsqueeze(0) wav = convert_audio(wav, sr, model.sample_rate, 1) wav = wav.to(device) @@ -498,7 +498,7 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.mod return quantized # vocos does not encode wavs to encodecs, so just use normal encodec - model = _load_encodec_model(device, levels=levels) + model = _load_encodec_model(device) wav = wav.unsqueeze(0) wav = convert_audio(wav, sr, model.sample_rate, model.channels) wav = wav.to(device) @@ -544,6 +544,9 @@ def encode_from_file(path, device="cuda"): """ Helper Functions """ + +# DAC "silence": [ 568, 804, 10, 674, 364, 981, 568, 378, 731] + # trims from the start, up to `target` def trim( qnt, target, reencode=False, device="cuda" ): length = max( qnt.shape[0], qnt.shape[1] ) @@ -613,23 +616,23 @@ def interleave_audio( *args, audio=None ): return res # concats two audios together -def concat_audio( *args, reencode=False, device="cuda", levels=cfg.model.max_levels ): +def concat_audio( *args, reencode=False, device="cuda" ): qnts = [ *args ] qnts = [ qnt for qnt in qnts if qnt is not None ] # just naively combine the codes if not reencode: return torch.concat( qnts ) - decoded = [ decode(qnt, device=device, levels=levels)[0] for qnt in qnts ] + decoded = [ decode(qnt, device=device)[0] for qnt in qnts ] combined = torch.concat( decoded ) - return encode(combined, cfg.sample_rate, device=device, levels=levels)[0].t() + return encode(combined, cfg.sample_rate, device=device)[0].t() # merges two quantized audios together # requires re-encoding because there's no good way to combine the waveforms of two audios without relying on some embedding magic -def merge_audio( *args, device="cuda", scale=[], levels=cfg.model.max_levels ): +def merge_audio( *args, device="cuda", scale=[] ): qnts = [ *args ] qnts = [ qnt for qnt in qnts if qnt is not None ] - decoded = [ decode(qnt, device=device, levels=levels)[0] for qnt in qnts ] + decoded = [ decode(qnt, device=device)[0] for qnt in qnts ] # max length max_length = max([ wav.shape[-1] for wav in decoded ]) @@ -646,17 +649,138 @@ def merge_audio( *args, device="cuda", scale=[], levels=cfg.model.max_levels ): decoded[i] = decoded[i] * scale[i] combined = sum(decoded) / len(decoded) - return encode(combined, cfg.sample_rate, device=device, levels=levels)[0].t() + return encode(combined, cfg.sample_rate, device=device)[0].t() -""" +# Get framerate for a given audio backend +def get_framerate( backend=None, sample_rate=None ): + if not backend: + backend = cfg.audio_backend + if not sample_rate: + sample_rate = cfg.sample_rate + + if backend == "dac": + if sample_rate == 44_100: + return 87 + if sample_rate == 16_000: + return 50 + + # 24Khz Encodec / Vocos and incidentally DAC are all at 75Hz + return 75 + +# Generates quantized silence +def get_silence( length, device=None, codes=None ): + length = math.floor(length * get_framerate()) + if cfg.audio_backend == "dac": + codes = [ 568, 804, 10, 674, 364, 981, 568, 378, 731 ] + else: + codes = [ 62, 424, 786, 673, 622, 986, 570, 948 ] + + return torch.tensor([ codes for _ in range( length ) ], device=device, dtype=torch.int16) + +# Pads a sequence of codes with silence +def pad_codes_with_silence( codes, size=1 ): + duration = codes.shape[0] * get_framerate() + difference = math.ceil( duration + size ) - duration + + silence = get_silence( difference, device=codes.device ) + + half = math.floor(difference / 2 * get_framerate()) + + return torch.concat( [ silence[half:, :], codes, silence[:half, :] ], dim=0 ) + +# Generates an empty waveform +def get_silent_waveform( length, device=None ): + length = math.floor(length * cfg.sample_rate) + return torch.tensor( [ [ 0 for _ in range( length ) ] ], device=device, dtype=torch.float32 ) + +# Pads a waveform with silence +def pad_waveform_with_silence( waveform, sample_rate, size=1 ): + duration = waveform.shape[-1] / sample_rate + difference = math.ceil( duration + size ) - duration + + silence = get_silent_waveform( difference, device=waveform.device ) + + half = math.floor(difference / 2 * sample_rate) + + return torch.concat( [ silence[:, half:], waveform, silence[:, :half] ], dim=-1 ) + +# Encodes/decodes audio, and helps me debug things if __name__ == "__main__": - cfg.sample_rate = 48_000 - cfg.audio_backend = "audiodec" + parser = argparse.ArgumentParser() - wav, sr = torchaudio.load("in.wav") - codes = encode( wav, sr ).t() # for some reason - print( "ENCODED:", codes.shape, codes ) - wav, sr = decode( codes ) - print( "DECODED:", wav.shape, wav ) - torchaudio.save("out.wav", wav.cpu(), sr) -""" \ No newline at end of file + parser.add_argument("--audio-backend", type=str, default="encodec") + parser.add_argument("--input", type=Path) + parser.add_argument("--output", type=Path, default=None) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--dtype", type=str, default="float16") + parser.add_argument("--window-duration", type=float, default=None) # for DAC, the window duration for encoding / decoding + parser.add_argument("--print", action="store_true") # prints codes and metadata + parser.add_argument("--pad", action="store_true") # to test if padding with silence modifies the waveform / quants too much + + args = parser.parse_args() + + # prepare from args + cfg.set_audio_backend(args.audio_backend) + audio_extension = cfg.audio_backend_extension + + cfg.inference.weight_dtype = args.dtype # "bfloat16" + cfg.inference.amp = args.dtype != "float32" + cfg.device = args.device + + # decode + if args.input.suffix == audio_extension: + args.output = args.input.with_suffix('.wav') if not args.output else args.output.with_suffix('.wav') + + artifact = np.load(args.input, allow_pickle=True)[()] + codes = torch.from_numpy(artifact['codes'])[0][:, :].t().to(device=cfg.device, dtype=torch.int16) + + # pad to nearest + if args.pad: + codes = pad_codes_with_silence( codes ) + del artifact['metadata'] + + waveform, sample_rate = decode( codes, device=cfg.device, metadata=artifact['metadata'] if 'metadata' in artifact else None, window_duration=args.window_duration ) + + torchaudio.save(args.output, waveform.cpu(), sample_rate) + + # print + if args.print: + torch.set_printoptions(profile="full") + + print( "Metadata:", artifact['metadata'] ) + print( "Codes:", codes.shape, codes ) + # encode + else: + args.output = args.input.with_suffix(audio_extension) if not args.output else args.output.with_suffix(audio_extension) + + waveform, sample_rate = torchaudio.load(args.input) + + # pad to nearest + if args.pad: + waveform = pad_waveform_with_silence( waveform, sample_rate ) + + qnt = encode(waveform.to(cfg.device), sr=sample_rate, device=cfg.device, window_duration=args.window_duration) + + if cfg.audio_backend == "dac": + state_dict = { + "codes": qnt.codes.cpu().numpy().astype(np.uint16), + "metadata": { + "original_length": qnt.original_length, + "sample_rate": qnt.sample_rate, + + "input_db": qnt.input_db.cpu().numpy().astype(np.float32), + "chunk_length": qnt.chunk_length, + "channels": qnt.channels, + "padding": qnt.padding, + "dac_version": "1.0.0", + }, + } + else: + state_dict = { + "codes": qnt.cpu().numpy().astype(np.uint16), + "metadata": { + "original_length": waveform.shape[-1], + "sample_rate": sample_rate, + }, + } + np.save(open(args.output, "wb"), state_dict) \ No newline at end of file