busy work and cleanup while I wait for 1TB of audio to quantize... again.
This commit is contained in:
parent
f284c7ea9c
commit
eac353cd0b
|
@ -58,13 +58,13 @@ If you already have a dataset you want, for example, your own large corpus or fo
|
|||
|
||||
1. Populate your source voices under `./voices/{group name}/{speaker name}/`.
|
||||
|
||||
2. Run `python3 ./scripts/transcribe_dataset.py`. This will generate a transcription with timestamps for your dataset.
|
||||
2. Run `python3 -m vall_e.emb.transcribe`. This will generate a transcription with timestamps for your dataset.
|
||||
+ If you're interested in using a different model, edit the script's `model_name` and `batch_size` variables.
|
||||
|
||||
3. Run `python3 ./scripts/process_dataset.py`. This will phonemize the transcriptions and quantize the audio.
|
||||
3. Run `python3 -m vall_e.emb.process`. This will phonemize the transcriptions and quantize the audio.
|
||||
+ If you're using a Descript-Audio-Codec based model, ensure to set the sample rate and audio backend accordingly.
|
||||
|
||||
4. Copy `./data/config.yaml` to `./training/config.yaml`. Customize the training configuration and populate your `dataset.training` list with the values stored under `./training/dataset_list.json`.
|
||||
4. Copy `./data/config.yaml` to `./training/config.yaml`. Customize the training configuration and populate your `dataset.training` list with the values stored under `./training/dataset/list.json`.
|
||||
+ Refer to `./vall_e/config.py` for additional configuration details.
|
||||
|
||||
### Dataset Formats
|
||||
|
|
|
@ -157,6 +157,8 @@ class Dataset:
|
|||
max_resps: int = 1 # number of samples to target for training
|
||||
p_resp_append: float = 1.0 # probability to append another sample to the training target
|
||||
|
||||
p_resp_pad_silence: float = 0.0 # probability to pad resp with silence to fit within the next window
|
||||
|
||||
sample_type: str = "path" # path | speaker
|
||||
sample_order: str = "interleaved" # duration
|
||||
sample_max_duration_batch: float = 0.0 # total number of seconds of utterances per batched, 0 to disable
|
||||
|
@ -177,11 +179,8 @@ class Dataset:
|
|||
return self._frames_per_second
|
||||
|
||||
if cfg.audio_backend == "dac":
|
||||
# using the 44KHz model with 24KHz sources has a frame rate of 41Hz
|
||||
if cfg.variable_sample_rate and cfg.sample_rate == 24_000:
|
||||
return 41
|
||||
if cfg.sample_rate == 44_000 or cfg.sample_rate == 44_100: # to-do: find the actual value for 44.1K
|
||||
return 86
|
||||
if cfg.sample_rate == 44_100:
|
||||
return 87
|
||||
if cfg.sample_rate == 16_000:
|
||||
return 50
|
||||
|
||||
|
@ -712,14 +711,40 @@ class Config(BaseConfig):
|
|||
tokenizer_path: str = "./tokenizer.json" # tokenizer path
|
||||
|
||||
sample_rate: int = 24_000 # sample rate the model expects
|
||||
variable_sample_rate: bool = False # NOT recommended, as running directly 24Khz audio in the 44Khz DAC model will have detrimental quality loss
|
||||
|
||||
audio_backend: str = "vocos" # audio backend to use "encodec" | "vocos" | "dac""
|
||||
|
||||
weights_format: str = "pth" # "pth" | "sft"
|
||||
|
||||
supported_weights_formats: list[str] = field(default_factory=lambda: ["sft", "safetensors", "pt", "pth"])
|
||||
|
||||
def set_audio_backend(self, audio_backend):
|
||||
cfg.audio_backend = audio_backend
|
||||
audio_extension = None
|
||||
if audio_backend in ["encodec", "vocos"]:
|
||||
audio_extension = ".enc"
|
||||
cfg.sample_rate = 24_000
|
||||
cfg.model.resp_levels = 8
|
||||
elif audio_backend == "dac":
|
||||
audio_extension = ".dac"
|
||||
cfg.sample_rate = 44_100
|
||||
cfg.model.resp_levels = 9
|
||||
elif cfg.audio_backend == "audiodec":
|
||||
audio_extension = ".dec"
|
||||
sample_rate = 48_000
|
||||
cfg.model.resp_levels = 8 # ?
|
||||
else:
|
||||
raise Exception(f"Unknown audio backend: {audio_backend}")
|
||||
|
||||
@property
|
||||
def audio_backend_extension(self):
|
||||
audio_extension = None
|
||||
if self.audio_backend in ["encodec", "vocos"]:
|
||||
audio_extension = ".enc"
|
||||
elif self.audio_backend == "dac":
|
||||
audio_extension = ".dac"
|
||||
elif self.audio_backend == "audiodec":
|
||||
audio_extension = ".dec"
|
||||
return audio_extension
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
for i, model in enumerate(self.models):
|
||||
|
|
|
@ -11,7 +11,7 @@ import torch
|
|||
import itertools
|
||||
|
||||
from .config import cfg
|
||||
from .emb.qnt import trim, trim_random, repeat_extend_audio, concat_audio, merge_audio, decode_to_file, decode as decode_qnt, encode as encode_qnt
|
||||
from .emb.qnt import trim, trim_random, repeat_extend_audio, concat_audio, merge_audio, decode_to_file, decode as decode_qnt, encode as encode_qnt, pad_codes_with_silence
|
||||
from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler
|
||||
from .utils.distributed import global_rank, local_rank, world_size
|
||||
from .utils.io import torch_save, torch_load
|
||||
|
@ -368,7 +368,9 @@ def get_phone_symmap():
|
|||
return cfg.tokenizer.get_vocab()
|
||||
|
||||
def tokenize( phones ):
|
||||
return cfg.tokenizer.encode( "".join(phones) )
|
||||
if isinstance( phones, list ):
|
||||
phones = "".join( phones )
|
||||
return cfg.tokenizer.encode( phones )
|
||||
|
||||
def get_lang_symmap():
|
||||
return {
|
||||
|
@ -1146,6 +1148,10 @@ class Dataset(_Dataset):
|
|||
if text is None:
|
||||
text = torch.tensor([bos_id, eos_id]).to(self.text_dtype)
|
||||
|
||||
# pad the target with silence
|
||||
if p_resp_pad_silence < random.random():
|
||||
resps = pad_codes_with_silence( resps )
|
||||
|
||||
return dict(
|
||||
index=index,
|
||||
path=Path(path),
|
||||
|
|
|
@ -54,17 +54,29 @@ def encode(text: str, language="en-us", backend="auto", punctuation=True, stress
|
|||
if not backend or backend == "auto":
|
||||
backend = "espeak" # if language[:2] != "en" else "festival"
|
||||
|
||||
text = [ text ]
|
||||
|
||||
backend = _get_backend(language=language, backend=backend, stress=stress, strip=strip, punctuation=punctuation)
|
||||
if backend is not None:
|
||||
tokens = backend.phonemize( text, strip=strip )
|
||||
tokens = backend.phonemize( [ text ], strip=strip )
|
||||
else:
|
||||
tokens = phonemize( text, language=language, strip=strip, preserve_punctuation=punctuation, with_stress=stress )
|
||||
tokens = phonemize( [ text ], language=language, strip=strip, preserve_punctuation=punctuation, with_stress=stress )
|
||||
|
||||
if not len(tokens):
|
||||
tokens = []
|
||||
else:
|
||||
tokens = list(tokens[0])
|
||||
raise Exception(f"Failed to phonemize, received empty string: {text}")
|
||||
|
||||
return tokens
|
||||
return tokens[0]
|
||||
|
||||
# Helper function to debug phonemizer
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("string", type=str)
|
||||
parser.add_argument("--language", type=str, default="en-us")
|
||||
parser.add_argument("--backend", type=str, default="auto")
|
||||
parser.add_argument("--no-punctuation", action="store_true")
|
||||
parser.add_argument("--no-stress", action="store_true")
|
||||
parser.add_argument("--no-strip", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
phonemes = encode( args.string, language=args.language, backend=args.backend, punctuation=not args.no_punctuation, stress=not args.no_stress, strip=not args.no_strip )
|
||||
print( phonemes )
|
|
@ -17,7 +17,7 @@ from ..config import cfg
|
|||
|
||||
# need to validate if this is safe to import before modifying the config
|
||||
from .g2p import encode as phonemize
|
||||
from .qnt import encode as quantize, _replace_file_extension
|
||||
from .qnt import encode as quantize
|
||||
|
||||
def pad(num, zeroes):
|
||||
return str(num).zfill(zeroes+1)
|
||||
|
@ -33,12 +33,11 @@ def process_items( items, stride=0, stride_offset=0 ):
|
|||
items = sorted( items )
|
||||
return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ]
|
||||
|
||||
def process_job( outpath, text, language, waveform, sample_rate ):
|
||||
phones = phonemize(text, language=language)
|
||||
def process_job( outpath, waveform, sample_rate, text=None, language="en" ):
|
||||
qnt = quantize(waveform, sr=sample_rate, device=waveform.device)
|
||||
|
||||
if cfg.audio_backend == "dac":
|
||||
np.save(open(outpath, "wb"), {
|
||||
state_dict = {
|
||||
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
|
||||
"metadata": {
|
||||
"original_length": qnt.original_length,
|
||||
|
@ -49,33 +48,35 @@ def process_job( outpath, text, language, waveform, sample_rate ):
|
|||
"channels": qnt.channels,
|
||||
"padding": qnt.padding,
|
||||
"dac_version": "1.0.0",
|
||||
|
||||
"text": text.strip(),
|
||||
"phonemes": "".join(phones),
|
||||
"language": language,
|
||||
},
|
||||
})
|
||||
}
|
||||
else:
|
||||
np.save(open(outpath, "wb"), {
|
||||
state_dict = {
|
||||
"codes": qnt.cpu().numpy().astype(np.uint16),
|
||||
"metadata": {
|
||||
"original_length": waveform.shape[-1],
|
||||
"sample_rate": sample_rate,
|
||||
|
||||
"text": text.strip(),
|
||||
"phonemes": "".join(phones),
|
||||
"language": language,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
if text:
|
||||
text = text.strip()
|
||||
state_dict['metadata'] |= {
|
||||
"text": text,
|
||||
"phonemes": phonemize(text, language=language),
|
||||
"language": language,
|
||||
}
|
||||
|
||||
np.save(open(outpath, "wb"), state_dict)
|
||||
|
||||
def process_jobs( jobs, speaker_id="", raise_exceptions=True ):
|
||||
if not jobs:
|
||||
return
|
||||
|
||||
for job in tqdm(jobs, desc=f"Quantizing: {speaker_id}"):
|
||||
outpath, text, language, waveform, sample_rate = job
|
||||
outpath, waveform, sample_rate, text, language = job
|
||||
try:
|
||||
process_job( outpath, text, language, waveform, sample_rate )
|
||||
process_job( outpath, waveform, sample_rate, text, language )
|
||||
except Exception as e:
|
||||
print(f"Failed to quantize: {outpath}:", e)
|
||||
if raise_exceptions:
|
||||
|
@ -98,30 +99,16 @@ def process(
|
|||
dtype="float16",
|
||||
amp=False,
|
||||
):
|
||||
# encodec / vocos
|
||||
|
||||
if audio_backend in ["encodec", "vocos"]:
|
||||
audio_extension = ".enc"
|
||||
cfg.sample_rate = 24_000
|
||||
cfg.model.resp_levels = 8
|
||||
elif audio_backend == "dac":
|
||||
audio_extension = ".dac"
|
||||
cfg.sample_rate = 44_100
|
||||
cfg.model.resp_levels = 9
|
||||
elif cfg.audio_backend == "audiodec":
|
||||
sample_rate = 48_000
|
||||
audio_extension = ".dec"
|
||||
cfg.model.resp_levels = 8 # ?
|
||||
else:
|
||||
raise Exception(f"Unknown audio backend: {audio_backend}")
|
||||
|
||||
# prepare from args
|
||||
cfg.audio_backend = audio_backend # "encodec"
|
||||
cfg.set_audio_backend(args.audio_backend)
|
||||
audio_extension = cfg.audio_backend_extension
|
||||
|
||||
cfg.inference.weight_dtype = dtype # "bfloat16"
|
||||
cfg.inference.amp = amp # False
|
||||
|
||||
output_dataset = f"{output_dataset}/{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}" # "training"
|
||||
|
||||
# to-do: make this also prepared from args
|
||||
language_map = {} # k = group, v = language
|
||||
|
||||
ignore_groups = [] # skip these groups
|
||||
|
@ -164,8 +151,7 @@ def process(
|
|||
if speaker_id in audio_only:
|
||||
for filename in sorted(os.listdir(f'./{input_audio}/{group_name}/{speaker_id}/')):
|
||||
inpath = Path(f'./{input_audio}/{group_name}/{speaker_id}/{filename}')
|
||||
outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{filename}')
|
||||
outpath = _replace_file_extension(outpath, audio_extension)
|
||||
outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{filename}').with_suffix(audio_extension)
|
||||
|
||||
if outpath.exists():
|
||||
continue
|
||||
|
@ -173,28 +159,7 @@ def process(
|
|||
waveform, sample_rate = load_audio( inpath, device )
|
||||
qnt = quantize(waveform, sr=sample_rate, device=device)
|
||||
|
||||
if cfg.audio_backend == "dac":
|
||||
np.save(open(outpath, "wb"), {
|
||||
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
|
||||
"metadata": {
|
||||
"original_length": qnt.original_length,
|
||||
"sample_rate": qnt.sample_rate,
|
||||
|
||||
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
|
||||
"chunk_length": qnt.chunk_length,
|
||||
"channels": qnt.channels,
|
||||
"padding": qnt.padding,
|
||||
"dac_version": "1.0.0",
|
||||
},
|
||||
})
|
||||
else:
|
||||
np.save(open(outpath, "wb"), {
|
||||
"codes": qnt.cpu().numpy().astype(np.uint16),
|
||||
"metadata": {
|
||||
"original_length": waveform.shape[-1],
|
||||
"sample_rate": sample_rate,
|
||||
},
|
||||
})
|
||||
process_job(outpath, waveform, sample_rate)
|
||||
|
||||
continue
|
||||
|
||||
|
@ -229,8 +194,7 @@ def process(
|
|||
language = language_map[group_name] if group_name in language_map else (metadata[filename]["language"] if "language" in metadata[filename] else "en")
|
||||
|
||||
if len(metadata[filename]["segments"]) == 0 or not use_slices:
|
||||
outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}.{extension}')
|
||||
outpath = _replace_file_extension(outpath, audio_extension)
|
||||
outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}.{extension}').with_suffix(audio_extension)
|
||||
text = metadata[filename]["text"]
|
||||
|
||||
if len(text) == 0 or outpath.exists():
|
||||
|
@ -240,15 +204,14 @@ def process(
|
|||
if waveform is None:
|
||||
waveform, sample_rate = load_audio( inpath, device )
|
||||
|
||||
jobs.append(( outpath, text, language, waveform, sample_rate ))
|
||||
jobs.append(( outpath, waveform, sample_rate, text, language ))
|
||||
else:
|
||||
i = 0
|
||||
for segment in metadata[filename]["segments"]:
|
||||
id = pad(i, 4)
|
||||
i = i + 1
|
||||
|
||||
outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}_{id}.{extension}')
|
||||
outpath = _replace_file_extension(outpath, audio_extension)
|
||||
outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}_{id}.{extension}').with_suffix(audio_extension)
|
||||
text = segment["text"]
|
||||
|
||||
if len(text) == 0 or outpath.exists():
|
||||
|
@ -269,7 +232,7 @@ def process(
|
|||
if end - start < 0:
|
||||
continue
|
||||
|
||||
jobs.append(( outpath, text, language, waveform[:, start:end], sample_rate ))
|
||||
jobs.append(( outpath, waveform[:, start:end], sample_rate, text, language ))
|
||||
|
||||
# processes audio files one at a time
|
||||
if low_memory:
|
||||
|
|
|
@ -2,8 +2,10 @@ from ..config import cfg
|
|||
|
||||
import argparse
|
||||
import random
|
||||
import math
|
||||
import torch
|
||||
import torchaudio
|
||||
import numpy as np
|
||||
|
||||
from functools import cache
|
||||
from pathlib import Path
|
||||
|
@ -215,9 +217,12 @@ except Exception as e:
|
|||
"""
|
||||
|
||||
@cache
|
||||
def _load_encodec_model(device="cuda", levels=cfg.model.max_levels):
|
||||
def _load_encodec_model(device="cuda", levels=0):
|
||||
assert cfg.sample_rate == 24_000
|
||||
|
||||
if not levels:
|
||||
levels = cfg.model.max_levels
|
||||
|
||||
# too lazy to un-if ladder this shit
|
||||
bandwidth_id = 6.0
|
||||
if levels == 2:
|
||||
|
@ -243,9 +248,12 @@ def _load_encodec_model(device="cuda", levels=cfg.model.max_levels):
|
|||
return model
|
||||
|
||||
@cache
|
||||
def _load_vocos_model(device="cuda", levels=cfg.model.max_levels):
|
||||
def _load_vocos_model(device="cuda", levels=0):
|
||||
assert cfg.sample_rate == 24_000
|
||||
|
||||
if not levels:
|
||||
levels = cfg.model.max_levels
|
||||
|
||||
model = Vocos.from_pretrained("charactr/vocos-encodec-24khz")
|
||||
model = model.to(device)
|
||||
model = model.eval()
|
||||
|
@ -267,32 +275,27 @@ def _load_vocos_model(device="cuda", levels=cfg.model.max_levels):
|
|||
return model
|
||||
|
||||
@cache
|
||||
def _load_dac_model(device="cuda", levels=cfg.model.max_levels):
|
||||
def _load_dac_model(device="cuda"):
|
||||
kwargs = dict(model_type="44khz",model_bitrate="8kbps",tag="latest")
|
||||
if not cfg.variable_sample_rate:
|
||||
# yes there's a better way, something like f'{cfg.sample.rate//1000}hz'
|
||||
if cfg.sample_rate == 44_000 or cfg.sample_rate == 44_100: # because I messed up and had assumed it was an even 44K and not 44.1K
|
||||
kwargs["model_type"] = "44khz"
|
||||
elif cfg.sample_rate == 16_000:
|
||||
kwargs["model_type"] = "16khz"
|
||||
else:
|
||||
raise Exception(f'unsupported sample rate: {cfg.sample_rate}')
|
||||
# yes there's a better way, something like f'{cfg.sample.rate//1000}hz'
|
||||
if cfg.sample_rate == 44_100:
|
||||
kwargs["model_type"] = "44khz"
|
||||
elif cfg.sample_rate == 16_000:
|
||||
kwargs["model_type"] = "16khz"
|
||||
else:
|
||||
raise Exception(f'unsupported sample rate: {cfg.sample_rate}')
|
||||
|
||||
model = __load_dac_model(**kwargs)
|
||||
model = model.to(device)
|
||||
model = model.eval()
|
||||
|
||||
# to revisit later, but experiments shown that this is a bad idea
|
||||
if cfg.variable_sample_rate:
|
||||
model.sample_rate = cfg.sample_rate
|
||||
|
||||
model.backend = "dac"
|
||||
model.model_type = kwargs["model_type"]
|
||||
|
||||
return model
|
||||
|
||||
@cache
|
||||
def _load_audiodec_model(device="cuda", model_name=None, levels=cfg.model.max_levels):
|
||||
def _load_audiodec_model(device="cuda", model_name=None):
|
||||
if not model_name:
|
||||
model_name = "libritts_v1" if cfg.sample_rate == 24_000 else "vctk_v1"
|
||||
sample_rate, encoder_checkpoint, decoder_checkpoint = _audiodec_assign_model(model_name)
|
||||
|
@ -307,25 +310,25 @@ def _load_audiodec_model(device="cuda", model_name=None, levels=cfg.model.max_le
|
|||
return model
|
||||
|
||||
@cache
|
||||
def _load_model(device="cuda", backend=None, levels=cfg.model.max_levels):
|
||||
def _load_model(device="cuda", backend=None):
|
||||
if not backend:
|
||||
backend = cfg.audio_backend
|
||||
|
||||
if backend == "audiodec":
|
||||
return _load_audiodec_model(device, levels=levels)
|
||||
return _load_audiodec_model(device)
|
||||
if backend == "dac":
|
||||
return _load_dac_model(device, levels=levels)
|
||||
return _load_dac_model(device)
|
||||
if backend == "vocos":
|
||||
return _load_vocos_model(device, levels=levels)
|
||||
return _load_vocos_model(device)
|
||||
|
||||
return _load_encodec_model(device, levels=levels)
|
||||
return _load_encodec_model(device)
|
||||
|
||||
def unload_model():
|
||||
_load_model.cache_clear()
|
||||
_load_encodec_model.cache_clear() # because vocos can only decode
|
||||
|
||||
@torch.inference_mode()
|
||||
def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=None):
|
||||
def decode(codes: Tensor, device="cuda", metadata=None, window_duration=None):
|
||||
# upcast so it won't whine
|
||||
if codes.dtype == torch.int8 or codes.dtype == torch.int16 or codes.dtype == torch.uint8:
|
||||
codes = codes.to(torch.int32)
|
||||
|
@ -342,7 +345,7 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N
|
|||
assert codes.dim() == 3, f'Requires shape (b q t) but got {codes.shape}'
|
||||
|
||||
# load the model
|
||||
model = _load_model(device, levels=levels)
|
||||
model = _load_model(device)
|
||||
|
||||
# AudioDec uses a different pathway
|
||||
if model.backend == "audiodec":
|
||||
|
@ -356,7 +359,7 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N
|
|||
dummy = False
|
||||
if metadata is None:
|
||||
metadata = dict(
|
||||
chunk_length= codes.shape[-1],
|
||||
chunk_length=codes.shape[-1],
|
||||
original_length=0,
|
||||
input_db=-12,
|
||||
channels=1,
|
||||
|
@ -367,10 +370,11 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N
|
|||
dummy = True
|
||||
elif hasattr( metadata, "__dict__" ):
|
||||
metadata = metadata.__dict__
|
||||
|
||||
# generate object with copied metadata
|
||||
artifact = DACFile(
|
||||
codes = codes,
|
||||
chunk_length = metadata["chunk_length"],
|
||||
chunk_length = math.floor(window_duration * cfg.dataset.frames_per_second) if window_duration else metadata["chunk_length"],
|
||||
original_length = metadata["original_length"],
|
||||
input_db = metadata["input_db"],
|
||||
channels = metadata["channels"],
|
||||
|
@ -400,8 +404,8 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N
|
|||
return wav, model.sample_rate
|
||||
|
||||
# huh
|
||||
def decode_to_wave(resps: Tensor, device="cuda", levels=cfg.model.max_levels):
|
||||
return decode(resps, device=device, levels=levels)
|
||||
def decode_to_wave(resps: Tensor, device="cuda"):
|
||||
return decode(resps, device=device)
|
||||
|
||||
def decode_to_file(resps: Tensor, path: Path, device="cuda"):
|
||||
wavs, sr = decode(resps, device=device)
|
||||
|
@ -471,23 +475,19 @@ def encode_as_embedding(codes: Tensor, quant_level: int = 0, sums=False, device=
|
|||
raise Exception(f'Currently only DAC is supported')
|
||||
|
||||
@torch.inference_mode()
|
||||
def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.model.max_levels, return_metadata=True):
|
||||
|
||||
def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", return_metadata=True, window_duration=None):
|
||||
# DAC uses a different pathway
|
||||
if cfg.audio_backend == "dac":
|
||||
model = _load_dac_model(device, levels=levels )
|
||||
model = _load_dac_model( device )
|
||||
signal = AudioSignal(wav, sample_rate=sr)
|
||||
|
||||
if not isinstance(levels, int):
|
||||
levels = 8 if model.model_type == "24khz" else None
|
||||
|
||||
artifact = model.compress(signal, win_duration=None, verbose=False, n_quantizers=levels)
|
||||
#artifact = model.compress(signal, n_quantizers=levels)
|
||||
artifact = model.compress(signal, win_duration=window_duration, verbose=False) # , n_quantizers=levels)
|
||||
#artifact = model.compress(signal)
|
||||
return artifact.codes if not return_metadata else artifact
|
||||
|
||||
# AudioDec uses a different pathway
|
||||
if cfg.audio_backend == "audiodec":
|
||||
model = _load_audiodec_model(device, levels=levels )
|
||||
model = _load_audiodec_model(device)
|
||||
wav = wav.unsqueeze(0)
|
||||
wav = convert_audio(wav, sr, model.sample_rate, 1)
|
||||
wav = wav.to(device)
|
||||
|
@ -498,7 +498,7 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.mod
|
|||
return quantized
|
||||
|
||||
# vocos does not encode wavs to encodecs, so just use normal encodec
|
||||
model = _load_encodec_model(device, levels=levels)
|
||||
model = _load_encodec_model(device)
|
||||
wav = wav.unsqueeze(0)
|
||||
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
|
||||
wav = wav.to(device)
|
||||
|
@ -544,6 +544,9 @@ def encode_from_file(path, device="cuda"):
|
|||
"""
|
||||
Helper Functions
|
||||
"""
|
||||
|
||||
# DAC "silence": [ 568, 804, 10, 674, 364, 981, 568, 378, 731]
|
||||
|
||||
# trims from the start, up to `target`
|
||||
def trim( qnt, target, reencode=False, device="cuda" ):
|
||||
length = max( qnt.shape[0], qnt.shape[1] )
|
||||
|
@ -613,23 +616,23 @@ def interleave_audio( *args, audio=None ):
|
|||
return res
|
||||
|
||||
# concats two audios together
|
||||
def concat_audio( *args, reencode=False, device="cuda", levels=cfg.model.max_levels ):
|
||||
def concat_audio( *args, reencode=False, device="cuda" ):
|
||||
qnts = [ *args ]
|
||||
qnts = [ qnt for qnt in qnts if qnt is not None ]
|
||||
# just naively combine the codes
|
||||
if not reencode:
|
||||
return torch.concat( qnts )
|
||||
|
||||
decoded = [ decode(qnt, device=device, levels=levels)[0] for qnt in qnts ]
|
||||
decoded = [ decode(qnt, device=device)[0] for qnt in qnts ]
|
||||
combined = torch.concat( decoded )
|
||||
return encode(combined, cfg.sample_rate, device=device, levels=levels)[0].t()
|
||||
return encode(combined, cfg.sample_rate, device=device)[0].t()
|
||||
|
||||
# merges two quantized audios together
|
||||
# requires re-encoding because there's no good way to combine the waveforms of two audios without relying on some embedding magic
|
||||
def merge_audio( *args, device="cuda", scale=[], levels=cfg.model.max_levels ):
|
||||
def merge_audio( *args, device="cuda", scale=[] ):
|
||||
qnts = [ *args ]
|
||||
qnts = [ qnt for qnt in qnts if qnt is not None ]
|
||||
decoded = [ decode(qnt, device=device, levels=levels)[0] for qnt in qnts ]
|
||||
decoded = [ decode(qnt, device=device)[0] for qnt in qnts ]
|
||||
|
||||
# max length
|
||||
max_length = max([ wav.shape[-1] for wav in decoded ])
|
||||
|
@ -646,17 +649,138 @@ def merge_audio( *args, device="cuda", scale=[], levels=cfg.model.max_levels ):
|
|||
decoded[i] = decoded[i] * scale[i]
|
||||
|
||||
combined = sum(decoded) / len(decoded)
|
||||
return encode(combined, cfg.sample_rate, device=device, levels=levels)[0].t()
|
||||
return encode(combined, cfg.sample_rate, device=device)[0].t()
|
||||
|
||||
"""
|
||||
# Get framerate for a given audio backend
|
||||
def get_framerate( backend=None, sample_rate=None ):
|
||||
if not backend:
|
||||
backend = cfg.audio_backend
|
||||
if not sample_rate:
|
||||
sample_rate = cfg.sample_rate
|
||||
|
||||
if backend == "dac":
|
||||
if sample_rate == 44_100:
|
||||
return 87
|
||||
if sample_rate == 16_000:
|
||||
return 50
|
||||
|
||||
# 24Khz Encodec / Vocos and incidentally DAC are all at 75Hz
|
||||
return 75
|
||||
|
||||
# Generates quantized silence
|
||||
def get_silence( length, device=None, codes=None ):
|
||||
length = math.floor(length * get_framerate())
|
||||
if cfg.audio_backend == "dac":
|
||||
codes = [ 568, 804, 10, 674, 364, 981, 568, 378, 731 ]
|
||||
else:
|
||||
codes = [ 62, 424, 786, 673, 622, 986, 570, 948 ]
|
||||
|
||||
return torch.tensor([ codes for _ in range( length ) ], device=device, dtype=torch.int16)
|
||||
|
||||
# Pads a sequence of codes with silence
|
||||
def pad_codes_with_silence( codes, size=1 ):
|
||||
duration = codes.shape[0] * get_framerate()
|
||||
difference = math.ceil( duration + size ) - duration
|
||||
|
||||
silence = get_silence( difference, device=codes.device )
|
||||
|
||||
half = math.floor(difference / 2 * get_framerate())
|
||||
|
||||
return torch.concat( [ silence[half:, :], codes, silence[:half, :] ], dim=0 )
|
||||
|
||||
# Generates an empty waveform
|
||||
def get_silent_waveform( length, device=None ):
|
||||
length = math.floor(length * cfg.sample_rate)
|
||||
return torch.tensor( [ [ 0 for _ in range( length ) ] ], device=device, dtype=torch.float32 )
|
||||
|
||||
# Pads a waveform with silence
|
||||
def pad_waveform_with_silence( waveform, sample_rate, size=1 ):
|
||||
duration = waveform.shape[-1] / sample_rate
|
||||
difference = math.ceil( duration + size ) - duration
|
||||
|
||||
silence = get_silent_waveform( difference, device=waveform.device )
|
||||
|
||||
half = math.floor(difference / 2 * sample_rate)
|
||||
|
||||
return torch.concat( [ silence[:, half:], waveform, silence[:, :half] ], dim=-1 )
|
||||
|
||||
# Encodes/decodes audio, and helps me debug things
|
||||
if __name__ == "__main__":
|
||||
cfg.sample_rate = 48_000
|
||||
cfg.audio_backend = "audiodec"
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
wav, sr = torchaudio.load("in.wav")
|
||||
codes = encode( wav, sr ).t() # for some reason
|
||||
print( "ENCODED:", codes.shape, codes )
|
||||
wav, sr = decode( codes )
|
||||
print( "DECODED:", wav.shape, wav )
|
||||
torchaudio.save("out.wav", wav.cpu(), sr)
|
||||
"""
|
||||
parser.add_argument("--audio-backend", type=str, default="encodec")
|
||||
parser.add_argument("--input", type=Path)
|
||||
parser.add_argument("--output", type=Path, default=None)
|
||||
parser.add_argument("--device", type=str, default="cuda")
|
||||
parser.add_argument("--dtype", type=str, default="float16")
|
||||
parser.add_argument("--window-duration", type=float, default=None) # for DAC, the window duration for encoding / decoding
|
||||
parser.add_argument("--print", action="store_true") # prints codes and metadata
|
||||
parser.add_argument("--pad", action="store_true") # to test if padding with silence modifies the waveform / quants too much
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# prepare from args
|
||||
cfg.set_audio_backend(args.audio_backend)
|
||||
audio_extension = cfg.audio_backend_extension
|
||||
|
||||
cfg.inference.weight_dtype = args.dtype # "bfloat16"
|
||||
cfg.inference.amp = args.dtype != "float32"
|
||||
cfg.device = args.device
|
||||
|
||||
# decode
|
||||
if args.input.suffix == audio_extension:
|
||||
args.output = args.input.with_suffix('.wav') if not args.output else args.output.with_suffix('.wav')
|
||||
|
||||
artifact = np.load(args.input, allow_pickle=True)[()]
|
||||
codes = torch.from_numpy(artifact['codes'])[0][:, :].t().to(device=cfg.device, dtype=torch.int16)
|
||||
|
||||
# pad to nearest
|
||||
if args.pad:
|
||||
codes = pad_codes_with_silence( codes )
|
||||
del artifact['metadata']
|
||||
|
||||
waveform, sample_rate = decode( codes, device=cfg.device, metadata=artifact['metadata'] if 'metadata' in artifact else None, window_duration=args.window_duration )
|
||||
|
||||
torchaudio.save(args.output, waveform.cpu(), sample_rate)
|
||||
|
||||
# print
|
||||
if args.print:
|
||||
torch.set_printoptions(profile="full")
|
||||
|
||||
print( "Metadata:", artifact['metadata'] )
|
||||
print( "Codes:", codes.shape, codes )
|
||||
# encode
|
||||
else:
|
||||
args.output = args.input.with_suffix(audio_extension) if not args.output else args.output.with_suffix(audio_extension)
|
||||
|
||||
waveform, sample_rate = torchaudio.load(args.input)
|
||||
|
||||
# pad to nearest
|
||||
if args.pad:
|
||||
waveform = pad_waveform_with_silence( waveform, sample_rate )
|
||||
|
||||
qnt = encode(waveform.to(cfg.device), sr=sample_rate, device=cfg.device, window_duration=args.window_duration)
|
||||
|
||||
if cfg.audio_backend == "dac":
|
||||
state_dict = {
|
||||
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
|
||||
"metadata": {
|
||||
"original_length": qnt.original_length,
|
||||
"sample_rate": qnt.sample_rate,
|
||||
|
||||
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
|
||||
"chunk_length": qnt.chunk_length,
|
||||
"channels": qnt.channels,
|
||||
"padding": qnt.padding,
|
||||
"dac_version": "1.0.0",
|
||||
},
|
||||
}
|
||||
else:
|
||||
state_dict = {
|
||||
"codes": qnt.cpu().numpy().astype(np.uint16),
|
||||
"metadata": {
|
||||
"original_length": waveform.shape[-1],
|
||||
"sample_rate": sample_rate,
|
||||
},
|
||||
}
|
||||
np.save(open(args.output, "wb"), state_dict)
|
Loading…
Reference in New Issue
Block a user