finally swallowing the Descript-Audio-Codec pill (I guess I'm going to have to regenerate my entire dataset)

This commit is contained in:
mrq 2024-04-17 20:39:35 -05:00
parent b0bd88833c
commit 5ff2b4aab5
3 changed files with 190 additions and 40 deletions

View File

@ -484,7 +484,12 @@ class Inference:
amp: bool = False
normalize: bool = False # do NOT enable this unless you know exactly what you're doing
audio_backend: str = "vocos"
# legacy / backwards compat
use_vocos: bool = True
use_encodec: bool = True
use_dac: bool = True
recurrent_chunk_size: int = 0
recurrent_forward: bool = False
@ -576,22 +581,30 @@ class Config(_Config):
self.dataset.use_hdf5 = False
def format( self ):
#if not isinstance(self.dataset, type):
self.dataset = Dataset(**self.dataset)
self.dataset.training = [ Path(dir) for dir in self.dataset.training ]
self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ]
self.dataset.noise = [ Path(dir) for dir in self.dataset.noise ]
#if not isinstance(self.model, type):
if self.models is not None:
self.model = Model(**next(iter(self.models)))
else:
self.model = Model(**self.model)
self.hyperparameters = Hyperparameters(**self.hyperparameters)
self.evaluation = Evaluation(**self.evaluation)
self.trainer = Trainer(**self.trainer)
self.inference = Inference(**self.inference)
self.bitsandbytes = BitsAndBytes(**self.bitsandbytes)
self.trainer.deepspeed = DeepSpeed(**self.trainer.deepspeed)
self.dataset.training = [ Path(dir) for dir in self.dataset.training ]
self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ]
self.dataset.noise = [ Path(dir) for dir in self.dataset.noise ]
#if not isinstance(self.hyperparameters, type):
self.hyperparameters = Hyperparameters(**self.hyperparameters)
#if not isinstance(self.evaluation, type):
self.evaluation = Evaluation(**self.evaluation)
#if not isinstance(self.trainer, type):
self.trainer = Trainer(**self.trainer)
if not isinstance(self.trainer.deepspeed, type):
self.trainer.deepspeed = DeepSpeed(**self.trainer.deepspeed)
#if not isinstance(self.inference, type):
self.inference = Inference(**self.inference)
#if not isinstance(self.bitsandbytes, type):
self.bitsandbytes = BitsAndBytes(**self.bitsandbytes)
cfg = Config.from_cli()

View File

@ -9,20 +9,89 @@ from functools import cache
from pathlib import Path
from encodec import EncodecModel
from encodec.utils import convert_audio
from einops import rearrange
from torch import Tensor
from tqdm import tqdm
try:
from encodec import EncodecModel
from encodec.utils import convert_audio
except Exception as e:
cfg.inference.use_encodec = False
try:
from vocos import Vocos
except Exception as e:
cfg.inference.use_vocos = False
try:
from dac import DACFile
from audiotools import AudioSignal
from dac.utils import load_model as __load_dac_model
"""
Patch decode to skip things related to the metadata (namely the waveform trimming)
So far it seems the raw waveform can just be returned without any post-processing
A smart implementation would just reuse the values from the input prompt
"""
from dac.model.base import CodecMixin
@torch.no_grad()
def CodecMixin_decompress(
self,
obj: Union[str, Path, DACFile],
verbose: bool = False,
) -> AudioSignal:
self.eval()
if isinstance(obj, (str, Path)):
obj = DACFile.load(obj)
original_padding = self.padding
self.padding = obj.padding
range_fn = range if not verbose else tqdm.trange
codes = obj.codes
original_device = codes.device
chunk_length = obj.chunk_length
recons = []
for i in range_fn(0, codes.shape[-1], chunk_length):
c = codes[..., i : i + chunk_length].to(self.device)
z = self.quantizer.from_codes(c)[0]
r = self.decode(z)
recons.append(r.to(original_device))
recons = torch.cat(recons, dim=-1)
recons = AudioSignal(recons, self.sample_rate)
# to-do, original implementation
"""
resample_fn = recons.resample
loudness_fn = recons.loudness
# If audio is > 10 minutes long, use the ffmpeg versions
if recons.signal_duration >= 10 * 60 * 60:
resample_fn = recons.ffmpeg_resample
loudness_fn = recons.ffmpeg_loudness
recons.normalize(obj.input_db)
resample_fn(obj.sample_rate)
recons = recons[..., : obj.original_length]
loudness_fn()
recons.audio_data = recons.audio_data.reshape(
-1, obj.channels, obj.original_length
)
"""
self.padding = original_padding
return recons
CodecMixin.decompress = CodecMixin_decompress
except Exception as e:
cfg.inference.use_dac = False
@cache
def _load_encodec_model(device="cuda", levels=cfg.model.max_levels):
# Instantiate a pretrained EnCodec model
assert cfg.sample_rate == 24_000
# too lazy to un-if ladder this shit
@ -34,8 +103,14 @@ def _load_encodec_model(device="cuda", levels=cfg.model.max_levels):
elif levels == 8:
bandwidth_id = 6.0
model = EncodecModel.encodec_model_24khz().to(device)
# Instantiate a pretrained EnCodec model
model = EncodecModel.encodec_model_24khz()
model.set_target_bandwidth(bandwidth_id)
model = model.to(device)
model = model.eval()
# extra metadata
model.bandwidth_id = bandwidth_id
model.sample_rate = cfg.sample_rate
model.normalize = cfg.inference.normalize
@ -49,6 +124,7 @@ def _load_vocos_model(device="cuda", levels=cfg.model.max_levels):
model = Vocos.from_pretrained("charactr/vocos-encodec-24khz")
model = model.to(device)
model = model.eval()
# too lazy to un-if ladder this shit
bandwidth_id = 2
@ -59,6 +135,7 @@ def _load_vocos_model(device="cuda", levels=cfg.model.max_levels):
elif levels == 8:
bandwidth_id = 2
# extra metadata
model.bandwidth_id = torch.tensor([bandwidth_id], device=device)
model.sample_rate = cfg.sample_rate
model.backend = "vocos"
@ -66,25 +143,48 @@ def _load_vocos_model(device="cuda", levels=cfg.model.max_levels):
return model
@cache
def _load_model(device="cuda", vocos=cfg.inference.use_vocos, levels=cfg.model.max_levels):
if vocos:
model = _load_vocos_model(device, levels=levels)
def _load_dac_model(device="cuda", levels=cfg.model.max_levels):
kwargs = dict(model_type="44khz",model_bitrate="8kbps",tag="latest")
# yes there's a better way, something like f'{cfg.sample.rate//1000}hz'
if cfg.sample_rate == 44_000:
kwargs["model_type"] = "44kz"
elif cfg.sample_rate == 24_000:
kwargs["model_type"] = "24khz"
elif cfg.sample_rate == 16_000:
kwargs["model_type"] = "16khz"
else:
model = _load_encodec_model(device, levels=levels)
raise Exception(f'unsupported sample rate: {cfg.sample_rate}')
model = __load_dac_model(**kwargs)
model = model.to(device)
model = model.eval()
# extra metadata
model.sample_rate = cfg.sample_rate
model.backend = "dac"
return model
@cache
def _load_model(device="cuda", backend=cfg.inference.audio_backend, levels=cfg.model.max_levels):
if backend == "dac":
return _load_dac_model(device, levels=levels)
if backend == "vocos":
return _load_vocos_model(device, levels=levels)
return _load_encodec_model(device, levels=levels)
def unload_model():
_load_model.cache_clear()
_load_encodec_model.cache_clear()
_load_encodec_model.cache_clear() # because vocos can only decode
@torch.inference_mode()
def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels):
"""
Args:
codes: (b q t)
"""
def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=None):
# upcast so it won't whine
if codes.dtype == torch.int8 or codes.dtype == torch.int16 or codes.dtype == torch.uint8:
codes = codes.to(torch.int32)
# expand if we're given a raw 1-RVQ stream
if codes.dim() == 1:
@ -96,21 +196,49 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels):
codes = rearrange(codes, "t q -> 1 q t")
assert codes.dim() == 3, f'Requires shape (b q t) but got {codes.shape}'
# load the model
model = _load_model(device, levels=levels)
# upcast so it won't whine
if codes.dtype == torch.int8 or codes.dtype == torch.int16 or codes.dtype == torch.uint8:
codes = codes.to(torch.int32)
# DAC uses a different pathway
if model.backend == "dac":
if metadata is None:
metadata = dict(
chunk_length=416,
original_length=0,
input_db=-12,
channels=1,
sample_rate=model.sample_rate,
padding=False,
dac_version='1.0.0',
)
# generate object with copied metadata
artifact = DACFile(
codes = codes,
# yes I can **kwargs from a dict but what if I want to pass the actual DACFile.metadata from elsewhere
chunk_length = metadata["chunk_length"] if isinstance(metadata, dict) else metadata.chunk_length,
original_length = metadata["original_length"] if isinstance(metadata, dict) else metadata.original_length,
input_db = metadata["input_db"] if isinstance(metadata, dict) else metadata.input_db,
channels = metadata["channels"] if isinstance(metadata, dict) else metadata.channels,
sample_rate = metadata["sample_rate"] if isinstance(metadata, dict) else metadata.sample_rate,
padding = metadata["padding"] if isinstance(metadata, dict) else metadata.padding,
dac_version = metadata["dac_version"] if isinstance(metadata, dict) else metadata.dac_version,
)
return model.decompress(artifact, verbose=False).audio_data[0], model.sample_rate
kwargs = {}
if model.backend == "vocos":
x = model.codes_to_features(codes[0])
kwargs['bandwidth_id'] = model.bandwidth_id
else:
# encodec will decode as a batch
x = [(codes.to(device), None)]
wav = model.decode(x, **kwargs)
# encodec will decode as a batch
if model.backend == "encodec":
wav = wav[0]
@ -131,13 +259,14 @@ def _replace_file_extension(path, suffix):
@torch.inference_mode()
def encode(wav: Tensor, sr: int = 24_000, device="cuda", levels=cfg.model.max_levels):
"""
Args:
wav: (t)
sr: int
"""
def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.model.max_levels, return_metadata=False):
if cfg.inference.audio_backend == "dac":
model = _load_dac_model(device, levels=levels)
signal = AudioSignal(wav, sample_rate=model.sample_rate)
artifact = model.compress(signal, 5.0, verbose=False, n_quantizers=levels if isinstance(levels, int) else None)
return artifact.codes if not return_metadata else artifact
# vocos does not encode wavs to encodecs, so just use normal encodec
model = _load_encodec_model(device, levels=levels)
wav = wav.unsqueeze(0)
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
@ -180,8 +309,9 @@ def encode_from_file(path, device="cuda"):
return qnt
# Helper Functions
"""
Helper Functions
"""
# trims from the start, up to `target`
def trim( qnt, target ):
length = max( qnt.shape[0], qnt.shape[1] )
@ -208,7 +338,7 @@ def trim_random( qnt, target ):
end = start + target
if end >= length:
start = length - target
end = length
end = length
return qnt[start:end] if qnt.shape[0] > qnt.shape[1] else qnt[:, start:end]
@ -233,13 +363,14 @@ def merge_audio( *args, device="cpu", scale=[], levels=cfg.model.max_levels ):
decoded[i] = decoded[i] * scale[i]
combined = sum(decoded) / len(decoded)
return encode(combined, 24_000, device="cpu", levels=levels)[0].t()
return encode(combined, cfg.sample_rate, device="cpu", levels=levels)[0].t()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("folder", type=Path)
parser.add_argument("--suffix", default=".wav")
parser.add_argument("--device", default="cuda")
parser.add_argument("--backend", default="encodec")
args = parser.parse_args()
device = args.device

View File

@ -336,7 +336,9 @@ def example_usage():
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
return torch.tensor([*map(symmap.get, phones)])
qnt = torch.load("data/qnt.pt")[0].t()[:, :cfg.model.prom_levels].to(device)
qnt = torch.load(f'data/qnt{".dac" if cfg.inference.audio_backend == "dac" else ""}.pt')[0].t()[:, :cfg.model.prom_levels].to(device)
print(qnt.shape)
cfg.hyperparameters.gradient_accumulation_steps = 1
@ -426,11 +428,15 @@ def example_usage():
@torch.inference_mode()
def sample( name, steps=600 ):
if cfg.inference.audio_backend == "dac" and name == "init":
return
engine.eval()
resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 )
for i, o in enumerate(resps_list):
_ = decode_to_file(o, f"data/ar.{i}.{name}.wav", device=device)
if cfg.inference.audio_backend != "dac":
for i, o in enumerate(resps_list):
_ = decode_to_file(o, f"data/ar.{i}.{name}.wav", device=device)
resps_list = [r.unsqueeze(-1) for r in resps_list]
resps_list = engine( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 )