busy work and cleanup while I wait for 1TB of audio to quantize... again.

This commit is contained in:
mrq 2024-08-06 20:23:33 -05:00
parent f284c7ea9c
commit eac353cd0b
6 changed files with 269 additions and 139 deletions

View File

@ -58,13 +58,13 @@ If you already have a dataset you want, for example, your own large corpus or fo
1. Populate your source voices under `./voices/{group name}/{speaker name}/`. 1. Populate your source voices under `./voices/{group name}/{speaker name}/`.
2. Run `python3 ./scripts/transcribe_dataset.py`. This will generate a transcription with timestamps for your dataset. 2. Run `python3 -m vall_e.emb.transcribe`. This will generate a transcription with timestamps for your dataset.
+ If you're interested in using a different model, edit the script's `model_name` and `batch_size` variables. + If you're interested in using a different model, edit the script's `model_name` and `batch_size` variables.
3. Run `python3 ./scripts/process_dataset.py`. This will phonemize the transcriptions and quantize the audio. 3. Run `python3 -m vall_e.emb.process`. This will phonemize the transcriptions and quantize the audio.
+ If you're using a Descript-Audio-Codec based model, ensure to set the sample rate and audio backend accordingly. + If you're using a Descript-Audio-Codec based model, ensure to set the sample rate and audio backend accordingly.
4. Copy `./data/config.yaml` to `./training/config.yaml`. Customize the training configuration and populate your `dataset.training` list with the values stored under `./training/dataset_list.json`. 4. Copy `./data/config.yaml` to `./training/config.yaml`. Customize the training configuration and populate your `dataset.training` list with the values stored under `./training/dataset/list.json`.
+ Refer to `./vall_e/config.py` for additional configuration details. + Refer to `./vall_e/config.py` for additional configuration details.
### Dataset Formats ### Dataset Formats

View File

@ -157,6 +157,8 @@ class Dataset:
max_resps: int = 1 # number of samples to target for training max_resps: int = 1 # number of samples to target for training
p_resp_append: float = 1.0 # probability to append another sample to the training target p_resp_append: float = 1.0 # probability to append another sample to the training target
p_resp_pad_silence: float = 0.0 # probability to pad resp with silence to fit within the next window
sample_type: str = "path" # path | speaker sample_type: str = "path" # path | speaker
sample_order: str = "interleaved" # duration sample_order: str = "interleaved" # duration
sample_max_duration_batch: float = 0.0 # total number of seconds of utterances per batched, 0 to disable sample_max_duration_batch: float = 0.0 # total number of seconds of utterances per batched, 0 to disable
@ -177,11 +179,8 @@ class Dataset:
return self._frames_per_second return self._frames_per_second
if cfg.audio_backend == "dac": if cfg.audio_backend == "dac":
# using the 44KHz model with 24KHz sources has a frame rate of 41Hz if cfg.sample_rate == 44_100:
if cfg.variable_sample_rate and cfg.sample_rate == 24_000: return 87
return 41
if cfg.sample_rate == 44_000 or cfg.sample_rate == 44_100: # to-do: find the actual value for 44.1K
return 86
if cfg.sample_rate == 16_000: if cfg.sample_rate == 16_000:
return 50 return 50
@ -712,14 +711,40 @@ class Config(BaseConfig):
tokenizer_path: str = "./tokenizer.json" # tokenizer path tokenizer_path: str = "./tokenizer.json" # tokenizer path
sample_rate: int = 24_000 # sample rate the model expects sample_rate: int = 24_000 # sample rate the model expects
variable_sample_rate: bool = False # NOT recommended, as running directly 24Khz audio in the 44Khz DAC model will have detrimental quality loss
audio_backend: str = "vocos" # audio backend to use "encodec" | "vocos" | "dac"" audio_backend: str = "vocos" # audio backend to use "encodec" | "vocos" | "dac""
weights_format: str = "pth" # "pth" | "sft" weights_format: str = "pth" # "pth" | "sft"
supported_weights_formats: list[str] = field(default_factory=lambda: ["sft", "safetensors", "pt", "pth"]) supported_weights_formats: list[str] = field(default_factory=lambda: ["sft", "safetensors", "pt", "pth"])
def set_audio_backend(self, audio_backend):
cfg.audio_backend = audio_backend
audio_extension = None
if audio_backend in ["encodec", "vocos"]:
audio_extension = ".enc"
cfg.sample_rate = 24_000
cfg.model.resp_levels = 8
elif audio_backend == "dac":
audio_extension = ".dac"
cfg.sample_rate = 44_100
cfg.model.resp_levels = 9
elif cfg.audio_backend == "audiodec":
audio_extension = ".dec"
sample_rate = 48_000
cfg.model.resp_levels = 8 # ?
else:
raise Exception(f"Unknown audio backend: {audio_backend}")
@property
def audio_backend_extension(self):
audio_extension = None
if self.audio_backend in ["encodec", "vocos"]:
audio_extension = ".enc"
elif self.audio_backend == "dac":
audio_extension = ".dac"
elif self.audio_backend == "audiodec":
audio_extension = ".dec"
return audio_extension
@property @property
def model(self): def model(self):
for i, model in enumerate(self.models): for i, model in enumerate(self.models):

View File

@ -11,7 +11,7 @@ import torch
import itertools import itertools
from .config import cfg from .config import cfg
from .emb.qnt import trim, trim_random, repeat_extend_audio, concat_audio, merge_audio, decode_to_file, decode as decode_qnt, encode as encode_qnt from .emb.qnt import trim, trim_random, repeat_extend_audio, concat_audio, merge_audio, decode_to_file, decode as decode_qnt, encode as encode_qnt, pad_codes_with_silence
from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler
from .utils.distributed import global_rank, local_rank, world_size from .utils.distributed import global_rank, local_rank, world_size
from .utils.io import torch_save, torch_load from .utils.io import torch_save, torch_load
@ -368,7 +368,9 @@ def get_phone_symmap():
return cfg.tokenizer.get_vocab() return cfg.tokenizer.get_vocab()
def tokenize( phones ): def tokenize( phones ):
return cfg.tokenizer.encode( "".join(phones) ) if isinstance( phones, list ):
phones = "".join( phones )
return cfg.tokenizer.encode( phones )
def get_lang_symmap(): def get_lang_symmap():
return { return {
@ -1146,6 +1148,10 @@ class Dataset(_Dataset):
if text is None: if text is None:
text = torch.tensor([bos_id, eos_id]).to(self.text_dtype) text = torch.tensor([bos_id, eos_id]).to(self.text_dtype)
# pad the target with silence
if p_resp_pad_silence < random.random():
resps = pad_codes_with_silence( resps )
return dict( return dict(
index=index, index=index,
path=Path(path), path=Path(path),

View File

@ -54,17 +54,29 @@ def encode(text: str, language="en-us", backend="auto", punctuation=True, stress
if not backend or backend == "auto": if not backend or backend == "auto":
backend = "espeak" # if language[:2] != "en" else "festival" backend = "espeak" # if language[:2] != "en" else "festival"
text = [ text ]
backend = _get_backend(language=language, backend=backend, stress=stress, strip=strip, punctuation=punctuation) backend = _get_backend(language=language, backend=backend, stress=stress, strip=strip, punctuation=punctuation)
if backend is not None: if backend is not None:
tokens = backend.phonemize( text, strip=strip ) tokens = backend.phonemize( [ text ], strip=strip )
else: else:
tokens = phonemize( text, language=language, strip=strip, preserve_punctuation=punctuation, with_stress=stress ) tokens = phonemize( [ text ], language=language, strip=strip, preserve_punctuation=punctuation, with_stress=stress )
if not len(tokens): if not len(tokens):
tokens = [] raise Exception(f"Failed to phonemize, received empty string: {text}")
else:
tokens = list(tokens[0])
return tokens return tokens[0]
# Helper function to debug phonemizer
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("string", type=str)
parser.add_argument("--language", type=str, default="en-us")
parser.add_argument("--backend", type=str, default="auto")
parser.add_argument("--no-punctuation", action="store_true")
parser.add_argument("--no-stress", action="store_true")
parser.add_argument("--no-strip", action="store_true")
args = parser.parse_args()
phonemes = encode( args.string, language=args.language, backend=args.backend, punctuation=not args.no_punctuation, stress=not args.no_stress, strip=not args.no_strip )
print( phonemes )

View File

@ -17,7 +17,7 @@ from ..config import cfg
# need to validate if this is safe to import before modifying the config # need to validate if this is safe to import before modifying the config
from .g2p import encode as phonemize from .g2p import encode as phonemize
from .qnt import encode as quantize, _replace_file_extension from .qnt import encode as quantize
def pad(num, zeroes): def pad(num, zeroes):
return str(num).zfill(zeroes+1) return str(num).zfill(zeroes+1)
@ -33,12 +33,11 @@ def process_items( items, stride=0, stride_offset=0 ):
items = sorted( items ) items = sorted( items )
return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ] return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ]
def process_job( outpath, text, language, waveform, sample_rate ): def process_job( outpath, waveform, sample_rate, text=None, language="en" ):
phones = phonemize(text, language=language)
qnt = quantize(waveform, sr=sample_rate, device=waveform.device) qnt = quantize(waveform, sr=sample_rate, device=waveform.device)
if cfg.audio_backend == "dac": if cfg.audio_backend == "dac":
np.save(open(outpath, "wb"), { state_dict = {
"codes": qnt.codes.cpu().numpy().astype(np.uint16), "codes": qnt.codes.cpu().numpy().astype(np.uint16),
"metadata": { "metadata": {
"original_length": qnt.original_length, "original_length": qnt.original_length,
@ -49,33 +48,35 @@ def process_job( outpath, text, language, waveform, sample_rate ):
"channels": qnt.channels, "channels": qnt.channels,
"padding": qnt.padding, "padding": qnt.padding,
"dac_version": "1.0.0", "dac_version": "1.0.0",
"text": text.strip(),
"phonemes": "".join(phones),
"language": language,
}, },
}) }
else: else:
np.save(open(outpath, "wb"), { state_dict = {
"codes": qnt.cpu().numpy().astype(np.uint16), "codes": qnt.cpu().numpy().astype(np.uint16),
"metadata": { "metadata": {
"original_length": waveform.shape[-1], "original_length": waveform.shape[-1],
"sample_rate": sample_rate, "sample_rate": sample_rate,
"text": text.strip(),
"phonemes": "".join(phones),
"language": language,
}, },
}) }
if text:
text = text.strip()
state_dict['metadata'] |= {
"text": text,
"phonemes": phonemize(text, language=language),
"language": language,
}
np.save(open(outpath, "wb"), state_dict)
def process_jobs( jobs, speaker_id="", raise_exceptions=True ): def process_jobs( jobs, speaker_id="", raise_exceptions=True ):
if not jobs: if not jobs:
return return
for job in tqdm(jobs, desc=f"Quantizing: {speaker_id}"): for job in tqdm(jobs, desc=f"Quantizing: {speaker_id}"):
outpath, text, language, waveform, sample_rate = job outpath, waveform, sample_rate, text, language = job
try: try:
process_job( outpath, text, language, waveform, sample_rate ) process_job( outpath, waveform, sample_rate, text, language )
except Exception as e: except Exception as e:
print(f"Failed to quantize: {outpath}:", e) print(f"Failed to quantize: {outpath}:", e)
if raise_exceptions: if raise_exceptions:
@ -98,30 +99,16 @@ def process(
dtype="float16", dtype="float16",
amp=False, amp=False,
): ):
# encodec / vocos
if audio_backend in ["encodec", "vocos"]:
audio_extension = ".enc"
cfg.sample_rate = 24_000
cfg.model.resp_levels = 8
elif audio_backend == "dac":
audio_extension = ".dac"
cfg.sample_rate = 44_100
cfg.model.resp_levels = 9
elif cfg.audio_backend == "audiodec":
sample_rate = 48_000
audio_extension = ".dec"
cfg.model.resp_levels = 8 # ?
else:
raise Exception(f"Unknown audio backend: {audio_backend}")
# prepare from args # prepare from args
cfg.audio_backend = audio_backend # "encodec" cfg.set_audio_backend(args.audio_backend)
audio_extension = cfg.audio_backend_extension
cfg.inference.weight_dtype = dtype # "bfloat16" cfg.inference.weight_dtype = dtype # "bfloat16"
cfg.inference.amp = amp # False cfg.inference.amp = amp # False
output_dataset = f"{output_dataset}/{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}" # "training" output_dataset = f"{output_dataset}/{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}" # "training"
# to-do: make this also prepared from args
language_map = {} # k = group, v = language language_map = {} # k = group, v = language
ignore_groups = [] # skip these groups ignore_groups = [] # skip these groups
@ -164,8 +151,7 @@ def process(
if speaker_id in audio_only: if speaker_id in audio_only:
for filename in sorted(os.listdir(f'./{input_audio}/{group_name}/{speaker_id}/')): for filename in sorted(os.listdir(f'./{input_audio}/{group_name}/{speaker_id}/')):
inpath = Path(f'./{input_audio}/{group_name}/{speaker_id}/{filename}') inpath = Path(f'./{input_audio}/{group_name}/{speaker_id}/{filename}')
outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{filename}') outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{filename}').with_suffix(audio_extension)
outpath = _replace_file_extension(outpath, audio_extension)
if outpath.exists(): if outpath.exists():
continue continue
@ -173,28 +159,7 @@ def process(
waveform, sample_rate = load_audio( inpath, device ) waveform, sample_rate = load_audio( inpath, device )
qnt = quantize(waveform, sr=sample_rate, device=device) qnt = quantize(waveform, sr=sample_rate, device=device)
if cfg.audio_backend == "dac": process_job(outpath, waveform, sample_rate)
np.save(open(outpath, "wb"), {
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": qnt.original_length,
"sample_rate": qnt.sample_rate,
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
"chunk_length": qnt.chunk_length,
"channels": qnt.channels,
"padding": qnt.padding,
"dac_version": "1.0.0",
},
})
else:
np.save(open(outpath, "wb"), {
"codes": qnt.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": waveform.shape[-1],
"sample_rate": sample_rate,
},
})
continue continue
@ -229,8 +194,7 @@ def process(
language = language_map[group_name] if group_name in language_map else (metadata[filename]["language"] if "language" in metadata[filename] else "en") language = language_map[group_name] if group_name in language_map else (metadata[filename]["language"] if "language" in metadata[filename] else "en")
if len(metadata[filename]["segments"]) == 0 or not use_slices: if len(metadata[filename]["segments"]) == 0 or not use_slices:
outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}.{extension}') outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}.{extension}').with_suffix(audio_extension)
outpath = _replace_file_extension(outpath, audio_extension)
text = metadata[filename]["text"] text = metadata[filename]["text"]
if len(text) == 0 or outpath.exists(): if len(text) == 0 or outpath.exists():
@ -240,15 +204,14 @@ def process(
if waveform is None: if waveform is None:
waveform, sample_rate = load_audio( inpath, device ) waveform, sample_rate = load_audio( inpath, device )
jobs.append(( outpath, text, language, waveform, sample_rate )) jobs.append(( outpath, waveform, sample_rate, text, language ))
else: else:
i = 0 i = 0
for segment in metadata[filename]["segments"]: for segment in metadata[filename]["segments"]:
id = pad(i, 4) id = pad(i, 4)
i = i + 1 i = i + 1
outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}_{id}.{extension}') outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}_{id}.{extension}').with_suffix(audio_extension)
outpath = _replace_file_extension(outpath, audio_extension)
text = segment["text"] text = segment["text"]
if len(text) == 0 or outpath.exists(): if len(text) == 0 or outpath.exists():
@ -269,7 +232,7 @@ def process(
if end - start < 0: if end - start < 0:
continue continue
jobs.append(( outpath, text, language, waveform[:, start:end], sample_rate )) jobs.append(( outpath, waveform[:, start:end], sample_rate, text, language ))
# processes audio files one at a time # processes audio files one at a time
if low_memory: if low_memory:

View File

@ -2,8 +2,10 @@ from ..config import cfg
import argparse import argparse
import random import random
import math
import torch import torch
import torchaudio import torchaudio
import numpy as np
from functools import cache from functools import cache
from pathlib import Path from pathlib import Path
@ -215,9 +217,12 @@ except Exception as e:
""" """
@cache @cache
def _load_encodec_model(device="cuda", levels=cfg.model.max_levels): def _load_encodec_model(device="cuda", levels=0):
assert cfg.sample_rate == 24_000 assert cfg.sample_rate == 24_000
if not levels:
levels = cfg.model.max_levels
# too lazy to un-if ladder this shit # too lazy to un-if ladder this shit
bandwidth_id = 6.0 bandwidth_id = 6.0
if levels == 2: if levels == 2:
@ -243,9 +248,12 @@ def _load_encodec_model(device="cuda", levels=cfg.model.max_levels):
return model return model
@cache @cache
def _load_vocos_model(device="cuda", levels=cfg.model.max_levels): def _load_vocos_model(device="cuda", levels=0):
assert cfg.sample_rate == 24_000 assert cfg.sample_rate == 24_000
if not levels:
levels = cfg.model.max_levels
model = Vocos.from_pretrained("charactr/vocos-encodec-24khz") model = Vocos.from_pretrained("charactr/vocos-encodec-24khz")
model = model.to(device) model = model.to(device)
model = model.eval() model = model.eval()
@ -267,11 +275,10 @@ def _load_vocos_model(device="cuda", levels=cfg.model.max_levels):
return model return model
@cache @cache
def _load_dac_model(device="cuda", levels=cfg.model.max_levels): def _load_dac_model(device="cuda"):
kwargs = dict(model_type="44khz",model_bitrate="8kbps",tag="latest") kwargs = dict(model_type="44khz",model_bitrate="8kbps",tag="latest")
if not cfg.variable_sample_rate:
# yes there's a better way, something like f'{cfg.sample.rate//1000}hz' # yes there's a better way, something like f'{cfg.sample.rate//1000}hz'
if cfg.sample_rate == 44_000 or cfg.sample_rate == 44_100: # because I messed up and had assumed it was an even 44K and not 44.1K if cfg.sample_rate == 44_100:
kwargs["model_type"] = "44khz" kwargs["model_type"] = "44khz"
elif cfg.sample_rate == 16_000: elif cfg.sample_rate == 16_000:
kwargs["model_type"] = "16khz" kwargs["model_type"] = "16khz"
@ -282,17 +289,13 @@ def _load_dac_model(device="cuda", levels=cfg.model.max_levels):
model = model.to(device) model = model.to(device)
model = model.eval() model = model.eval()
# to revisit later, but experiments shown that this is a bad idea
if cfg.variable_sample_rate:
model.sample_rate = cfg.sample_rate
model.backend = "dac" model.backend = "dac"
model.model_type = kwargs["model_type"] model.model_type = kwargs["model_type"]
return model return model
@cache @cache
def _load_audiodec_model(device="cuda", model_name=None, levels=cfg.model.max_levels): def _load_audiodec_model(device="cuda", model_name=None):
if not model_name: if not model_name:
model_name = "libritts_v1" if cfg.sample_rate == 24_000 else "vctk_v1" model_name = "libritts_v1" if cfg.sample_rate == 24_000 else "vctk_v1"
sample_rate, encoder_checkpoint, decoder_checkpoint = _audiodec_assign_model(model_name) sample_rate, encoder_checkpoint, decoder_checkpoint = _audiodec_assign_model(model_name)
@ -307,25 +310,25 @@ def _load_audiodec_model(device="cuda", model_name=None, levels=cfg.model.max_le
return model return model
@cache @cache
def _load_model(device="cuda", backend=None, levels=cfg.model.max_levels): def _load_model(device="cuda", backend=None):
if not backend: if not backend:
backend = cfg.audio_backend backend = cfg.audio_backend
if backend == "audiodec": if backend == "audiodec":
return _load_audiodec_model(device, levels=levels) return _load_audiodec_model(device)
if backend == "dac": if backend == "dac":
return _load_dac_model(device, levels=levels) return _load_dac_model(device)
if backend == "vocos": if backend == "vocos":
return _load_vocos_model(device, levels=levels) return _load_vocos_model(device)
return _load_encodec_model(device, levels=levels) return _load_encodec_model(device)
def unload_model(): def unload_model():
_load_model.cache_clear() _load_model.cache_clear()
_load_encodec_model.cache_clear() # because vocos can only decode _load_encodec_model.cache_clear() # because vocos can only decode
@torch.inference_mode() @torch.inference_mode()
def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=None): def decode(codes: Tensor, device="cuda", metadata=None, window_duration=None):
# upcast so it won't whine # upcast so it won't whine
if codes.dtype == torch.int8 or codes.dtype == torch.int16 or codes.dtype == torch.uint8: if codes.dtype == torch.int8 or codes.dtype == torch.int16 or codes.dtype == torch.uint8:
codes = codes.to(torch.int32) codes = codes.to(torch.int32)
@ -342,7 +345,7 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N
assert codes.dim() == 3, f'Requires shape (b q t) but got {codes.shape}' assert codes.dim() == 3, f'Requires shape (b q t) but got {codes.shape}'
# load the model # load the model
model = _load_model(device, levels=levels) model = _load_model(device)
# AudioDec uses a different pathway # AudioDec uses a different pathway
if model.backend == "audiodec": if model.backend == "audiodec":
@ -367,10 +370,11 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N
dummy = True dummy = True
elif hasattr( metadata, "__dict__" ): elif hasattr( metadata, "__dict__" ):
metadata = metadata.__dict__ metadata = metadata.__dict__
# generate object with copied metadata # generate object with copied metadata
artifact = DACFile( artifact = DACFile(
codes = codes, codes = codes,
chunk_length = metadata["chunk_length"], chunk_length = math.floor(window_duration * cfg.dataset.frames_per_second) if window_duration else metadata["chunk_length"],
original_length = metadata["original_length"], original_length = metadata["original_length"],
input_db = metadata["input_db"], input_db = metadata["input_db"],
channels = metadata["channels"], channels = metadata["channels"],
@ -400,8 +404,8 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N
return wav, model.sample_rate return wav, model.sample_rate
# huh # huh
def decode_to_wave(resps: Tensor, device="cuda", levels=cfg.model.max_levels): def decode_to_wave(resps: Tensor, device="cuda"):
return decode(resps, device=device, levels=levels) return decode(resps, device=device)
def decode_to_file(resps: Tensor, path: Path, device="cuda"): def decode_to_file(resps: Tensor, path: Path, device="cuda"):
wavs, sr = decode(resps, device=device) wavs, sr = decode(resps, device=device)
@ -471,23 +475,19 @@ def encode_as_embedding(codes: Tensor, quant_level: int = 0, sums=False, device=
raise Exception(f'Currently only DAC is supported') raise Exception(f'Currently only DAC is supported')
@torch.inference_mode() @torch.inference_mode()
def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.model.max_levels, return_metadata=True): def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", return_metadata=True, window_duration=None):
# DAC uses a different pathway # DAC uses a different pathway
if cfg.audio_backend == "dac": if cfg.audio_backend == "dac":
model = _load_dac_model(device, levels=levels ) model = _load_dac_model( device )
signal = AudioSignal(wav, sample_rate=sr) signal = AudioSignal(wav, sample_rate=sr)
if not isinstance(levels, int): artifact = model.compress(signal, win_duration=window_duration, verbose=False) # , n_quantizers=levels)
levels = 8 if model.model_type == "24khz" else None #artifact = model.compress(signal)
artifact = model.compress(signal, win_duration=None, verbose=False, n_quantizers=levels)
#artifact = model.compress(signal, n_quantizers=levels)
return artifact.codes if not return_metadata else artifact return artifact.codes if not return_metadata else artifact
# AudioDec uses a different pathway # AudioDec uses a different pathway
if cfg.audio_backend == "audiodec": if cfg.audio_backend == "audiodec":
model = _load_audiodec_model(device, levels=levels ) model = _load_audiodec_model(device)
wav = wav.unsqueeze(0) wav = wav.unsqueeze(0)
wav = convert_audio(wav, sr, model.sample_rate, 1) wav = convert_audio(wav, sr, model.sample_rate, 1)
wav = wav.to(device) wav = wav.to(device)
@ -498,7 +498,7 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.mod
return quantized return quantized
# vocos does not encode wavs to encodecs, so just use normal encodec # vocos does not encode wavs to encodecs, so just use normal encodec
model = _load_encodec_model(device, levels=levels) model = _load_encodec_model(device)
wav = wav.unsqueeze(0) wav = wav.unsqueeze(0)
wav = convert_audio(wav, sr, model.sample_rate, model.channels) wav = convert_audio(wav, sr, model.sample_rate, model.channels)
wav = wav.to(device) wav = wav.to(device)
@ -544,6 +544,9 @@ def encode_from_file(path, device="cuda"):
""" """
Helper Functions Helper Functions
""" """
# DAC "silence": [ 568, 804, 10, 674, 364, 981, 568, 378, 731]
# trims from the start, up to `target` # trims from the start, up to `target`
def trim( qnt, target, reencode=False, device="cuda" ): def trim( qnt, target, reencode=False, device="cuda" ):
length = max( qnt.shape[0], qnt.shape[1] ) length = max( qnt.shape[0], qnt.shape[1] )
@ -613,23 +616,23 @@ def interleave_audio( *args, audio=None ):
return res return res
# concats two audios together # concats two audios together
def concat_audio( *args, reencode=False, device="cuda", levels=cfg.model.max_levels ): def concat_audio( *args, reencode=False, device="cuda" ):
qnts = [ *args ] qnts = [ *args ]
qnts = [ qnt for qnt in qnts if qnt is not None ] qnts = [ qnt for qnt in qnts if qnt is not None ]
# just naively combine the codes # just naively combine the codes
if not reencode: if not reencode:
return torch.concat( qnts ) return torch.concat( qnts )
decoded = [ decode(qnt, device=device, levels=levels)[0] for qnt in qnts ] decoded = [ decode(qnt, device=device)[0] for qnt in qnts ]
combined = torch.concat( decoded ) combined = torch.concat( decoded )
return encode(combined, cfg.sample_rate, device=device, levels=levels)[0].t() return encode(combined, cfg.sample_rate, device=device)[0].t()
# merges two quantized audios together # merges two quantized audios together
# requires re-encoding because there's no good way to combine the waveforms of two audios without relying on some embedding magic # requires re-encoding because there's no good way to combine the waveforms of two audios without relying on some embedding magic
def merge_audio( *args, device="cuda", scale=[], levels=cfg.model.max_levels ): def merge_audio( *args, device="cuda", scale=[] ):
qnts = [ *args ] qnts = [ *args ]
qnts = [ qnt for qnt in qnts if qnt is not None ] qnts = [ qnt for qnt in qnts if qnt is not None ]
decoded = [ decode(qnt, device=device, levels=levels)[0] for qnt in qnts ] decoded = [ decode(qnt, device=device)[0] for qnt in qnts ]
# max length # max length
max_length = max([ wav.shape[-1] for wav in decoded ]) max_length = max([ wav.shape[-1] for wav in decoded ])
@ -646,17 +649,138 @@ def merge_audio( *args, device="cuda", scale=[], levels=cfg.model.max_levels ):
decoded[i] = decoded[i] * scale[i] decoded[i] = decoded[i] * scale[i]
combined = sum(decoded) / len(decoded) combined = sum(decoded) / len(decoded)
return encode(combined, cfg.sample_rate, device=device, levels=levels)[0].t() return encode(combined, cfg.sample_rate, device=device)[0].t()
""" # Get framerate for a given audio backend
def get_framerate( backend=None, sample_rate=None ):
if not backend:
backend = cfg.audio_backend
if not sample_rate:
sample_rate = cfg.sample_rate
if backend == "dac":
if sample_rate == 44_100:
return 87
if sample_rate == 16_000:
return 50
# 24Khz Encodec / Vocos and incidentally DAC are all at 75Hz
return 75
# Generates quantized silence
def get_silence( length, device=None, codes=None ):
length = math.floor(length * get_framerate())
if cfg.audio_backend == "dac":
codes = [ 568, 804, 10, 674, 364, 981, 568, 378, 731 ]
else:
codes = [ 62, 424, 786, 673, 622, 986, 570, 948 ]
return torch.tensor([ codes for _ in range( length ) ], device=device, dtype=torch.int16)
# Pads a sequence of codes with silence
def pad_codes_with_silence( codes, size=1 ):
duration = codes.shape[0] * get_framerate()
difference = math.ceil( duration + size ) - duration
silence = get_silence( difference, device=codes.device )
half = math.floor(difference / 2 * get_framerate())
return torch.concat( [ silence[half:, :], codes, silence[:half, :] ], dim=0 )
# Generates an empty waveform
def get_silent_waveform( length, device=None ):
length = math.floor(length * cfg.sample_rate)
return torch.tensor( [ [ 0 for _ in range( length ) ] ], device=device, dtype=torch.float32 )
# Pads a waveform with silence
def pad_waveform_with_silence( waveform, sample_rate, size=1 ):
duration = waveform.shape[-1] / sample_rate
difference = math.ceil( duration + size ) - duration
silence = get_silent_waveform( difference, device=waveform.device )
half = math.floor(difference / 2 * sample_rate)
return torch.concat( [ silence[:, half:], waveform, silence[:, :half] ], dim=-1 )
# Encodes/decodes audio, and helps me debug things
if __name__ == "__main__": if __name__ == "__main__":
cfg.sample_rate = 48_000 parser = argparse.ArgumentParser()
cfg.audio_backend = "audiodec"
wav, sr = torchaudio.load("in.wav") parser.add_argument("--audio-backend", type=str, default="encodec")
codes = encode( wav, sr ).t() # for some reason parser.add_argument("--input", type=Path)
print( "ENCODED:", codes.shape, codes ) parser.add_argument("--output", type=Path, default=None)
wav, sr = decode( codes ) parser.add_argument("--device", type=str, default="cuda")
print( "DECODED:", wav.shape, wav ) parser.add_argument("--dtype", type=str, default="float16")
torchaudio.save("out.wav", wav.cpu(), sr) parser.add_argument("--window-duration", type=float, default=None) # for DAC, the window duration for encoding / decoding
""" parser.add_argument("--print", action="store_true") # prints codes and metadata
parser.add_argument("--pad", action="store_true") # to test if padding with silence modifies the waveform / quants too much
args = parser.parse_args()
# prepare from args
cfg.set_audio_backend(args.audio_backend)
audio_extension = cfg.audio_backend_extension
cfg.inference.weight_dtype = args.dtype # "bfloat16"
cfg.inference.amp = args.dtype != "float32"
cfg.device = args.device
# decode
if args.input.suffix == audio_extension:
args.output = args.input.with_suffix('.wav') if not args.output else args.output.with_suffix('.wav')
artifact = np.load(args.input, allow_pickle=True)[()]
codes = torch.from_numpy(artifact['codes'])[0][:, :].t().to(device=cfg.device, dtype=torch.int16)
# pad to nearest
if args.pad:
codes = pad_codes_with_silence( codes )
del artifact['metadata']
waveform, sample_rate = decode( codes, device=cfg.device, metadata=artifact['metadata'] if 'metadata' in artifact else None, window_duration=args.window_duration )
torchaudio.save(args.output, waveform.cpu(), sample_rate)
# print
if args.print:
torch.set_printoptions(profile="full")
print( "Metadata:", artifact['metadata'] )
print( "Codes:", codes.shape, codes )
# encode
else:
args.output = args.input.with_suffix(audio_extension) if not args.output else args.output.with_suffix(audio_extension)
waveform, sample_rate = torchaudio.load(args.input)
# pad to nearest
if args.pad:
waveform = pad_waveform_with_silence( waveform, sample_rate )
qnt = encode(waveform.to(cfg.device), sr=sample_rate, device=cfg.device, window_duration=args.window_duration)
if cfg.audio_backend == "dac":
state_dict = {
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": qnt.original_length,
"sample_rate": qnt.sample_rate,
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
"chunk_length": qnt.chunk_length,
"channels": qnt.channels,
"padding": qnt.padding,
"dac_version": "1.0.0",
},
}
else:
state_dict = {
"codes": qnt.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": waveform.shape[-1],
"sample_rate": sample_rate,
},
}
np.save(open(args.output, "wb"), state_dict)