From 1ecf2793f4c8a046416fcc03a4f58310ac53f6b9 Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 4 Jul 2024 15:40:51 -0500 Subject: [PATCH] (commented-out) support for facebookresearch/AudioDec, but support really didn't wow me (so I commented it out until I figure out why my output audio is super crusty with AudioDec) --- vall_e/config.py | 3 ++ vall_e/emb/qnt.py | 85 +++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 81 insertions(+), 7 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index 36899f3..e0582fa 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -776,6 +776,9 @@ class Config(BaseConfig): self.models = [ Model(**model) for model in self.models ] self.loras = [ LoRA(**lora) for lora in self.loras ] + if not self.models: + self.models = [ Model() ] + for model in self.models: if not isinstance( model.experimental, dict ): continue diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index 4c9c534..ab61d63 100755 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -89,6 +89,18 @@ try: except Exception as e: cfg.inference.use_dac = False print(str(e)) + +""" +# uses https://github.com/facebookresearch/AudioDec/ +# I have set up a pip-ify'd version with the caveat of having to manually handle downloading the checkpoints with a wget + unzip +# I was not happy with testing, it sounded rather mediocre. +try: + from audiodec.utils.audiodec import AudioDec, assign_model as _audiodec_assign_model +except Exception as e: + cfg.inference.use_audiodec = False + print(str(e)) +""" + @cache def _load_encodec_model(device="cuda", levels=cfg.model.max_levels): assert cfg.sample_rate == 24_000 @@ -146,7 +158,7 @@ def _load_dac_model(device="cuda", levels=cfg.model.max_levels): kwargs = dict(model_type="44khz",model_bitrate="8kbps",tag="latest") if not cfg.variable_sample_rate: # yes there's a better way, something like f'{cfg.sample.rate//1000}hz' - if cfg.sample_rate == 44_000: + if cfg.sample_rate == 44_000 or cfg.sample_rate == 44_100: kwargs["model_type"] = "44khz" elif cfg.sample_rate == 16_000: kwargs["model_type"] = "16khz" @@ -157,10 +169,7 @@ def _load_dac_model(device="cuda", levels=cfg.model.max_levels): model = model.to(device) model = model.eval() - # extra metadata - - # since DAC moreso models against waveforms, we can actually use a smaller sample rate - # updating it here will affect the sample rate the waveform is resampled to on encoding + # to revisit later, but experiments shown that this is a bad idea if cfg.variable_sample_rate: model.sample_rate = cfg.sample_rate @@ -169,8 +178,29 @@ def _load_dac_model(device="cuda", levels=cfg.model.max_levels): return model +""" +@cache +def _load_audiodec_model(device="cuda", model_name=None, levels=cfg.model.max_levels): + if not model_name: + model_name = "libritts_v1" if cfg.sample_rate == 24_000 else "vctk_v1" + sample_rate, encoder_checkpoint, decoder_checkpoint = _audiodec_assign_model(model_name) + + model = AudioDec(tx_device=device , rx_device=device ) + model.load_transmitter(encoder_checkpoint) + model.load_receiver(encoder_checkpoint, decoder_checkpoint) + + model.backend = "audiodec" + model.sample_rate = sample_rate + + return model +""" + @cache def _load_model(device="cuda", backend=cfg.audio_backend, levels=cfg.model.max_levels): + """ + if backend == "audiodec": + return _load_audiodec_model(device, levels=levels) + """ if backend == "dac": return _load_dac_model(device, levels=levels) if backend == "vocos": @@ -202,6 +232,15 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N # load the model model = _load_model(device, levels=levels) + # AudioDec uses a different pathway + """ + if model.backend == "audiodec": + codes = codes.to( device=device )[0] + zq = model.rx_encoder.lookup( codes ) + wav = model.decoder.decode(zq).squeeze(1) + return wav, model.sample_rate + """ + # DAC uses a different pathway if model.backend == "dac": dummy = False @@ -233,7 +272,6 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N # to-do: inject the sample rate encoded at, because we can actually decouple return CodecMixin_decompress(model, artifact, verbose=False).audio_data[0], artifact.sample_rate - kwargs = {} if model.backend == "vocos": x = model.codes_to_features(codes[0]) @@ -266,6 +304,9 @@ def _replace_file_extension(path, suffix): # an experimental way to include "trained" embeddings from the audio backend itself # > b-but why not just initialize the embedding weights to these instead of fetching them at r-runtime # each audio backend does their "embeddings" a different way that isn't just a embedding weights +# +# this is overkill and I don't feel like this benefits anything, but it was an idea I had +# this only really works if the embedding dims match, and either a Linear to rescale would be needed or semi-erroneously just padding with 0s @torch.inference_mode() def encode_as_embedding(codes: Tensor, quant_level: int = 0, sums=False, device="cuda"): model = _load_model(device) @@ -320,6 +361,8 @@ def encode_as_embedding(codes: Tensor, quant_level: int = 0, sums=False, device= @torch.inference_mode() def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.model.max_levels, return_metadata=True): + + # DAC uses a different pathway if cfg.audio_backend == "dac": model = _load_dac_model(device, levels=levels ) signal = AudioSignal(wav, sample_rate=sr) @@ -330,7 +373,20 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.mod with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp): artifact = model.compress(signal, win_duration=None, verbose=False, n_quantizers=levels) + # AudioDec uses a different pathway + """ + if cfg.audio_backend == "audiodec": + model = _load_audiodec_model(device, levels=levels ) + wav = wav.unsqueeze(0) + wav = convert_audio(wav, sr, model.sample_rate, 1) + wav = wav.to(device) + + # wav = rearrange(wav, "t c -> t 1 c").to(device) + encoded = model.tx_encoder.encode(wav) + quantized = model.tx_encoder.quantize(encoded) + return quantized return artifact.codes if not return_metadata else artifact + """ # vocos does not encode wavs to encodecs, so just use normal encodec model = _load_encodec_model(device, levels=levels) @@ -430,4 +486,19 @@ def merge_audio( *args, device="cpu", scale=[], levels=cfg.model.max_levels ): decoded[i] = decoded[i] * scale[i] combined = sum(decoded) / len(decoded) - return encode(combined, cfg.sample_rate, device="cpu", levels=levels)[0].t() \ No newline at end of file + return encode(combined, cfg.sample_rate, device="cpu", levels=levels)[0].t() + +""" +if __name__ == "__main__": + from vall_e.emb.qnt import encode, decode, cfg + + cfg.sample_rate = 48_000 + cfg.audio_backend = "audiodec" + + wav, sr = torchaudio.load("in.wav") + codes = encode( wav, sr ).t() + print( "ENCODED:", codes.shape, codes ) + wav, sr = decode( codes ) + print( "DECODED:", wav.shape, wav ) + torchaudio.save("out.wav", wav.cpu(), sr) +""" \ No newline at end of file