(commented-out) support for facebookresearch/AudioDec, but support really didn't wow me (so I commented it out until I figure out why my output audio is super crusty with AudioDec)
This commit is contained in:
parent
db62e55a38
commit
1ecf2793f4
|
@ -776,6 +776,9 @@ class Config(BaseConfig):
|
|||
self.models = [ Model(**model) for model in self.models ]
|
||||
self.loras = [ LoRA(**lora) for lora in self.loras ]
|
||||
|
||||
if not self.models:
|
||||
self.models = [ Model() ]
|
||||
|
||||
for model in self.models:
|
||||
if not isinstance( model.experimental, dict ):
|
||||
continue
|
||||
|
|
|
@ -89,6 +89,18 @@ try:
|
|||
except Exception as e:
|
||||
cfg.inference.use_dac = False
|
||||
print(str(e))
|
||||
|
||||
"""
|
||||
# uses https://github.com/facebookresearch/AudioDec/
|
||||
# I have set up a pip-ify'd version with the caveat of having to manually handle downloading the checkpoints with a wget + unzip
|
||||
# I was not happy with testing, it sounded rather mediocre.
|
||||
try:
|
||||
from audiodec.utils.audiodec import AudioDec, assign_model as _audiodec_assign_model
|
||||
except Exception as e:
|
||||
cfg.inference.use_audiodec = False
|
||||
print(str(e))
|
||||
"""
|
||||
|
||||
@cache
|
||||
def _load_encodec_model(device="cuda", levels=cfg.model.max_levels):
|
||||
assert cfg.sample_rate == 24_000
|
||||
|
@ -146,7 +158,7 @@ def _load_dac_model(device="cuda", levels=cfg.model.max_levels):
|
|||
kwargs = dict(model_type="44khz",model_bitrate="8kbps",tag="latest")
|
||||
if not cfg.variable_sample_rate:
|
||||
# yes there's a better way, something like f'{cfg.sample.rate//1000}hz'
|
||||
if cfg.sample_rate == 44_000:
|
||||
if cfg.sample_rate == 44_000 or cfg.sample_rate == 44_100:
|
||||
kwargs["model_type"] = "44khz"
|
||||
elif cfg.sample_rate == 16_000:
|
||||
kwargs["model_type"] = "16khz"
|
||||
|
@ -157,10 +169,7 @@ def _load_dac_model(device="cuda", levels=cfg.model.max_levels):
|
|||
model = model.to(device)
|
||||
model = model.eval()
|
||||
|
||||
# extra metadata
|
||||
|
||||
# since DAC moreso models against waveforms, we can actually use a smaller sample rate
|
||||
# updating it here will affect the sample rate the waveform is resampled to on encoding
|
||||
# to revisit later, but experiments shown that this is a bad idea
|
||||
if cfg.variable_sample_rate:
|
||||
model.sample_rate = cfg.sample_rate
|
||||
|
||||
|
@ -169,8 +178,29 @@ def _load_dac_model(device="cuda", levels=cfg.model.max_levels):
|
|||
|
||||
return model
|
||||
|
||||
"""
|
||||
@cache
|
||||
def _load_audiodec_model(device="cuda", model_name=None, levels=cfg.model.max_levels):
|
||||
if not model_name:
|
||||
model_name = "libritts_v1" if cfg.sample_rate == 24_000 else "vctk_v1"
|
||||
sample_rate, encoder_checkpoint, decoder_checkpoint = _audiodec_assign_model(model_name)
|
||||
|
||||
model = AudioDec(tx_device=device , rx_device=device )
|
||||
model.load_transmitter(encoder_checkpoint)
|
||||
model.load_receiver(encoder_checkpoint, decoder_checkpoint)
|
||||
|
||||
model.backend = "audiodec"
|
||||
model.sample_rate = sample_rate
|
||||
|
||||
return model
|
||||
"""
|
||||
|
||||
@cache
|
||||
def _load_model(device="cuda", backend=cfg.audio_backend, levels=cfg.model.max_levels):
|
||||
"""
|
||||
if backend == "audiodec":
|
||||
return _load_audiodec_model(device, levels=levels)
|
||||
"""
|
||||
if backend == "dac":
|
||||
return _load_dac_model(device, levels=levels)
|
||||
if backend == "vocos":
|
||||
|
@ -202,6 +232,15 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N
|
|||
# load the model
|
||||
model = _load_model(device, levels=levels)
|
||||
|
||||
# AudioDec uses a different pathway
|
||||
"""
|
||||
if model.backend == "audiodec":
|
||||
codes = codes.to( device=device )[0]
|
||||
zq = model.rx_encoder.lookup( codes )
|
||||
wav = model.decoder.decode(zq).squeeze(1)
|
||||
return wav, model.sample_rate
|
||||
"""
|
||||
|
||||
# DAC uses a different pathway
|
||||
if model.backend == "dac":
|
||||
dummy = False
|
||||
|
@ -233,7 +272,6 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N
|
|||
# to-do: inject the sample rate encoded at, because we can actually decouple
|
||||
return CodecMixin_decompress(model, artifact, verbose=False).audio_data[0], artifact.sample_rate
|
||||
|
||||
|
||||
kwargs = {}
|
||||
if model.backend == "vocos":
|
||||
x = model.codes_to_features(codes[0])
|
||||
|
@ -266,6 +304,9 @@ def _replace_file_extension(path, suffix):
|
|||
# an experimental way to include "trained" embeddings from the audio backend itself
|
||||
# > b-but why not just initialize the embedding weights to these instead of fetching them at r-runtime
|
||||
# each audio backend does their "embeddings" a different way that isn't just a embedding weights
|
||||
#
|
||||
# this is overkill and I don't feel like this benefits anything, but it was an idea I had
|
||||
# this only really works if the embedding dims match, and either a Linear to rescale would be needed or semi-erroneously just padding with 0s
|
||||
@torch.inference_mode()
|
||||
def encode_as_embedding(codes: Tensor, quant_level: int = 0, sums=False, device="cuda"):
|
||||
model = _load_model(device)
|
||||
|
@ -320,6 +361,8 @@ def encode_as_embedding(codes: Tensor, quant_level: int = 0, sums=False, device=
|
|||
|
||||
@torch.inference_mode()
|
||||
def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.model.max_levels, return_metadata=True):
|
||||
|
||||
# DAC uses a different pathway
|
||||
if cfg.audio_backend == "dac":
|
||||
model = _load_dac_model(device, levels=levels )
|
||||
signal = AudioSignal(wav, sample_rate=sr)
|
||||
|
@ -330,7 +373,20 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.mod
|
|||
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
|
||||
artifact = model.compress(signal, win_duration=None, verbose=False, n_quantizers=levels)
|
||||
|
||||
# AudioDec uses a different pathway
|
||||
"""
|
||||
if cfg.audio_backend == "audiodec":
|
||||
model = _load_audiodec_model(device, levels=levels )
|
||||
wav = wav.unsqueeze(0)
|
||||
wav = convert_audio(wav, sr, model.sample_rate, 1)
|
||||
wav = wav.to(device)
|
||||
|
||||
# wav = rearrange(wav, "t c -> t 1 c").to(device)
|
||||
encoded = model.tx_encoder.encode(wav)
|
||||
quantized = model.tx_encoder.quantize(encoded)
|
||||
return quantized
|
||||
return artifact.codes if not return_metadata else artifact
|
||||
"""
|
||||
|
||||
# vocos does not encode wavs to encodecs, so just use normal encodec
|
||||
model = _load_encodec_model(device, levels=levels)
|
||||
|
@ -430,4 +486,19 @@ def merge_audio( *args, device="cpu", scale=[], levels=cfg.model.max_levels ):
|
|||
decoded[i] = decoded[i] * scale[i]
|
||||
|
||||
combined = sum(decoded) / len(decoded)
|
||||
return encode(combined, cfg.sample_rate, device="cpu", levels=levels)[0].t()
|
||||
return encode(combined, cfg.sample_rate, device="cpu", levels=levels)[0].t()
|
||||
|
||||
"""
|
||||
if __name__ == "__main__":
|
||||
from vall_e.emb.qnt import encode, decode, cfg
|
||||
|
||||
cfg.sample_rate = 48_000
|
||||
cfg.audio_backend = "audiodec"
|
||||
|
||||
wav, sr = torchaudio.load("in.wav")
|
||||
codes = encode( wav, sr ).t()
|
||||
print( "ENCODED:", codes.shape, codes )
|
||||
wav, sr = decode( codes )
|
||||
print( "DECODED:", wav.shape, wav )
|
||||
torchaudio.save("out.wav", wav.cpu(), sr)
|
||||
"""
|
Loading…
Reference in New Issue
Block a user