finally swallowing the Descript-Audio-Codec pill (I guess I'm going to have to regenerate my entire dataset)
This commit is contained in:
parent
b0bd88833c
commit
5ff2b4aab5
|
@ -484,7 +484,12 @@ class Inference:
|
|||
amp: bool = False
|
||||
|
||||
normalize: bool = False # do NOT enable this unless you know exactly what you're doing
|
||||
audio_backend: str = "vocos"
|
||||
|
||||
# legacy / backwards compat
|
||||
use_vocos: bool = True
|
||||
use_encodec: bool = True
|
||||
use_dac: bool = True
|
||||
|
||||
recurrent_chunk_size: int = 0
|
||||
recurrent_forward: bool = False
|
||||
|
@ -576,22 +581,30 @@ class Config(_Config):
|
|||
self.dataset.use_hdf5 = False
|
||||
|
||||
def format( self ):
|
||||
#if not isinstance(self.dataset, type):
|
||||
self.dataset = Dataset(**self.dataset)
|
||||
self.dataset.training = [ Path(dir) for dir in self.dataset.training ]
|
||||
self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ]
|
||||
self.dataset.noise = [ Path(dir) for dir in self.dataset.noise ]
|
||||
|
||||
#if not isinstance(self.model, type):
|
||||
if self.models is not None:
|
||||
self.model = Model(**next(iter(self.models)))
|
||||
else:
|
||||
self.model = Model(**self.model)
|
||||
self.hyperparameters = Hyperparameters(**self.hyperparameters)
|
||||
self.evaluation = Evaluation(**self.evaluation)
|
||||
self.trainer = Trainer(**self.trainer)
|
||||
self.inference = Inference(**self.inference)
|
||||
self.bitsandbytes = BitsAndBytes(**self.bitsandbytes)
|
||||
|
||||
self.trainer.deepspeed = DeepSpeed(**self.trainer.deepspeed)
|
||||
|
||||
self.dataset.training = [ Path(dir) for dir in self.dataset.training ]
|
||||
self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ]
|
||||
self.dataset.noise = [ Path(dir) for dir in self.dataset.noise ]
|
||||
#if not isinstance(self.hyperparameters, type):
|
||||
self.hyperparameters = Hyperparameters(**self.hyperparameters)
|
||||
#if not isinstance(self.evaluation, type):
|
||||
self.evaluation = Evaluation(**self.evaluation)
|
||||
#if not isinstance(self.trainer, type):
|
||||
self.trainer = Trainer(**self.trainer)
|
||||
if not isinstance(self.trainer.deepspeed, type):
|
||||
self.trainer.deepspeed = DeepSpeed(**self.trainer.deepspeed)
|
||||
#if not isinstance(self.inference, type):
|
||||
self.inference = Inference(**self.inference)
|
||||
#if not isinstance(self.bitsandbytes, type):
|
||||
self.bitsandbytes = BitsAndBytes(**self.bitsandbytes)
|
||||
|
||||
|
||||
cfg = Config.from_cli()
|
||||
|
|
|
@ -9,20 +9,89 @@ from functools import cache
|
|||
from pathlib import Path
|
||||
|
||||
|
||||
from encodec import EncodecModel
|
||||
from encodec.utils import convert_audio
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
from tqdm import tqdm
|
||||
|
||||
try:
|
||||
from encodec import EncodecModel
|
||||
from encodec.utils import convert_audio
|
||||
except Exception as e:
|
||||
cfg.inference.use_encodec = False
|
||||
|
||||
try:
|
||||
from vocos import Vocos
|
||||
except Exception as e:
|
||||
cfg.inference.use_vocos = False
|
||||
|
||||
try:
|
||||
from dac import DACFile
|
||||
from audiotools import AudioSignal
|
||||
from dac.utils import load_model as __load_dac_model
|
||||
|
||||
"""
|
||||
Patch decode to skip things related to the metadata (namely the waveform trimming)
|
||||
So far it seems the raw waveform can just be returned without any post-processing
|
||||
A smart implementation would just reuse the values from the input prompt
|
||||
"""
|
||||
from dac.model.base import CodecMixin
|
||||
|
||||
@torch.no_grad()
|
||||
def CodecMixin_decompress(
|
||||
self,
|
||||
obj: Union[str, Path, DACFile],
|
||||
verbose: bool = False,
|
||||
) -> AudioSignal:
|
||||
self.eval()
|
||||
if isinstance(obj, (str, Path)):
|
||||
obj = DACFile.load(obj)
|
||||
|
||||
original_padding = self.padding
|
||||
self.padding = obj.padding
|
||||
|
||||
range_fn = range if not verbose else tqdm.trange
|
||||
codes = obj.codes
|
||||
original_device = codes.device
|
||||
chunk_length = obj.chunk_length
|
||||
recons = []
|
||||
|
||||
for i in range_fn(0, codes.shape[-1], chunk_length):
|
||||
c = codes[..., i : i + chunk_length].to(self.device)
|
||||
z = self.quantizer.from_codes(c)[0]
|
||||
r = self.decode(z)
|
||||
recons.append(r.to(original_device))
|
||||
|
||||
recons = torch.cat(recons, dim=-1)
|
||||
recons = AudioSignal(recons, self.sample_rate)
|
||||
|
||||
# to-do, original implementation
|
||||
"""
|
||||
resample_fn = recons.resample
|
||||
loudness_fn = recons.loudness
|
||||
|
||||
# If audio is > 10 minutes long, use the ffmpeg versions
|
||||
if recons.signal_duration >= 10 * 60 * 60:
|
||||
resample_fn = recons.ffmpeg_resample
|
||||
loudness_fn = recons.ffmpeg_loudness
|
||||
|
||||
recons.normalize(obj.input_db)
|
||||
resample_fn(obj.sample_rate)
|
||||
recons = recons[..., : obj.original_length]
|
||||
loudness_fn()
|
||||
recons.audio_data = recons.audio_data.reshape(
|
||||
-1, obj.channels, obj.original_length
|
||||
)
|
||||
"""
|
||||
self.padding = original_padding
|
||||
return recons
|
||||
|
||||
CodecMixin.decompress = CodecMixin_decompress
|
||||
|
||||
except Exception as e:
|
||||
cfg.inference.use_dac = False
|
||||
|
||||
@cache
|
||||
def _load_encodec_model(device="cuda", levels=cfg.model.max_levels):
|
||||
# Instantiate a pretrained EnCodec model
|
||||
assert cfg.sample_rate == 24_000
|
||||
|
||||
# too lazy to un-if ladder this shit
|
||||
|
@ -34,8 +103,14 @@ def _load_encodec_model(device="cuda", levels=cfg.model.max_levels):
|
|||
elif levels == 8:
|
||||
bandwidth_id = 6.0
|
||||
|
||||
model = EncodecModel.encodec_model_24khz().to(device)
|
||||
# Instantiate a pretrained EnCodec model
|
||||
model = EncodecModel.encodec_model_24khz()
|
||||
model.set_target_bandwidth(bandwidth_id)
|
||||
|
||||
model = model.to(device)
|
||||
model = model.eval()
|
||||
|
||||
# extra metadata
|
||||
model.bandwidth_id = bandwidth_id
|
||||
model.sample_rate = cfg.sample_rate
|
||||
model.normalize = cfg.inference.normalize
|
||||
|
@ -49,6 +124,7 @@ def _load_vocos_model(device="cuda", levels=cfg.model.max_levels):
|
|||
|
||||
model = Vocos.from_pretrained("charactr/vocos-encodec-24khz")
|
||||
model = model.to(device)
|
||||
model = model.eval()
|
||||
|
||||
# too lazy to un-if ladder this shit
|
||||
bandwidth_id = 2
|
||||
|
@ -59,6 +135,7 @@ def _load_vocos_model(device="cuda", levels=cfg.model.max_levels):
|
|||
elif levels == 8:
|
||||
bandwidth_id = 2
|
||||
|
||||
# extra metadata
|
||||
model.bandwidth_id = torch.tensor([bandwidth_id], device=device)
|
||||
model.sample_rate = cfg.sample_rate
|
||||
model.backend = "vocos"
|
||||
|
@ -66,25 +143,48 @@ def _load_vocos_model(device="cuda", levels=cfg.model.max_levels):
|
|||
return model
|
||||
|
||||
@cache
|
||||
def _load_model(device="cuda", vocos=cfg.inference.use_vocos, levels=cfg.model.max_levels):
|
||||
if vocos:
|
||||
model = _load_vocos_model(device, levels=levels)
|
||||
def _load_dac_model(device="cuda", levels=cfg.model.max_levels):
|
||||
kwargs = dict(model_type="44khz",model_bitrate="8kbps",tag="latest")
|
||||
|
||||
# yes there's a better way, something like f'{cfg.sample.rate//1000}hz'
|
||||
if cfg.sample_rate == 44_000:
|
||||
kwargs["model_type"] = "44kz"
|
||||
elif cfg.sample_rate == 24_000:
|
||||
kwargs["model_type"] = "24khz"
|
||||
elif cfg.sample_rate == 16_000:
|
||||
kwargs["model_type"] = "16khz"
|
||||
else:
|
||||
model = _load_encodec_model(device, levels=levels)
|
||||
raise Exception(f'unsupported sample rate: {cfg.sample_rate}')
|
||||
|
||||
model = __load_dac_model(**kwargs)
|
||||
model = model.to(device)
|
||||
model = model.eval()
|
||||
|
||||
# extra metadata
|
||||
model.sample_rate = cfg.sample_rate
|
||||
model.backend = "dac"
|
||||
|
||||
return model
|
||||
|
||||
@cache
|
||||
def _load_model(device="cuda", backend=cfg.inference.audio_backend, levels=cfg.model.max_levels):
|
||||
if backend == "dac":
|
||||
return _load_dac_model(device, levels=levels)
|
||||
if backend == "vocos":
|
||||
return _load_vocos_model(device, levels=levels)
|
||||
|
||||
return _load_encodec_model(device, levels=levels)
|
||||
|
||||
def unload_model():
|
||||
_load_model.cache_clear()
|
||||
_load_encodec_model.cache_clear()
|
||||
_load_encodec_model.cache_clear() # because vocos can only decode
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels):
|
||||
"""
|
||||
Args:
|
||||
codes: (b q t)
|
||||
"""
|
||||
def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=None):
|
||||
# upcast so it won't whine
|
||||
if codes.dtype == torch.int8 or codes.dtype == torch.int16 or codes.dtype == torch.uint8:
|
||||
codes = codes.to(torch.int32)
|
||||
|
||||
# expand if we're given a raw 1-RVQ stream
|
||||
if codes.dim() == 1:
|
||||
|
@ -96,21 +196,49 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels):
|
|||
codes = rearrange(codes, "t q -> 1 q t")
|
||||
|
||||
assert codes.dim() == 3, f'Requires shape (b q t) but got {codes.shape}'
|
||||
|
||||
# load the model
|
||||
model = _load_model(device, levels=levels)
|
||||
|
||||
# upcast so it won't whine
|
||||
if codes.dtype == torch.int8 or codes.dtype == torch.int16 or codes.dtype == torch.uint8:
|
||||
codes = codes.to(torch.int32)
|
||||
# DAC uses a different pathway
|
||||
if model.backend == "dac":
|
||||
if metadata is None:
|
||||
metadata = dict(
|
||||
chunk_length=416,
|
||||
original_length=0,
|
||||
input_db=-12,
|
||||
channels=1,
|
||||
sample_rate=model.sample_rate,
|
||||
padding=False,
|
||||
dac_version='1.0.0',
|
||||
)
|
||||
# generate object with copied metadata
|
||||
artifact = DACFile(
|
||||
codes = codes,
|
||||
# yes I can **kwargs from a dict but what if I want to pass the actual DACFile.metadata from elsewhere
|
||||
chunk_length = metadata["chunk_length"] if isinstance(metadata, dict) else metadata.chunk_length,
|
||||
original_length = metadata["original_length"] if isinstance(metadata, dict) else metadata.original_length,
|
||||
input_db = metadata["input_db"] if isinstance(metadata, dict) else metadata.input_db,
|
||||
channels = metadata["channels"] if isinstance(metadata, dict) else metadata.channels,
|
||||
sample_rate = metadata["sample_rate"] if isinstance(metadata, dict) else metadata.sample_rate,
|
||||
padding = metadata["padding"] if isinstance(metadata, dict) else metadata.padding,
|
||||
dac_version = metadata["dac_version"] if isinstance(metadata, dict) else metadata.dac_version,
|
||||
)
|
||||
|
||||
return model.decompress(artifact, verbose=False).audio_data[0], model.sample_rate
|
||||
|
||||
|
||||
kwargs = {}
|
||||
if model.backend == "vocos":
|
||||
x = model.codes_to_features(codes[0])
|
||||
kwargs['bandwidth_id'] = model.bandwidth_id
|
||||
else:
|
||||
# encodec will decode as a batch
|
||||
x = [(codes.to(device), None)]
|
||||
|
||||
wav = model.decode(x, **kwargs)
|
||||
|
||||
# encodec will decode as a batch
|
||||
if model.backend == "encodec":
|
||||
wav = wav[0]
|
||||
|
||||
|
@ -131,13 +259,14 @@ def _replace_file_extension(path, suffix):
|
|||
|
||||
|
||||
@torch.inference_mode()
|
||||
def encode(wav: Tensor, sr: int = 24_000, device="cuda", levels=cfg.model.max_levels):
|
||||
"""
|
||||
Args:
|
||||
wav: (t)
|
||||
sr: int
|
||||
"""
|
||||
def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.model.max_levels, return_metadata=False):
|
||||
if cfg.inference.audio_backend == "dac":
|
||||
model = _load_dac_model(device, levels=levels)
|
||||
signal = AudioSignal(wav, sample_rate=model.sample_rate)
|
||||
artifact = model.compress(signal, 5.0, verbose=False, n_quantizers=levels if isinstance(levels, int) else None)
|
||||
return artifact.codes if not return_metadata else artifact
|
||||
|
||||
# vocos does not encode wavs to encodecs, so just use normal encodec
|
||||
model = _load_encodec_model(device, levels=levels)
|
||||
wav = wav.unsqueeze(0)
|
||||
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
|
||||
|
@ -180,8 +309,9 @@ def encode_from_file(path, device="cuda"):
|
|||
|
||||
return qnt
|
||||
|
||||
# Helper Functions
|
||||
|
||||
"""
|
||||
Helper Functions
|
||||
"""
|
||||
# trims from the start, up to `target`
|
||||
def trim( qnt, target ):
|
||||
length = max( qnt.shape[0], qnt.shape[1] )
|
||||
|
@ -208,7 +338,7 @@ def trim_random( qnt, target ):
|
|||
end = start + target
|
||||
if end >= length:
|
||||
start = length - target
|
||||
end = length
|
||||
end = length
|
||||
|
||||
return qnt[start:end] if qnt.shape[0] > qnt.shape[1] else qnt[:, start:end]
|
||||
|
||||
|
@ -233,13 +363,14 @@ def merge_audio( *args, device="cpu", scale=[], levels=cfg.model.max_levels ):
|
|||
decoded[i] = decoded[i] * scale[i]
|
||||
|
||||
combined = sum(decoded) / len(decoded)
|
||||
return encode(combined, 24_000, device="cpu", levels=levels)[0].t()
|
||||
return encode(combined, cfg.sample_rate, device="cpu", levels=levels)[0].t()
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("folder", type=Path)
|
||||
parser.add_argument("--suffix", default=".wav")
|
||||
parser.add_argument("--device", default="cuda")
|
||||
parser.add_argument("--backend", default="encodec")
|
||||
args = parser.parse_args()
|
||||
|
||||
device = args.device
|
||||
|
|
|
@ -336,7 +336,9 @@ def example_usage():
|
|||
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
|
||||
return torch.tensor([*map(symmap.get, phones)])
|
||||
|
||||
qnt = torch.load("data/qnt.pt")[0].t()[:, :cfg.model.prom_levels].to(device)
|
||||
qnt = torch.load(f'data/qnt{".dac" if cfg.inference.audio_backend == "dac" else ""}.pt')[0].t()[:, :cfg.model.prom_levels].to(device)
|
||||
|
||||
print(qnt.shape)
|
||||
|
||||
cfg.hyperparameters.gradient_accumulation_steps = 1
|
||||
|
||||
|
@ -426,11 +428,15 @@ def example_usage():
|
|||
|
||||
@torch.inference_mode()
|
||||
def sample( name, steps=600 ):
|
||||
if cfg.inference.audio_backend == "dac" and name == "init":
|
||||
return
|
||||
|
||||
engine.eval()
|
||||
resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 )
|
||||
|
||||
for i, o in enumerate(resps_list):
|
||||
_ = decode_to_file(o, f"data/ar.{i}.{name}.wav", device=device)
|
||||
if cfg.inference.audio_backend != "dac":
|
||||
for i, o in enumerate(resps_list):
|
||||
_ = decode_to_file(o, f"data/ar.{i}.{name}.wav", device=device)
|
||||
|
||||
resps_list = [r.unsqueeze(-1) for r in resps_list]
|
||||
resps_list = engine( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 )
|
||||
|
|
Loading…
Reference in New Issue
Block a user