cleaned up encode/decode functions to make them a little more coherent, added option to batch encode/decode (would have been very nice in the past, but this should speed things up for me when i fall for the latest meme codec)

This commit is contained in:
mrq 2025-02-05 20:54:31 -06:00
parent 84174c1c1b
commit 79c504c278
5 changed files with 142 additions and 173 deletions

BIN
test.wav

Binary file not shown.

View File

@ -245,7 +245,8 @@ class ModelExperimentalSettings:
# a model trained not summing audio embeddings *can* have this enabled without any apparent issues
# a model trained to sum *cannot* have this disabled without any apparent issues, or at least the ar+nar-retnet-8 can't.
# in theory a model that is trained to sum embeddings can peform better due to "seeing" previous levles (due to the R in RVQ standing for residuals...), but in practice it seems fine to not do so
audio_embedding_mode: str | None = None # None | "exclusive" | "inclusive", subjugates the audio backend's encoding/decoding model for embeddings
audio_embedding_mode: str | None = None # None | "exclusive" | "inclusive", subjugates the audio backend's encoding/decoding model for embeddings, currently not used
kv_heads: int = 0 # MHA or GQA (for supported backends)
rvq_levels_p: str | list = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely
rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range for LoRAs, isn't necesary

View File

@ -18,6 +18,8 @@ from einops import rearrange
from torch import Tensor
from tqdm import tqdm
from torch.nn.utils.rnn import pad_sequence
try:
from .codecs.encodec import *
except Exception as e:
@ -67,10 +69,8 @@ def _load_encodec_model(device="cuda", levels=0):
# extra metadata
model.bandwidth_id = bandwidth_id
model.sample_rate = cfg.sample_rate
model.normalize = cfg.inference.normalize
model.backend = "encodec"
model.device = device
return model
@ -96,9 +96,7 @@ def _load_vocos_model(device="cuda", levels=0):
# extra metadata
model.bandwidth_id = torch.tensor([bandwidth_id], device=device)
model.sample_rate = cfg.sample_rate
model.backend = "vocos"
model.device = device
return model
@ -119,23 +117,6 @@ def _load_dac_model(device="cuda"):
model.backend = "dac"
model.model_type = kwargs["model_type"]
#model.device = device
return model
@cache
def _load_audiodec_model(device="cuda", model_name=None):
if not model_name:
model_name = "libritts_v1" if cfg.sample_rate == 24_000 else "vctk_v1"
sample_rate, encoder_checkpoint, decoder_checkpoint = _audiodec_assign_model(model_name)
model = AudioDec(tx_device=device, rx_device=device)
model.load_transmitter(encoder_checkpoint)
model.load_receiver(encoder_checkpoint, decoder_checkpoint)
model.backend = "audiodec"
model.sample_rate = sample_rate
model.device = device
return model
@ -147,8 +128,6 @@ def _load_nemo_model(device="cuda", model_name=None):
model = AudioCodecModel.from_pretrained(model_name).to(device).eval()
model.backend = "nemo"
model.sample_rate = 44_100
#model.device = device
return model
@ -173,41 +152,39 @@ def unload_model():
_load_model.cache_clear()
_load_encodec_model.cache_clear() # because vocos can only decode
# to-do: clean up this mess
@torch.inference_mode()
def decode(codes: Tensor, device="cuda", metadata=None, window_duration=None):
# upcast so it won't whine
if codes.dtype == torch.int8 or codes.dtype == torch.int16 or codes.dtype == torch.uint8:
if codes.dtype in [torch.int8, torch.int16, torch.uint8]:
codes = codes.to(torch.int32)
# expand if we're given a raw 1-RVQ stream
if codes.dim() == 1:
codes = rearrange(codes, "t -> 1 1 t")
# expand to a batch size of one if not passed as a batch
# vocos does not do batch decoding, but encodec does, but we don't end up using this anyways *I guess*
# to-do, make this logical
elif codes.dim() == 2:
codes = rearrange(codes, "t q -> 1 q t")
# expand to a batch size of one if not passed as a batch
elif codes.dim() == 2:
# if (t, q), transpose to (q, t) instead
if codes.shape[0] > codes.shape[1]:
codes = codes.t()
codes = codes.unsqueeze(0)
# life is easier if we assume we're using a batch
assert codes.dim() == 3, f'Requires shape (b q t) but got {codes.shape}'
# load the model
model = _load_model(device)
# move to device
codes = codes.to( device=device )
# NeMo uses a different pathway
if model.backend == "nemo":
# ugh
codes = rearrange( codes, "b q t -> b t q")
codes = codes.to( device=device )
l = torch.tensor([codes.shape[-1]], device=device, dtype=torch.int32)
l = torch.tensor([c.shape[-1] for c in codes], device=device, dtype=torch.int32)
wav, _ = model.decode(tokens=codes, tokens_len=l)
return wav, model.sample_rate
# AudioDec uses a different pathway
if model.backend == "audiodec":
codes = codes.to( device=device )[0]
zq = model.rx_encoder.lookup( codes )
wav = model.decoder.decode(zq).squeeze(1)
return wav, model.sample_rate
return wav, cfg.sample_rate
assert codes.shape[0] == 1, f'Batch decoding is unsupported for backend: {model.backend}'
# DAC uses a different pathway
if model.backend == "dac":
@ -218,7 +195,7 @@ def decode(codes: Tensor, device="cuda", metadata=None, window_duration=None):
original_length=0,
input_db=-12,
channels=1,
sample_rate=model.sample_rate,
sample_rate=cfg.sample_rate,
padding=True,
dac_version='1.0.0',
)
@ -242,21 +219,52 @@ def decode(codes: Tensor, device="cuda", metadata=None, window_duration=None):
# to-do: inject the sample rate encoded at, because we can actually decouple
return CodecMixin_decompress(model, artifact, verbose=False).audio_data[0], artifact.sample_rate
kwargs = {}
# cleaner to separate out from EnCodec's pathway
if model.backend == "vocos":
x = model.codes_to_features(codes[0])
kwargs['bandwidth_id'] = model.bandwidth_id
else:
# encodec will decode as a batch
x = [(codes.to(device), None)]
wav = model.decode(x, bandwidth_id=model.bandwidth_id)
return wav, cfg.sample_rate
wav = model.decode(x, **kwargs)
# encodec will decode as a batch
if model.backend == "encodec":
wav = wav[0]
x = [(codes.to(device), None)]
wav = model.decode(x)
return wav, cfg.sample_rate
return wav, model.sample_rate
@torch.inference_mode()
def decode_batch(codes: list[Tensor], device="cuda"):
# transpose if needed
for i, code in enumerate(codes):
if code.shape[0] < code.shape[1]:
codes[i] = code.t()
# store lengths
lens = torch.tensor([code.shape[0] for code in codes], device=device, dtype=torch.int32)
# pad and concat
codes = pad_sequence(codes, batch_first=True)
# re-transpose if needed
if codes.shape[1] > codes.shape[2]:
codes = rearrange(codes, "b t q -> b q t")
# upcast so it won't whine
if codes.dtype in [torch.int8, torch.int16, torch.uint8]:
codes = codes.to(torch.int32)
assert codes.dim() == 3, f'Requires shape (b q t) but got {codes.shape}'
# load the model
model = _load_model(device)
# move to device
codes = codes.to( device=device )
# NeMo uses a different pathway
if model.backend == "nemo":
wav, lens = model.decode(tokens=codes, tokens_len=lens)
return [ wav[:l].unsqueeze(0) for wav, l in zip(wav, lens) ], cfg.sample_rate
# to-do: implement for encodec and vocos
raise Exception(f"Batch decoding unsupported for backend {cfg.audio_backend}")
# huh
def decode_to_wave(resps: Tensor, device="cuda"):
@ -271,125 +279,96 @@ def decode_to_file(resps: Tensor, path: Path, device="cuda"):
def _replace_file_extension(path, suffix):
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
# an experimental way to include "trained" embeddings from the audio backend itself
# > b-but why not just initialize the embedding weights to these instead of fetching them at r-runtime
# each audio backend does their "embeddings" a different way that isn't just a embedding weights
#
# this is overkill and I don't feel like this benefits anything, but it was an idea I had
# this only really works if the embedding dims match, and either a Linear to rescale would be needed or semi-erroneously just padding with 0s
@torch.inference_mode()
def encode_as_embedding(codes: Tensor, quant_level: int = 0, sums=False, device="cuda"):
model = _load_model(device)
codes = codes.to(device=device, dtype=torch.int32)
# yucky kludge
if sums:
if codes.dim() == 1:
codes = rearrange(codes, "t -> t 1")
if cfg.audio_backend == "dac":
x = []
for i in range(quant_level+1):
emb = model.quantizer.quantizers[i]
code = rearrange(codes[:, quant_level], "t -> 1 t")
xi = emb.decode_code(code)
xi = emb.out_proj(xi)
x.append( xi[0].t() )
return sum(x).detach()
raise Exception(f'Currently only DAC is supported')
if codes.dim() == 2:
codes = codes[:, quant_level]
codes = rearrange(codes, "t -> 1 t")
# dac conveniently has its dim = 1024
if cfg.audio_backend == "dac":
emb = model.quantizer.quantizers[quant_level]
x = emb.decode_code(codes)
x = emb.out_proj(x)
x = x[0].t().detach()
return x
"""
# vocos inconveniently has its dim = 128
elif cfg.audio_backend == "vocos":
x = model.codes_to_features(codes)
# encodec inconveniently has its dim = 300
elif cfg.audio_backend == "encodec":
...
"""
raise Exception(f'Currently only DAC is supported')
@torch.inference_mode()
def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", return_metadata=True, window_duration=None):
# NeMo uses a different pathway
if cfg.audio_backend == "nemo":
model = _load_nemo_model( device )
# reshape (channel, samples) => (batch, channel, samples)
if wav.dim() < 3:
wav = wav.unsqueeze(0)
# skip unnecessary resample
if sr != model.sample_rate or wav.shape[1] != 1:
wav = convert_audio(wav, sr, model.sample_rate, 1)
# expand if 1D
if wav.dim() < 2:
wav = wav.unsqueeze(0)
# reshape (channels, samples) => (batch, channel, samples)
if wav.dim() < 3:
wav = wav.unsqueeze(0)
wav = wav.to(device)[0, :, :]
l = torch.tensor([wav[0].shape[0]]).to(device)
codes, _ = model.encode(audio=wav, audio_len=l)
# ( batch, level, frame )
return codes[0]
# cringe assert
assert wav.shape[0] == 1, f'Batch encoding is unsupported with vanilla encode()'
# DAC uses a different pathway
if cfg.audio_backend == "dac":
model = _load_dac_model( device )
signal = AudioSignal(wav, sample_rate=sr)
artifact = model.compress(signal, win_duration=window_duration, verbose=False) # , n_quantizers=levels)
#artifact = model.compress(signal)
artifact = model.compress(signal, win_duration=window_duration, verbose=False)
return artifact.codes if not return_metadata else artifact
# AudioDec uses a different pathway
if cfg.audio_backend == "audiodec":
model = _load_audiodec_model(device)
# reshape (channel, samples) => (batch, channel, samples)
if wav.dim() < 3:
wav = wav.unsqueeze(0)
# skip unnecessary resample
if sr != model.sample_rate or wav.shape[1] != 1:
wav = convert_audio(wav, sr, model.sample_rate, 1)
wav = wav.to(device)
# wav = rearrange(wav, "t c -> t 1 c").to(device)
encoded = model.tx_encoder.encode(wav)
quantized = model.tx_encoder.quantize(encoded)
return quantized
# vocos does not encode wavs to encodecs, so just use normal encodec
model = _load_encodec_model(device)
# reshape (channel, samples) => (batch, channel, samples)
if wav.dim() < 3:
wav = wav.unsqueeze(0)
# skip unnecessary resample
if sr != model.sample_rate or wav.shape[1] != model.channels:
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
# resample if necessary
if sr != cfg.sample_rate or wav.shape[1] != 1:
wav = convert_audio(wav, sr, cfg.sample_rate, 1)
wav = wav.to(device)
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
encoded_frames = model.encode(wav)
qnt = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # (b q t)
# NeMo uses a different pathway
if cfg.audio_backend == "nemo":
model = _load_nemo_model( device )
return qnt
wav = wav.to(device)[:, 0, :]
l = torch.tensor([w.shape[0] for w in wav]).to(device)
codes, lens = model.encode(audio=wav, audio_len=l)
# to-do: unpad
return codes
# vocos does not encode wavs to encodecs, so just use normal encodec
if cfg.audio_backend in ["encodec", "vocos"]:
model = _load_encodec_model(device)
codes = model.encode(wav)
codes = torch.cat([code[0] for code in codes], dim=-1) # (b q t)
return codes
@torch.inference_mode()
def encode_batch( wavs: list[Tensor], sr: list[int] | int = cfg.sample_rate, device="cuda" ):
# expand as list
if not isinstance(sr, list):
sr = [sr] * len(wavs)
# resample if necessary
for i, wav in enumerate(wavs):
if sr[i] != cfg.sample_rate or wavs[i].shape[1] != 1:
wavs[i] = convert_audio(wavs[i], sr[i], cfg.sample_rate, 1)
# (frames) => (channel, frames)
if wavs[i].dim() < 2:
wavs[i] = wavs[i].unsqueeze(0)
# transpose is required
if wavs[i].shape[0] < wavs[i].shape[1]:
wavs[i] = wavs[i].t()
# store lengths
lens = torch.tensor([wav.shape[0] for wav in wavs], device=device, dtype=torch.int32)
# pad and concat (transpose because pad_sequence requires it this way)
wav = pad_sequence(wavs, batch_first=True)
# untranspose
wav = rearrange(wav, "b t c -> b c t")
#
wav = wav.to(device)
# NeMo uses a different pathway
if cfg.audio_backend == "nemo":
model = _load_nemo_model( device )
wav = wav.to(device)[:, 0, :]
codes, code_lens = model.encode(audio=wav, audio_len=lens)
return [ code[:, :l] for code, l in zip( codes, code_lens ) ]
# can't be assed to implement
if cfg.audio_backend == "dac":
raise Exception(f"Batch encoding unsupported for backend {cfg.audio_backend}")
# naively encode
if cfg.audio_backend in ["encodec", "vocos"]:
model = _load_encodec_model(device)
codes = model.encode(wav)
codes = torch.cat([code[0] for code in codes], dim=-1) # (b q t)
return [ code[:, :l * cfg.dataset.frames_per_second // cfg.sample_rate] for code, l in zip(codes, lens) ]
def encode_from_files(paths, device="cuda"):
tuples = [ torchaudio.load(str(path)) for path in paths ]

View File

@ -23,7 +23,7 @@ import logging
_logger = logging.getLogger(__name__)
from ..emb.qnt import trim, encode_as_embedding, get_silence
from ..emb.qnt import trim, get_silence
from ..utils import get_devices, setup_logging, timer, clamp, convert_kwargs
from .lora import enable_lora

View File

@ -32,7 +32,6 @@ from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, Mult
from .arch import *
from ..utils import wrapper as ml, clamp
from ..samplers import *
from ..emb.qnt import encode_as_embedding
# yuck, kind of needed
from ..data import get_task_symmap
@ -115,21 +114,11 @@ def _join(x: tuple[Tensor], sep: Tensor):
ret = torch.cat((ret, sep[None], x[i]), dim=0)
return ret
def list_to_tensor(x_list: list[Tensor], pattern="t b c -> b t c"):
"""
Args:
x_list: [(t d)]
Returns:
x: (? ? ?)
m: (? ? ?), same as x
"""
def list_to_tensor(x_list: list[Tensor]):
l = list(map(len, x_list))
x = rearrange(pad_sequence(x_list), pattern)
x = pad_sequence(x_list, batch_first=True)
m = _create_mask(l, x_list[0].device)
"""
m = m.t().unsqueeze(-1) # (t b 1)
m = rearrange(m, pattern)
"""
m = m.to(x).int()
return x, m