finally swallowing the Descript-Audio-Codec pill (I guess I'm going to have to regenerate my entire dataset)
This commit is contained in:
parent
b0bd88833c
commit
5ff2b4aab5
@ -484,7 +484,12 @@ class Inference:
|
|||||||
amp: bool = False
|
amp: bool = False
|
||||||
|
|
||||||
normalize: bool = False # do NOT enable this unless you know exactly what you're doing
|
normalize: bool = False # do NOT enable this unless you know exactly what you're doing
|
||||||
|
audio_backend: str = "vocos"
|
||||||
|
|
||||||
|
# legacy / backwards compat
|
||||||
use_vocos: bool = True
|
use_vocos: bool = True
|
||||||
|
use_encodec: bool = True
|
||||||
|
use_dac: bool = True
|
||||||
|
|
||||||
recurrent_chunk_size: int = 0
|
recurrent_chunk_size: int = 0
|
||||||
recurrent_forward: bool = False
|
recurrent_forward: bool = False
|
||||||
@ -576,23 +581,31 @@ class Config(_Config):
|
|||||||
self.dataset.use_hdf5 = False
|
self.dataset.use_hdf5 = False
|
||||||
|
|
||||||
def format( self ):
|
def format( self ):
|
||||||
|
#if not isinstance(self.dataset, type):
|
||||||
self.dataset = Dataset(**self.dataset)
|
self.dataset = Dataset(**self.dataset)
|
||||||
|
self.dataset.training = [ Path(dir) for dir in self.dataset.training ]
|
||||||
|
self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ]
|
||||||
|
self.dataset.noise = [ Path(dir) for dir in self.dataset.noise ]
|
||||||
|
|
||||||
|
#if not isinstance(self.model, type):
|
||||||
if self.models is not None:
|
if self.models is not None:
|
||||||
self.model = Model(**next(iter(self.models)))
|
self.model = Model(**next(iter(self.models)))
|
||||||
else:
|
else:
|
||||||
self.model = Model(**self.model)
|
self.model = Model(**self.model)
|
||||||
|
|
||||||
|
#if not isinstance(self.hyperparameters, type):
|
||||||
self.hyperparameters = Hyperparameters(**self.hyperparameters)
|
self.hyperparameters = Hyperparameters(**self.hyperparameters)
|
||||||
|
#if not isinstance(self.evaluation, type):
|
||||||
self.evaluation = Evaluation(**self.evaluation)
|
self.evaluation = Evaluation(**self.evaluation)
|
||||||
|
#if not isinstance(self.trainer, type):
|
||||||
self.trainer = Trainer(**self.trainer)
|
self.trainer = Trainer(**self.trainer)
|
||||||
|
if not isinstance(self.trainer.deepspeed, type):
|
||||||
|
self.trainer.deepspeed = DeepSpeed(**self.trainer.deepspeed)
|
||||||
|
#if not isinstance(self.inference, type):
|
||||||
self.inference = Inference(**self.inference)
|
self.inference = Inference(**self.inference)
|
||||||
|
#if not isinstance(self.bitsandbytes, type):
|
||||||
self.bitsandbytes = BitsAndBytes(**self.bitsandbytes)
|
self.bitsandbytes = BitsAndBytes(**self.bitsandbytes)
|
||||||
|
|
||||||
self.trainer.deepspeed = DeepSpeed(**self.trainer.deepspeed)
|
|
||||||
|
|
||||||
self.dataset.training = [ Path(dir) for dir in self.dataset.training ]
|
|
||||||
self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ]
|
|
||||||
self.dataset.noise = [ Path(dir) for dir in self.dataset.noise ]
|
|
||||||
|
|
||||||
|
|
||||||
cfg = Config.from_cli()
|
cfg = Config.from_cli()
|
||||||
|
|
||||||
|
@ -9,20 +9,89 @@ from functools import cache
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
from encodec import EncodecModel
|
|
||||||
from encodec.utils import convert_audio
|
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
try:
|
||||||
|
from encodec import EncodecModel
|
||||||
|
from encodec.utils import convert_audio
|
||||||
|
except Exception as e:
|
||||||
|
cfg.inference.use_encodec = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from vocos import Vocos
|
from vocos import Vocos
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
cfg.inference.use_vocos = False
|
cfg.inference.use_vocos = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
from dac import DACFile
|
||||||
|
from audiotools import AudioSignal
|
||||||
|
from dac.utils import load_model as __load_dac_model
|
||||||
|
|
||||||
|
"""
|
||||||
|
Patch decode to skip things related to the metadata (namely the waveform trimming)
|
||||||
|
So far it seems the raw waveform can just be returned without any post-processing
|
||||||
|
A smart implementation would just reuse the values from the input prompt
|
||||||
|
"""
|
||||||
|
from dac.model.base import CodecMixin
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def CodecMixin_decompress(
|
||||||
|
self,
|
||||||
|
obj: Union[str, Path, DACFile],
|
||||||
|
verbose: bool = False,
|
||||||
|
) -> AudioSignal:
|
||||||
|
self.eval()
|
||||||
|
if isinstance(obj, (str, Path)):
|
||||||
|
obj = DACFile.load(obj)
|
||||||
|
|
||||||
|
original_padding = self.padding
|
||||||
|
self.padding = obj.padding
|
||||||
|
|
||||||
|
range_fn = range if not verbose else tqdm.trange
|
||||||
|
codes = obj.codes
|
||||||
|
original_device = codes.device
|
||||||
|
chunk_length = obj.chunk_length
|
||||||
|
recons = []
|
||||||
|
|
||||||
|
for i in range_fn(0, codes.shape[-1], chunk_length):
|
||||||
|
c = codes[..., i : i + chunk_length].to(self.device)
|
||||||
|
z = self.quantizer.from_codes(c)[0]
|
||||||
|
r = self.decode(z)
|
||||||
|
recons.append(r.to(original_device))
|
||||||
|
|
||||||
|
recons = torch.cat(recons, dim=-1)
|
||||||
|
recons = AudioSignal(recons, self.sample_rate)
|
||||||
|
|
||||||
|
# to-do, original implementation
|
||||||
|
"""
|
||||||
|
resample_fn = recons.resample
|
||||||
|
loudness_fn = recons.loudness
|
||||||
|
|
||||||
|
# If audio is > 10 minutes long, use the ffmpeg versions
|
||||||
|
if recons.signal_duration >= 10 * 60 * 60:
|
||||||
|
resample_fn = recons.ffmpeg_resample
|
||||||
|
loudness_fn = recons.ffmpeg_loudness
|
||||||
|
|
||||||
|
recons.normalize(obj.input_db)
|
||||||
|
resample_fn(obj.sample_rate)
|
||||||
|
recons = recons[..., : obj.original_length]
|
||||||
|
loudness_fn()
|
||||||
|
recons.audio_data = recons.audio_data.reshape(
|
||||||
|
-1, obj.channels, obj.original_length
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
self.padding = original_padding
|
||||||
|
return recons
|
||||||
|
|
||||||
|
CodecMixin.decompress = CodecMixin_decompress
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
cfg.inference.use_dac = False
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def _load_encodec_model(device="cuda", levels=cfg.model.max_levels):
|
def _load_encodec_model(device="cuda", levels=cfg.model.max_levels):
|
||||||
# Instantiate a pretrained EnCodec model
|
|
||||||
assert cfg.sample_rate == 24_000
|
assert cfg.sample_rate == 24_000
|
||||||
|
|
||||||
# too lazy to un-if ladder this shit
|
# too lazy to un-if ladder this shit
|
||||||
@ -34,8 +103,14 @@ def _load_encodec_model(device="cuda", levels=cfg.model.max_levels):
|
|||||||
elif levels == 8:
|
elif levels == 8:
|
||||||
bandwidth_id = 6.0
|
bandwidth_id = 6.0
|
||||||
|
|
||||||
model = EncodecModel.encodec_model_24khz().to(device)
|
# Instantiate a pretrained EnCodec model
|
||||||
|
model = EncodecModel.encodec_model_24khz()
|
||||||
model.set_target_bandwidth(bandwidth_id)
|
model.set_target_bandwidth(bandwidth_id)
|
||||||
|
|
||||||
|
model = model.to(device)
|
||||||
|
model = model.eval()
|
||||||
|
|
||||||
|
# extra metadata
|
||||||
model.bandwidth_id = bandwidth_id
|
model.bandwidth_id = bandwidth_id
|
||||||
model.sample_rate = cfg.sample_rate
|
model.sample_rate = cfg.sample_rate
|
||||||
model.normalize = cfg.inference.normalize
|
model.normalize = cfg.inference.normalize
|
||||||
@ -49,6 +124,7 @@ def _load_vocos_model(device="cuda", levels=cfg.model.max_levels):
|
|||||||
|
|
||||||
model = Vocos.from_pretrained("charactr/vocos-encodec-24khz")
|
model = Vocos.from_pretrained("charactr/vocos-encodec-24khz")
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
model = model.eval()
|
||||||
|
|
||||||
# too lazy to un-if ladder this shit
|
# too lazy to un-if ladder this shit
|
||||||
bandwidth_id = 2
|
bandwidth_id = 2
|
||||||
@ -59,6 +135,7 @@ def _load_vocos_model(device="cuda", levels=cfg.model.max_levels):
|
|||||||
elif levels == 8:
|
elif levels == 8:
|
||||||
bandwidth_id = 2
|
bandwidth_id = 2
|
||||||
|
|
||||||
|
# extra metadata
|
||||||
model.bandwidth_id = torch.tensor([bandwidth_id], device=device)
|
model.bandwidth_id = torch.tensor([bandwidth_id], device=device)
|
||||||
model.sample_rate = cfg.sample_rate
|
model.sample_rate = cfg.sample_rate
|
||||||
model.backend = "vocos"
|
model.backend = "vocos"
|
||||||
@ -66,25 +143,48 @@ def _load_vocos_model(device="cuda", levels=cfg.model.max_levels):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def _load_model(device="cuda", vocos=cfg.inference.use_vocos, levels=cfg.model.max_levels):
|
def _load_dac_model(device="cuda", levels=cfg.model.max_levels):
|
||||||
if vocos:
|
kwargs = dict(model_type="44khz",model_bitrate="8kbps",tag="latest")
|
||||||
model = _load_vocos_model(device, levels=levels)
|
|
||||||
|
# yes there's a better way, something like f'{cfg.sample.rate//1000}hz'
|
||||||
|
if cfg.sample_rate == 44_000:
|
||||||
|
kwargs["model_type"] = "44kz"
|
||||||
|
elif cfg.sample_rate == 24_000:
|
||||||
|
kwargs["model_type"] = "24khz"
|
||||||
|
elif cfg.sample_rate == 16_000:
|
||||||
|
kwargs["model_type"] = "16khz"
|
||||||
else:
|
else:
|
||||||
model = _load_encodec_model(device, levels=levels)
|
raise Exception(f'unsupported sample rate: {cfg.sample_rate}')
|
||||||
|
|
||||||
|
model = __load_dac_model(**kwargs)
|
||||||
|
model = model.to(device)
|
||||||
|
model = model.eval()
|
||||||
|
|
||||||
|
# extra metadata
|
||||||
|
model.sample_rate = cfg.sample_rate
|
||||||
|
model.backend = "dac"
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def _load_model(device="cuda", backend=cfg.inference.audio_backend, levels=cfg.model.max_levels):
|
||||||
|
if backend == "dac":
|
||||||
|
return _load_dac_model(device, levels=levels)
|
||||||
|
if backend == "vocos":
|
||||||
|
return _load_vocos_model(device, levels=levels)
|
||||||
|
|
||||||
|
return _load_encodec_model(device, levels=levels)
|
||||||
|
|
||||||
def unload_model():
|
def unload_model():
|
||||||
_load_model.cache_clear()
|
_load_model.cache_clear()
|
||||||
_load_encodec_model.cache_clear()
|
_load_encodec_model.cache_clear() # because vocos can only decode
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels):
|
def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=None):
|
||||||
"""
|
# upcast so it won't whine
|
||||||
Args:
|
if codes.dtype == torch.int8 or codes.dtype == torch.int16 or codes.dtype == torch.uint8:
|
||||||
codes: (b q t)
|
codes = codes.to(torch.int32)
|
||||||
"""
|
|
||||||
|
|
||||||
# expand if we're given a raw 1-RVQ stream
|
# expand if we're given a raw 1-RVQ stream
|
||||||
if codes.dim() == 1:
|
if codes.dim() == 1:
|
||||||
@ -96,21 +196,49 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels):
|
|||||||
codes = rearrange(codes, "t q -> 1 q t")
|
codes = rearrange(codes, "t q -> 1 q t")
|
||||||
|
|
||||||
assert codes.dim() == 3, f'Requires shape (b q t) but got {codes.shape}'
|
assert codes.dim() == 3, f'Requires shape (b q t) but got {codes.shape}'
|
||||||
|
|
||||||
|
# load the model
|
||||||
model = _load_model(device, levels=levels)
|
model = _load_model(device, levels=levels)
|
||||||
|
|
||||||
# upcast so it won't whine
|
# DAC uses a different pathway
|
||||||
if codes.dtype == torch.int8 or codes.dtype == torch.int16 or codes.dtype == torch.uint8:
|
if model.backend == "dac":
|
||||||
codes = codes.to(torch.int32)
|
if metadata is None:
|
||||||
|
metadata = dict(
|
||||||
|
chunk_length=416,
|
||||||
|
original_length=0,
|
||||||
|
input_db=-12,
|
||||||
|
channels=1,
|
||||||
|
sample_rate=model.sample_rate,
|
||||||
|
padding=False,
|
||||||
|
dac_version='1.0.0',
|
||||||
|
)
|
||||||
|
# generate object with copied metadata
|
||||||
|
artifact = DACFile(
|
||||||
|
codes = codes,
|
||||||
|
# yes I can **kwargs from a dict but what if I want to pass the actual DACFile.metadata from elsewhere
|
||||||
|
chunk_length = metadata["chunk_length"] if isinstance(metadata, dict) else metadata.chunk_length,
|
||||||
|
original_length = metadata["original_length"] if isinstance(metadata, dict) else metadata.original_length,
|
||||||
|
input_db = metadata["input_db"] if isinstance(metadata, dict) else metadata.input_db,
|
||||||
|
channels = metadata["channels"] if isinstance(metadata, dict) else metadata.channels,
|
||||||
|
sample_rate = metadata["sample_rate"] if isinstance(metadata, dict) else metadata.sample_rate,
|
||||||
|
padding = metadata["padding"] if isinstance(metadata, dict) else metadata.padding,
|
||||||
|
dac_version = metadata["dac_version"] if isinstance(metadata, dict) else metadata.dac_version,
|
||||||
|
)
|
||||||
|
|
||||||
|
return model.decompress(artifact, verbose=False).audio_data[0], model.sample_rate
|
||||||
|
|
||||||
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if model.backend == "vocos":
|
if model.backend == "vocos":
|
||||||
x = model.codes_to_features(codes[0])
|
x = model.codes_to_features(codes[0])
|
||||||
kwargs['bandwidth_id'] = model.bandwidth_id
|
kwargs['bandwidth_id'] = model.bandwidth_id
|
||||||
else:
|
else:
|
||||||
|
# encodec will decode as a batch
|
||||||
x = [(codes.to(device), None)]
|
x = [(codes.to(device), None)]
|
||||||
|
|
||||||
wav = model.decode(x, **kwargs)
|
wav = model.decode(x, **kwargs)
|
||||||
|
|
||||||
|
# encodec will decode as a batch
|
||||||
if model.backend == "encodec":
|
if model.backend == "encodec":
|
||||||
wav = wav[0]
|
wav = wav[0]
|
||||||
|
|
||||||
@ -131,13 +259,14 @@ def _replace_file_extension(path, suffix):
|
|||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def encode(wav: Tensor, sr: int = 24_000, device="cuda", levels=cfg.model.max_levels):
|
def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.model.max_levels, return_metadata=False):
|
||||||
"""
|
if cfg.inference.audio_backend == "dac":
|
||||||
Args:
|
model = _load_dac_model(device, levels=levels)
|
||||||
wav: (t)
|
signal = AudioSignal(wav, sample_rate=model.sample_rate)
|
||||||
sr: int
|
artifact = model.compress(signal, 5.0, verbose=False, n_quantizers=levels if isinstance(levels, int) else None)
|
||||||
"""
|
return artifact.codes if not return_metadata else artifact
|
||||||
|
|
||||||
|
# vocos does not encode wavs to encodecs, so just use normal encodec
|
||||||
model = _load_encodec_model(device, levels=levels)
|
model = _load_encodec_model(device, levels=levels)
|
||||||
wav = wav.unsqueeze(0)
|
wav = wav.unsqueeze(0)
|
||||||
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
|
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
|
||||||
@ -180,8 +309,9 @@ def encode_from_file(path, device="cuda"):
|
|||||||
|
|
||||||
return qnt
|
return qnt
|
||||||
|
|
||||||
# Helper Functions
|
"""
|
||||||
|
Helper Functions
|
||||||
|
"""
|
||||||
# trims from the start, up to `target`
|
# trims from the start, up to `target`
|
||||||
def trim( qnt, target ):
|
def trim( qnt, target ):
|
||||||
length = max( qnt.shape[0], qnt.shape[1] )
|
length = max( qnt.shape[0], qnt.shape[1] )
|
||||||
@ -233,13 +363,14 @@ def merge_audio( *args, device="cpu", scale=[], levels=cfg.model.max_levels ):
|
|||||||
decoded[i] = decoded[i] * scale[i]
|
decoded[i] = decoded[i] * scale[i]
|
||||||
|
|
||||||
combined = sum(decoded) / len(decoded)
|
combined = sum(decoded) / len(decoded)
|
||||||
return encode(combined, 24_000, device="cpu", levels=levels)[0].t()
|
return encode(combined, cfg.sample_rate, device="cpu", levels=levels)[0].t()
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("folder", type=Path)
|
parser.add_argument("folder", type=Path)
|
||||||
parser.add_argument("--suffix", default=".wav")
|
parser.add_argument("--suffix", default=".wav")
|
||||||
parser.add_argument("--device", default="cuda")
|
parser.add_argument("--device", default="cuda")
|
||||||
|
parser.add_argument("--backend", default="encodec")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
device = args.device
|
device = args.device
|
||||||
|
@ -336,7 +336,9 @@ def example_usage():
|
|||||||
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
|
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
|
||||||
return torch.tensor([*map(symmap.get, phones)])
|
return torch.tensor([*map(symmap.get, phones)])
|
||||||
|
|
||||||
qnt = torch.load("data/qnt.pt")[0].t()[:, :cfg.model.prom_levels].to(device)
|
qnt = torch.load(f'data/qnt{".dac" if cfg.inference.audio_backend == "dac" else ""}.pt')[0].t()[:, :cfg.model.prom_levels].to(device)
|
||||||
|
|
||||||
|
print(qnt.shape)
|
||||||
|
|
||||||
cfg.hyperparameters.gradient_accumulation_steps = 1
|
cfg.hyperparameters.gradient_accumulation_steps = 1
|
||||||
|
|
||||||
@ -426,11 +428,15 @@ def example_usage():
|
|||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def sample( name, steps=600 ):
|
def sample( name, steps=600 ):
|
||||||
|
if cfg.inference.audio_backend == "dac" and name == "init":
|
||||||
|
return
|
||||||
|
|
||||||
engine.eval()
|
engine.eval()
|
||||||
resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 )
|
resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 )
|
||||||
|
|
||||||
for i, o in enumerate(resps_list):
|
if cfg.inference.audio_backend != "dac":
|
||||||
_ = decode_to_file(o, f"data/ar.{i}.{name}.wav", device=device)
|
for i, o in enumerate(resps_list):
|
||||||
|
_ = decode_to_file(o, f"data/ar.{i}.{name}.wav", device=device)
|
||||||
|
|
||||||
resps_list = [r.unsqueeze(-1) for r in resps_list]
|
resps_list = [r.unsqueeze(-1) for r in resps_list]
|
||||||
resps_list = engine( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 )
|
resps_list = engine( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 )
|
||||||
|
Loading…
Reference in New Issue
Block a user