(commented-out) support for facebookresearch/AudioDec, but support really didn't wow me (so I commented it out until I figure out why my output audio is super crusty with AudioDec)
This commit is contained in:
parent
db62e55a38
commit
1ecf2793f4
|
@ -776,6 +776,9 @@ class Config(BaseConfig):
|
||||||
self.models = [ Model(**model) for model in self.models ]
|
self.models = [ Model(**model) for model in self.models ]
|
||||||
self.loras = [ LoRA(**lora) for lora in self.loras ]
|
self.loras = [ LoRA(**lora) for lora in self.loras ]
|
||||||
|
|
||||||
|
if not self.models:
|
||||||
|
self.models = [ Model() ]
|
||||||
|
|
||||||
for model in self.models:
|
for model in self.models:
|
||||||
if not isinstance( model.experimental, dict ):
|
if not isinstance( model.experimental, dict ):
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -89,6 +89,18 @@ try:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
cfg.inference.use_dac = False
|
cfg.inference.use_dac = False
|
||||||
print(str(e))
|
print(str(e))
|
||||||
|
|
||||||
|
"""
|
||||||
|
# uses https://github.com/facebookresearch/AudioDec/
|
||||||
|
# I have set up a pip-ify'd version with the caveat of having to manually handle downloading the checkpoints with a wget + unzip
|
||||||
|
# I was not happy with testing, it sounded rather mediocre.
|
||||||
|
try:
|
||||||
|
from audiodec.utils.audiodec import AudioDec, assign_model as _audiodec_assign_model
|
||||||
|
except Exception as e:
|
||||||
|
cfg.inference.use_audiodec = False
|
||||||
|
print(str(e))
|
||||||
|
"""
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def _load_encodec_model(device="cuda", levels=cfg.model.max_levels):
|
def _load_encodec_model(device="cuda", levels=cfg.model.max_levels):
|
||||||
assert cfg.sample_rate == 24_000
|
assert cfg.sample_rate == 24_000
|
||||||
|
@ -146,7 +158,7 @@ def _load_dac_model(device="cuda", levels=cfg.model.max_levels):
|
||||||
kwargs = dict(model_type="44khz",model_bitrate="8kbps",tag="latest")
|
kwargs = dict(model_type="44khz",model_bitrate="8kbps",tag="latest")
|
||||||
if not cfg.variable_sample_rate:
|
if not cfg.variable_sample_rate:
|
||||||
# yes there's a better way, something like f'{cfg.sample.rate//1000}hz'
|
# yes there's a better way, something like f'{cfg.sample.rate//1000}hz'
|
||||||
if cfg.sample_rate == 44_000:
|
if cfg.sample_rate == 44_000 or cfg.sample_rate == 44_100:
|
||||||
kwargs["model_type"] = "44khz"
|
kwargs["model_type"] = "44khz"
|
||||||
elif cfg.sample_rate == 16_000:
|
elif cfg.sample_rate == 16_000:
|
||||||
kwargs["model_type"] = "16khz"
|
kwargs["model_type"] = "16khz"
|
||||||
|
@ -157,10 +169,7 @@ def _load_dac_model(device="cuda", levels=cfg.model.max_levels):
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
|
|
||||||
# extra metadata
|
# to revisit later, but experiments shown that this is a bad idea
|
||||||
|
|
||||||
# since DAC moreso models against waveforms, we can actually use a smaller sample rate
|
|
||||||
# updating it here will affect the sample rate the waveform is resampled to on encoding
|
|
||||||
if cfg.variable_sample_rate:
|
if cfg.variable_sample_rate:
|
||||||
model.sample_rate = cfg.sample_rate
|
model.sample_rate = cfg.sample_rate
|
||||||
|
|
||||||
|
@ -169,8 +178,29 @@ def _load_dac_model(device="cuda", levels=cfg.model.max_levels):
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
"""
|
||||||
|
@cache
|
||||||
|
def _load_audiodec_model(device="cuda", model_name=None, levels=cfg.model.max_levels):
|
||||||
|
if not model_name:
|
||||||
|
model_name = "libritts_v1" if cfg.sample_rate == 24_000 else "vctk_v1"
|
||||||
|
sample_rate, encoder_checkpoint, decoder_checkpoint = _audiodec_assign_model(model_name)
|
||||||
|
|
||||||
|
model = AudioDec(tx_device=device , rx_device=device )
|
||||||
|
model.load_transmitter(encoder_checkpoint)
|
||||||
|
model.load_receiver(encoder_checkpoint, decoder_checkpoint)
|
||||||
|
|
||||||
|
model.backend = "audiodec"
|
||||||
|
model.sample_rate = sample_rate
|
||||||
|
|
||||||
|
return model
|
||||||
|
"""
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def _load_model(device="cuda", backend=cfg.audio_backend, levels=cfg.model.max_levels):
|
def _load_model(device="cuda", backend=cfg.audio_backend, levels=cfg.model.max_levels):
|
||||||
|
"""
|
||||||
|
if backend == "audiodec":
|
||||||
|
return _load_audiodec_model(device, levels=levels)
|
||||||
|
"""
|
||||||
if backend == "dac":
|
if backend == "dac":
|
||||||
return _load_dac_model(device, levels=levels)
|
return _load_dac_model(device, levels=levels)
|
||||||
if backend == "vocos":
|
if backend == "vocos":
|
||||||
|
@ -202,6 +232,15 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N
|
||||||
# load the model
|
# load the model
|
||||||
model = _load_model(device, levels=levels)
|
model = _load_model(device, levels=levels)
|
||||||
|
|
||||||
|
# AudioDec uses a different pathway
|
||||||
|
"""
|
||||||
|
if model.backend == "audiodec":
|
||||||
|
codes = codes.to( device=device )[0]
|
||||||
|
zq = model.rx_encoder.lookup( codes )
|
||||||
|
wav = model.decoder.decode(zq).squeeze(1)
|
||||||
|
return wav, model.sample_rate
|
||||||
|
"""
|
||||||
|
|
||||||
# DAC uses a different pathway
|
# DAC uses a different pathway
|
||||||
if model.backend == "dac":
|
if model.backend == "dac":
|
||||||
dummy = False
|
dummy = False
|
||||||
|
@ -233,7 +272,6 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N
|
||||||
# to-do: inject the sample rate encoded at, because we can actually decouple
|
# to-do: inject the sample rate encoded at, because we can actually decouple
|
||||||
return CodecMixin_decompress(model, artifact, verbose=False).audio_data[0], artifact.sample_rate
|
return CodecMixin_decompress(model, artifact, verbose=False).audio_data[0], artifact.sample_rate
|
||||||
|
|
||||||
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if model.backend == "vocos":
|
if model.backend == "vocos":
|
||||||
x = model.codes_to_features(codes[0])
|
x = model.codes_to_features(codes[0])
|
||||||
|
@ -266,6 +304,9 @@ def _replace_file_extension(path, suffix):
|
||||||
# an experimental way to include "trained" embeddings from the audio backend itself
|
# an experimental way to include "trained" embeddings from the audio backend itself
|
||||||
# > b-but why not just initialize the embedding weights to these instead of fetching them at r-runtime
|
# > b-but why not just initialize the embedding weights to these instead of fetching them at r-runtime
|
||||||
# each audio backend does their "embeddings" a different way that isn't just a embedding weights
|
# each audio backend does their "embeddings" a different way that isn't just a embedding weights
|
||||||
|
#
|
||||||
|
# this is overkill and I don't feel like this benefits anything, but it was an idea I had
|
||||||
|
# this only really works if the embedding dims match, and either a Linear to rescale would be needed or semi-erroneously just padding with 0s
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def encode_as_embedding(codes: Tensor, quant_level: int = 0, sums=False, device="cuda"):
|
def encode_as_embedding(codes: Tensor, quant_level: int = 0, sums=False, device="cuda"):
|
||||||
model = _load_model(device)
|
model = _load_model(device)
|
||||||
|
@ -320,6 +361,8 @@ def encode_as_embedding(codes: Tensor, quant_level: int = 0, sums=False, device=
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.model.max_levels, return_metadata=True):
|
def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.model.max_levels, return_metadata=True):
|
||||||
|
|
||||||
|
# DAC uses a different pathway
|
||||||
if cfg.audio_backend == "dac":
|
if cfg.audio_backend == "dac":
|
||||||
model = _load_dac_model(device, levels=levels )
|
model = _load_dac_model(device, levels=levels )
|
||||||
signal = AudioSignal(wav, sample_rate=sr)
|
signal = AudioSignal(wav, sample_rate=sr)
|
||||||
|
@ -330,7 +373,20 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.mod
|
||||||
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
|
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
|
||||||
artifact = model.compress(signal, win_duration=None, verbose=False, n_quantizers=levels)
|
artifact = model.compress(signal, win_duration=None, verbose=False, n_quantizers=levels)
|
||||||
|
|
||||||
|
# AudioDec uses a different pathway
|
||||||
|
"""
|
||||||
|
if cfg.audio_backend == "audiodec":
|
||||||
|
model = _load_audiodec_model(device, levels=levels )
|
||||||
|
wav = wav.unsqueeze(0)
|
||||||
|
wav = convert_audio(wav, sr, model.sample_rate, 1)
|
||||||
|
wav = wav.to(device)
|
||||||
|
|
||||||
|
# wav = rearrange(wav, "t c -> t 1 c").to(device)
|
||||||
|
encoded = model.tx_encoder.encode(wav)
|
||||||
|
quantized = model.tx_encoder.quantize(encoded)
|
||||||
|
return quantized
|
||||||
return artifact.codes if not return_metadata else artifact
|
return artifact.codes if not return_metadata else artifact
|
||||||
|
"""
|
||||||
|
|
||||||
# vocos does not encode wavs to encodecs, so just use normal encodec
|
# vocos does not encode wavs to encodecs, so just use normal encodec
|
||||||
model = _load_encodec_model(device, levels=levels)
|
model = _load_encodec_model(device, levels=levels)
|
||||||
|
@ -431,3 +487,18 @@ def merge_audio( *args, device="cpu", scale=[], levels=cfg.model.max_levels ):
|
||||||
|
|
||||||
combined = sum(decoded) / len(decoded)
|
combined = sum(decoded) / len(decoded)
|
||||||
return encode(combined, cfg.sample_rate, device="cpu", levels=levels)[0].t()
|
return encode(combined, cfg.sample_rate, device="cpu", levels=levels)[0].t()
|
||||||
|
|
||||||
|
"""
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from vall_e.emb.qnt import encode, decode, cfg
|
||||||
|
|
||||||
|
cfg.sample_rate = 48_000
|
||||||
|
cfg.audio_backend = "audiodec"
|
||||||
|
|
||||||
|
wav, sr = torchaudio.load("in.wav")
|
||||||
|
codes = encode( wav, sr ).t()
|
||||||
|
print( "ENCODED:", codes.shape, codes )
|
||||||
|
wav, sr = decode( codes )
|
||||||
|
print( "DECODED:", wav.shape, wav )
|
||||||
|
torchaudio.save("out.wav", wav.cpu(), sr)
|
||||||
|
"""
|
Loading…
Reference in New Issue
Block a user