diff --git a/test.wav b/test.wav deleted file mode 100644 index e057d6c..0000000 Binary files a/test.wav and /dev/null differ diff --git a/vall_e/config.py b/vall_e/config.py index 9f59d14..8b9b087 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -245,7 +245,8 @@ class ModelExperimentalSettings: # a model trained not summing audio embeddings *can* have this enabled without any apparent issues # a model trained to sum *cannot* have this disabled without any apparent issues, or at least the ar+nar-retnet-8 can't. # in theory a model that is trained to sum embeddings can peform better due to "seeing" previous levles (due to the R in RVQ standing for residuals...), but in practice it seems fine to not do so - audio_embedding_mode: str | None = None # None | "exclusive" | "inclusive", subjugates the audio backend's encoding/decoding model for embeddings + + audio_embedding_mode: str | None = None # None | "exclusive" | "inclusive", subjugates the audio backend's encoding/decoding model for embeddings, currently not used kv_heads: int = 0 # MHA or GQA (for supported backends) rvq_levels_p: str | list = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range for LoRAs, isn't necesary diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index 454865e..09c6742 100755 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -18,6 +18,8 @@ from einops import rearrange from torch import Tensor from tqdm import tqdm +from torch.nn.utils.rnn import pad_sequence + try: from .codecs.encodec import * except Exception as e: @@ -67,10 +69,8 @@ def _load_encodec_model(device="cuda", levels=0): # extra metadata model.bandwidth_id = bandwidth_id - model.sample_rate = cfg.sample_rate model.normalize = cfg.inference.normalize model.backend = "encodec" - model.device = device return model @@ -96,9 +96,7 @@ def _load_vocos_model(device="cuda", levels=0): # extra metadata model.bandwidth_id = torch.tensor([bandwidth_id], device=device) - model.sample_rate = cfg.sample_rate model.backend = "vocos" - model.device = device return model @@ -119,23 +117,6 @@ def _load_dac_model(device="cuda"): model.backend = "dac" model.model_type = kwargs["model_type"] - #model.device = device - - return model - -@cache -def _load_audiodec_model(device="cuda", model_name=None): - if not model_name: - model_name = "libritts_v1" if cfg.sample_rate == 24_000 else "vctk_v1" - sample_rate, encoder_checkpoint, decoder_checkpoint = _audiodec_assign_model(model_name) - - model = AudioDec(tx_device=device, rx_device=device) - model.load_transmitter(encoder_checkpoint) - model.load_receiver(encoder_checkpoint, decoder_checkpoint) - - model.backend = "audiodec" - model.sample_rate = sample_rate - model.device = device return model @@ -147,8 +128,6 @@ def _load_nemo_model(device="cuda", model_name=None): model = AudioCodecModel.from_pretrained(model_name).to(device).eval() model.backend = "nemo" - model.sample_rate = 44_100 - #model.device = device return model @@ -173,41 +152,39 @@ def unload_model(): _load_model.cache_clear() _load_encodec_model.cache_clear() # because vocos can only decode +# to-do: clean up this mess @torch.inference_mode() def decode(codes: Tensor, device="cuda", metadata=None, window_duration=None): # upcast so it won't whine - if codes.dtype == torch.int8 or codes.dtype == torch.int16 or codes.dtype == torch.uint8: + if codes.dtype in [torch.int8, torch.int16, torch.uint8]: codes = codes.to(torch.int32) # expand if we're given a raw 1-RVQ stream if codes.dim() == 1: codes = rearrange(codes, "t -> 1 1 t") - # expand to a batch size of one if not passed as a batch - # vocos does not do batch decoding, but encodec does, but we don't end up using this anyways *I guess* - # to-do, make this logical - elif codes.dim() == 2: - codes = rearrange(codes, "t q -> 1 q t") + # expand to a batch size of one if not passed as a batch + elif codes.dim() == 2: + # if (t, q), transpose to (q, t) instead + if codes.shape[0] > codes.shape[1]: + codes = codes.t() + codes = codes.unsqueeze(0) + + # life is easier if we assume we're using a batch assert codes.dim() == 3, f'Requires shape (b q t) but got {codes.shape}' # load the model model = _load_model(device) + # move to device + codes = codes.to( device=device ) # NeMo uses a different pathway if model.backend == "nemo": - # ugh - codes = rearrange( codes, "b q t -> b t q") - codes = codes.to( device=device ) - l = torch.tensor([codes.shape[-1]], device=device, dtype=torch.int32) + l = torch.tensor([c.shape[-1] for c in codes], device=device, dtype=torch.int32) wav, _ = model.decode(tokens=codes, tokens_len=l) - return wav, model.sample_rate - - # AudioDec uses a different pathway - if model.backend == "audiodec": - codes = codes.to( device=device )[0] - zq = model.rx_encoder.lookup( codes ) - wav = model.decoder.decode(zq).squeeze(1) - return wav, model.sample_rate + return wav, cfg.sample_rate + + assert codes.shape[0] == 1, f'Batch decoding is unsupported for backend: {model.backend}' # DAC uses a different pathway if model.backend == "dac": @@ -218,7 +195,7 @@ def decode(codes: Tensor, device="cuda", metadata=None, window_duration=None): original_length=0, input_db=-12, channels=1, - sample_rate=model.sample_rate, + sample_rate=cfg.sample_rate, padding=True, dac_version='1.0.0', ) @@ -242,21 +219,52 @@ def decode(codes: Tensor, device="cuda", metadata=None, window_duration=None): # to-do: inject the sample rate encoded at, because we can actually decouple return CodecMixin_decompress(model, artifact, verbose=False).audio_data[0], artifact.sample_rate - kwargs = {} + # cleaner to separate out from EnCodec's pathway if model.backend == "vocos": x = model.codes_to_features(codes[0]) - kwargs['bandwidth_id'] = model.bandwidth_id - else: - # encodec will decode as a batch - x = [(codes.to(device), None)] + wav = model.decode(x, bandwidth_id=model.bandwidth_id) + return wav, cfg.sample_rate - wav = model.decode(x, **kwargs) - - # encodec will decode as a batch if model.backend == "encodec": - wav = wav[0] + x = [(codes.to(device), None)] + wav = model.decode(x) + return wav, cfg.sample_rate - return wav, model.sample_rate +@torch.inference_mode() +def decode_batch(codes: list[Tensor], device="cuda"): + # transpose if needed + for i, code in enumerate(codes): + if code.shape[0] < code.shape[1]: + codes[i] = code.t() + + # store lengths + lens = torch.tensor([code.shape[0] for code in codes], device=device, dtype=torch.int32) + + # pad and concat + codes = pad_sequence(codes, batch_first=True) + + # re-transpose if needed + if codes.shape[1] > codes.shape[2]: + codes = rearrange(codes, "b t q -> b q t") + + # upcast so it won't whine + if codes.dtype in [torch.int8, torch.int16, torch.uint8]: + codes = codes.to(torch.int32) + + assert codes.dim() == 3, f'Requires shape (b q t) but got {codes.shape}' + + # load the model + model = _load_model(device) + # move to device + codes = codes.to( device=device ) + + # NeMo uses a different pathway + if model.backend == "nemo": + wav, lens = model.decode(tokens=codes, tokens_len=lens) + return [ wav[:l].unsqueeze(0) for wav, l in zip(wav, lens) ], cfg.sample_rate + + # to-do: implement for encodec and vocos + raise Exception(f"Batch decoding unsupported for backend {cfg.audio_backend}") # huh def decode_to_wave(resps: Tensor, device="cuda"): @@ -271,125 +279,96 @@ def decode_to_file(resps: Tensor, path: Path, device="cuda"): def _replace_file_extension(path, suffix): return (path.parent / path.name.split(".")[0]).with_suffix(suffix) -# an experimental way to include "trained" embeddings from the audio backend itself -# > b-but why not just initialize the embedding weights to these instead of fetching them at r-runtime -# each audio backend does their "embeddings" a different way that isn't just a embedding weights -# -# this is overkill and I don't feel like this benefits anything, but it was an idea I had -# this only really works if the embedding dims match, and either a Linear to rescale would be needed or semi-erroneously just padding with 0s -@torch.inference_mode() -def encode_as_embedding(codes: Tensor, quant_level: int = 0, sums=False, device="cuda"): - model = _load_model(device) - - codes = codes.to(device=device, dtype=torch.int32) - - # yucky kludge - if sums: - if codes.dim() == 1: - codes = rearrange(codes, "t -> t 1") - - if cfg.audio_backend == "dac": - x = [] - for i in range(quant_level+1): - emb = model.quantizer.quantizers[i] - code = rearrange(codes[:, quant_level], "t -> 1 t") - - xi = emb.decode_code(code) - xi = emb.out_proj(xi) - x.append( xi[0].t() ) - - return sum(x).detach() - - raise Exception(f'Currently only DAC is supported') - - - if codes.dim() == 2: - codes = codes[:, quant_level] - - codes = rearrange(codes, "t -> 1 t") - - # dac conveniently has its dim = 1024 - if cfg.audio_backend == "dac": - emb = model.quantizer.quantizers[quant_level] - - x = emb.decode_code(codes) - x = emb.out_proj(x) - x = x[0].t().detach() - - return x - - """ - # vocos inconveniently has its dim = 128 - elif cfg.audio_backend == "vocos": - x = model.codes_to_features(codes) - # encodec inconveniently has its dim = 300 - elif cfg.audio_backend == "encodec": - ... - """ - - raise Exception(f'Currently only DAC is supported') - @torch.inference_mode() def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", return_metadata=True, window_duration=None): - # NeMo uses a different pathway - if cfg.audio_backend == "nemo": - model = _load_nemo_model( device ) - # reshape (channel, samples) => (batch, channel, samples) - if wav.dim() < 3: - wav = wav.unsqueeze(0) - # skip unnecessary resample - if sr != model.sample_rate or wav.shape[1] != 1: - wav = convert_audio(wav, sr, model.sample_rate, 1) + # expand if 1D + if wav.dim() < 2: + wav = wav.unsqueeze(0) + # reshape (channels, samples) => (batch, channel, samples) + if wav.dim() < 3: + wav = wav.unsqueeze(0) - wav = wav.to(device)[0, :, :] - l = torch.tensor([wav[0].shape[0]]).to(device) - - codes, _ = model.encode(audio=wav, audio_len=l) - # ( batch, level, frame ) - return codes[0] + # cringe assert + assert wav.shape[0] == 1, f'Batch encoding is unsupported with vanilla encode()' # DAC uses a different pathway if cfg.audio_backend == "dac": model = _load_dac_model( device ) signal = AudioSignal(wav, sample_rate=sr) - artifact = model.compress(signal, win_duration=window_duration, verbose=False) # , n_quantizers=levels) - #artifact = model.compress(signal) + artifact = model.compress(signal, win_duration=window_duration, verbose=False) return artifact.codes if not return_metadata else artifact - - # AudioDec uses a different pathway - if cfg.audio_backend == "audiodec": - model = _load_audiodec_model(device) - # reshape (channel, samples) => (batch, channel, samples) - if wav.dim() < 3: - wav = wav.unsqueeze(0) - # skip unnecessary resample - if sr != model.sample_rate or wav.shape[1] != 1: - wav = convert_audio(wav, sr, model.sample_rate, 1) - wav = wav.to(device) - - # wav = rearrange(wav, "t c -> t 1 c").to(device) - encoded = model.tx_encoder.encode(wav) - quantized = model.tx_encoder.quantize(encoded) - return quantized - - # vocos does not encode wavs to encodecs, so just use normal encodec - model = _load_encodec_model(device) - # reshape (channel, samples) => (batch, channel, samples) - if wav.dim() < 3: - wav = wav.unsqueeze(0) - # skip unnecessary resample - if sr != model.sample_rate or wav.shape[1] != model.channels: - wav = convert_audio(wav, sr, model.sample_rate, model.channels) - + + # resample if necessary + if sr != cfg.sample_rate or wav.shape[1] != 1: + wav = convert_audio(wav, sr, cfg.sample_rate, 1) + wav = wav.to(device) - with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp): - encoded_frames = model.encode(wav) - qnt = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # (b q t) + # NeMo uses a different pathway + if cfg.audio_backend == "nemo": + model = _load_nemo_model( device ) - return qnt + wav = wav.to(device)[:, 0, :] + l = torch.tensor([w.shape[0] for w in wav]).to(device) + codes, lens = model.encode(audio=wav, audio_len=l) + # to-do: unpad + return codes + # vocos does not encode wavs to encodecs, so just use normal encodec + if cfg.audio_backend in ["encodec", "vocos"]: + model = _load_encodec_model(device) + codes = model.encode(wav) + codes = torch.cat([code[0] for code in codes], dim=-1) # (b q t) + return codes + +@torch.inference_mode() +def encode_batch( wavs: list[Tensor], sr: list[int] | int = cfg.sample_rate, device="cuda" ): + # expand as list + if not isinstance(sr, list): + sr = [sr] * len(wavs) + + # resample if necessary + for i, wav in enumerate(wavs): + if sr[i] != cfg.sample_rate or wavs[i].shape[1] != 1: + wavs[i] = convert_audio(wavs[i], sr[i], cfg.sample_rate, 1) + + # (frames) => (channel, frames) + if wavs[i].dim() < 2: + wavs[i] = wavs[i].unsqueeze(0) + + # transpose is required + if wavs[i].shape[0] < wavs[i].shape[1]: + wavs[i] = wavs[i].t() + + # store lengths + lens = torch.tensor([wav.shape[0] for wav in wavs], device=device, dtype=torch.int32) + + # pad and concat (transpose because pad_sequence requires it this way) + wav = pad_sequence(wavs, batch_first=True) + # untranspose + wav = rearrange(wav, "b t c -> b c t") + # + wav = wav.to(device) + + # NeMo uses a different pathway + if cfg.audio_backend == "nemo": + model = _load_nemo_model( device ) + wav = wav.to(device)[:, 0, :] + codes, code_lens = model.encode(audio=wav, audio_len=lens) + return [ code[:, :l] for code, l in zip( codes, code_lens ) ] + + # can't be assed to implement + if cfg.audio_backend == "dac": + raise Exception(f"Batch encoding unsupported for backend {cfg.audio_backend}") + + # naively encode + if cfg.audio_backend in ["encodec", "vocos"]: + model = _load_encodec_model(device) + codes = model.encode(wav) + codes = torch.cat([code[0] for code in codes], dim=-1) # (b q t) + + return [ code[:, :l * cfg.dataset.frames_per_second // cfg.sample_rate] for code, l in zip(codes, lens) ] def encode_from_files(paths, device="cuda"): tuples = [ torchaudio.load(str(path)) for path in paths ] diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 78326a8..4c8d9ec 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -23,7 +23,7 @@ import logging _logger = logging.getLogger(__name__) -from ..emb.qnt import trim, encode_as_embedding, get_silence +from ..emb.qnt import trim, get_silence from ..utils import get_devices, setup_logging, timer, clamp, convert_kwargs from .lora import enable_lora diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 06bb89d..d224871 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -32,7 +32,6 @@ from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, Mult from .arch import * from ..utils import wrapper as ml, clamp from ..samplers import * -from ..emb.qnt import encode_as_embedding # yuck, kind of needed from ..data import get_task_symmap @@ -115,21 +114,11 @@ def _join(x: tuple[Tensor], sep: Tensor): ret = torch.cat((ret, sep[None], x[i]), dim=0) return ret -def list_to_tensor(x_list: list[Tensor], pattern="t b c -> b t c"): - """ - Args: - x_list: [(t d)] - Returns: - x: (? ? ?) - m: (? ? ?), same as x - """ +def list_to_tensor(x_list: list[Tensor]): l = list(map(len, x_list)) - x = rearrange(pad_sequence(x_list), pattern) + x = pad_sequence(x_list, batch_first=True) m = _create_mask(l, x_list[0].device) - """ - m = m.t().unsqueeze(-1) # (t b 1) - m = rearrange(m, pattern) - """ + m = m.to(x).int() return x, m