(commented-out) support for facebookresearch/AudioDec, but support really didn't wow me (so I commented it out until I figure out why my output audio is super crusty with AudioDec)

This commit is contained in:
mrq 2024-07-04 15:40:51 -05:00
parent db62e55a38
commit 1ecf2793f4
2 changed files with 81 additions and 7 deletions

View File

@ -776,6 +776,9 @@ class Config(BaseConfig):
self.models = [ Model(**model) for model in self.models ]
self.loras = [ LoRA(**lora) for lora in self.loras ]
if not self.models:
self.models = [ Model() ]
for model in self.models:
if not isinstance( model.experimental, dict ):
continue

View File

@ -89,6 +89,18 @@ try:
except Exception as e:
cfg.inference.use_dac = False
print(str(e))
"""
# uses https://github.com/facebookresearch/AudioDec/
# I have set up a pip-ify'd version with the caveat of having to manually handle downloading the checkpoints with a wget + unzip
# I was not happy with testing, it sounded rather mediocre.
try:
from audiodec.utils.audiodec import AudioDec, assign_model as _audiodec_assign_model
except Exception as e:
cfg.inference.use_audiodec = False
print(str(e))
"""
@cache
def _load_encodec_model(device="cuda", levels=cfg.model.max_levels):
assert cfg.sample_rate == 24_000
@ -146,7 +158,7 @@ def _load_dac_model(device="cuda", levels=cfg.model.max_levels):
kwargs = dict(model_type="44khz",model_bitrate="8kbps",tag="latest")
if not cfg.variable_sample_rate:
# yes there's a better way, something like f'{cfg.sample.rate//1000}hz'
if cfg.sample_rate == 44_000:
if cfg.sample_rate == 44_000 or cfg.sample_rate == 44_100:
kwargs["model_type"] = "44khz"
elif cfg.sample_rate == 16_000:
kwargs["model_type"] = "16khz"
@ -157,10 +169,7 @@ def _load_dac_model(device="cuda", levels=cfg.model.max_levels):
model = model.to(device)
model = model.eval()
# extra metadata
# since DAC moreso models against waveforms, we can actually use a smaller sample rate
# updating it here will affect the sample rate the waveform is resampled to on encoding
# to revisit later, but experiments shown that this is a bad idea
if cfg.variable_sample_rate:
model.sample_rate = cfg.sample_rate
@ -169,8 +178,29 @@ def _load_dac_model(device="cuda", levels=cfg.model.max_levels):
return model
"""
@cache
def _load_audiodec_model(device="cuda", model_name=None, levels=cfg.model.max_levels):
if not model_name:
model_name = "libritts_v1" if cfg.sample_rate == 24_000 else "vctk_v1"
sample_rate, encoder_checkpoint, decoder_checkpoint = _audiodec_assign_model(model_name)
model = AudioDec(tx_device=device , rx_device=device )
model.load_transmitter(encoder_checkpoint)
model.load_receiver(encoder_checkpoint, decoder_checkpoint)
model.backend = "audiodec"
model.sample_rate = sample_rate
return model
"""
@cache
def _load_model(device="cuda", backend=cfg.audio_backend, levels=cfg.model.max_levels):
"""
if backend == "audiodec":
return _load_audiodec_model(device, levels=levels)
"""
if backend == "dac":
return _load_dac_model(device, levels=levels)
if backend == "vocos":
@ -202,6 +232,15 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N
# load the model
model = _load_model(device, levels=levels)
# AudioDec uses a different pathway
"""
if model.backend == "audiodec":
codes = codes.to( device=device )[0]
zq = model.rx_encoder.lookup( codes )
wav = model.decoder.decode(zq).squeeze(1)
return wav, model.sample_rate
"""
# DAC uses a different pathway
if model.backend == "dac":
dummy = False
@ -233,7 +272,6 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N
# to-do: inject the sample rate encoded at, because we can actually decouple
return CodecMixin_decompress(model, artifact, verbose=False).audio_data[0], artifact.sample_rate
kwargs = {}
if model.backend == "vocos":
x = model.codes_to_features(codes[0])
@ -266,6 +304,9 @@ def _replace_file_extension(path, suffix):
# an experimental way to include "trained" embeddings from the audio backend itself
# > b-but why not just initialize the embedding weights to these instead of fetching them at r-runtime
# each audio backend does their "embeddings" a different way that isn't just a embedding weights
#
# this is overkill and I don't feel like this benefits anything, but it was an idea I had
# this only really works if the embedding dims match, and either a Linear to rescale would be needed or semi-erroneously just padding with 0s
@torch.inference_mode()
def encode_as_embedding(codes: Tensor, quant_level: int = 0, sums=False, device="cuda"):
model = _load_model(device)
@ -320,6 +361,8 @@ def encode_as_embedding(codes: Tensor, quant_level: int = 0, sums=False, device=
@torch.inference_mode()
def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.model.max_levels, return_metadata=True):
# DAC uses a different pathway
if cfg.audio_backend == "dac":
model = _load_dac_model(device, levels=levels )
signal = AudioSignal(wav, sample_rate=sr)
@ -330,7 +373,20 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.mod
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
artifact = model.compress(signal, win_duration=None, verbose=False, n_quantizers=levels)
# AudioDec uses a different pathway
"""
if cfg.audio_backend == "audiodec":
model = _load_audiodec_model(device, levels=levels )
wav = wav.unsqueeze(0)
wav = convert_audio(wav, sr, model.sample_rate, 1)
wav = wav.to(device)
# wav = rearrange(wav, "t c -> t 1 c").to(device)
encoded = model.tx_encoder.encode(wav)
quantized = model.tx_encoder.quantize(encoded)
return quantized
return artifact.codes if not return_metadata else artifact
"""
# vocos does not encode wavs to encodecs, so just use normal encodec
model = _load_encodec_model(device, levels=levels)
@ -430,4 +486,19 @@ def merge_audio( *args, device="cpu", scale=[], levels=cfg.model.max_levels ):
decoded[i] = decoded[i] * scale[i]
combined = sum(decoded) / len(decoded)
return encode(combined, cfg.sample_rate, device="cpu", levels=levels)[0].t()
return encode(combined, cfg.sample_rate, device="cpu", levels=levels)[0].t()
"""
if __name__ == "__main__":
from vall_e.emb.qnt import encode, decode, cfg
cfg.sample_rate = 48_000
cfg.audio_backend = "audiodec"
wav, sr = torchaudio.load("in.wav")
codes = encode( wav, sr ).t()
print( "ENCODED:", codes.shape, codes )
wav, sr = decode( codes )
print( "DECODED:", wav.shape, wav )
torchaudio.save("out.wav", wav.cpu(), sr)
"""