DAC just doesn't work well enough......

This commit is contained in:
mrq 2024-05-25 11:07:52 -05:00
parent e3ef89f5aa
commit ddbacde0d1
8 changed files with 132 additions and 37 deletions

View File

@ -143,7 +143,8 @@ For audio backends:
* [`vocos`](https://huggingface.co/charactr/vocos-encodec-24khz): a higher quality EnCodec decoder.
- encoding audio will use the `encodec` backend automagically, as there's no EnCodec encoder under `vocos`
* [`descript-audio-codec`](https://github.com/descriptinc/descript-audio-codec): boasts better compression and quality
- **Note** models using `descript-audio-codec` at 24KHz + 8kbps will NOT converge. Audio encoded through the 44KHz seems to work.
- **Note** models using `descript-audio-codec` at 24KHz + 8kbps will NOT converge in any manner.
- **Note** models using `descript-audio-codec` at 44KHz + 8kbps stops improving after a while.
`llama`-based models also support different attention backends:
* `math`: torch's SDPA's `math` implementation

View File

@ -8,8 +8,8 @@ from pathlib import Path
from vall_e.config import cfg
# things that could be args
cfg.sample_rate = 44_000
cfg.inference.audio_backend = "dac"
cfg.sample_rate = 24_000
cfg.inference.audio_backend = "encodec"
"""
cfg.inference.weight_dtype = "bfloat16"
cfg.inference.dtype = torch.bfloat16

View File

@ -1,14 +1,29 @@
import os
import json
import torch
import numpy as np
from tqdm.auto import tqdm
from pathlib import Path
from vall_e.config import cfg
# things that could be args
cfg.sample_rate = 24_000
cfg.inference.audio_backend = "encodec"
"""
cfg.inference.weight_dtype = "bfloat16"
cfg.inference.dtype = torch.bfloat16
cfg.inference.amp = True
"""
from vall_e.emb.g2p import encode as valle_phonemize
from vall_e.emb.qnt import encode_from_file as valle_quantize, _replace_file_extension
audio_extension = ".dac" if cfg.inference.audio_backend == "dac" else ".enc"
input_dataset = "LibriTTS_R"
output_dataset = "LibriTTS-Train"
output_dataset = f"LibriTTS-Train-{'2' if cfg.sample_rate == 24_000 else '4'}4KHz"
device = "cuda"
txts = []
@ -32,24 +47,61 @@ for dataset_name in os.listdir(f'./{input_dataset}/'):
inpath = Path(f'./{input_dataset}/{dataset_name}/{speaker_id}/{book_id}/{filename}')
outpath = Path(f'./{output_dataset}/{speaker_id}/{filename}')
if ".original.txt" in filename and not _replace_file_extension(outpath, ".json").exists():
txts.append([inpath, outpath])
if ".wav" in filename and not _replace_file_extension(outpath, ".dac").exists():
wavs.append([inpath, outpath])
if ".wav" in filename: # and not _replace_file_extension(outpath, ".dac").exists():
txts.append((
inpath,
outpath
))
for paths in tqdm(txts, desc="Phonemizing..."):
text = open(paths[0], "r", encoding="utf-8").read()
phones = valle_phonemize(text)
data = {
"text": text,
"phonemes": phones,
"language": "english",
}
open(_replace_file_extension(paths[1], ".json"), 'w', encoding='utf-8').write(json.dumps(data))
#phones = valle_phonemize(open(paths[0], "r", encoding="utf-8").read())
#open(_replace_file_extension(paths[1], ".phn.txt"), "w", encoding="utf-8").write(" ".join(phones))
for paths in tqdm(txts, desc="Processing..."):
inpath, outpath = paths
try:
if _replace_file_extension(outpath, ".dac").exists() and _replace_file_extension(outpath, ".json").exists():
data = json.loads(open(_replace_file_extension(outpath, ".json"), 'r', encoding='utf-8').read())
qnt = np.load(_replace_file_extension(outpath, audio_extension), allow_pickle=True)
if not isinstance(data["phonemes"], str):
data["phonemes"] = "".join(data["phonemes"])
for paths in tqdm(wavs, desc="Quantizing..."):
qnt = valle_quantize(paths[0], device=device)
qnt.save(_replace_file_extension(paths[1], ".dac"))
#torch.save(qnt.cpu(), _replace_file_extension(paths[1], ".qnt.pt"))
for k, v in data.items():
qnt[()]['metadata'][k] = v
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), qnt)
else:
text = open(_replace_file_extension(inpath, ".original.txt"), "r", encoding="utf-8").read()
phones = valle_phonemize(text)
qnt = valle_quantize(_replace_file_extension(inpath, ".wav"), device=device)
if cfg.inference.audio_backend == "dac":
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": qnt.original_length,
"sample_rate": qnt.sample_rate,
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
"chunk_length": qnt.chunk_length,
"channels": qnt.channels,
"padding": qnt.padding,
"dac_version": "1.0.0",
"text": text.strip(),
"phonemes": "".join(phones),
"language": "en",
},
})
else:
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
"codes": qnt.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": qnt.shape[0] / 75.0,
"sample_rate": cfg.sample_rate,
"text": text.strip(),
"phonemes": "".join(phones),
"language": "en",
},
})
except Exception as e:
tqdm.write(f"Failed to process: {paths}: {e}")

View File

@ -156,7 +156,7 @@ class Dataset:
p_resp_append: float = 1.0
sample_type: str = "path" # path | speaker
tasks_list: list[str] = field(default_factory=lambda: ["tts"])
_frames_per_second: int = 0 # allows setting your own hint
@ -166,7 +166,7 @@ class Dataset:
if self._frames_per_second > 0:
return self._frames_per_second
if cfg.inference.audio_backend == "dac":
if cfg.audio_backend == "dac":
# using the 44KHz model with 24KHz sources has a frame rate of 41Hz
if cfg.variable_sample_rate and cfg.sample_rate == 24_000:
return 41
@ -567,7 +567,7 @@ class Inference:
amp: bool = False
normalize: bool = False # do NOT enable this unless you know exactly what you're doing
audio_backend: str = "vocos" # encodec, vocos, dac
audio_backend: str = "" # encodec, vocos, dac
# legacy / backwards compat
use_vocos: bool = True
@ -628,6 +628,8 @@ class Config(_Config):
sample_rate: int = 24_000
variable_sample_rate: bool = False # NOT recommended, as running directly 24Khz audio in the 44Khz DAC model will have detrimental quality loss
audio_backend: str = "vocos"
@property
def distributed(self):
return world_size() > 1
@ -726,6 +728,9 @@ class Config(_Config):
if self.trainer.backend == "local" and self.distributed:
self.trainer.ddp = True
if self.inference.audio_backend != "" and self.audio_backend == "":
self.audio_backend = self.inference.audio_backend
# Preserves the old behavior
class NaiveTokenizer:

View File

@ -63,10 +63,10 @@ def _replace_file_extension(path, suffix):
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
def _get_quant_extension():
return ".dac" if cfg.inference.audio_backend == "dac" else ".enc"
return ".dac" if cfg.audio_backend == "dac" else ".enc"
def _get_phone_extension():
return ".json" # if cfg.inference.audio_backend == "dac" else ".phn.txt"
return ".json" # if cfg.audio_backend == "dac" else ".phn.txt"
def _get_quant_path(path):
return _replace_file_extension(path, _get_quant_extension())
@ -876,12 +876,36 @@ def create_dataset_hdf5( skip_existing=True ):
if not os.path.isdir(f'{root}/{name}/'):
return
# tqdm.write(f'{root}/{name}')
files = os.listdir(f'{root}/{name}/')
# grab IDs for every file
ids = { file.replace(_get_quant_extension(), "").replace(_get_phone_extension(), "") for file in files }
"""
# rephonemizes if you fuck up and use and old tokenizer...
for id, entry in tqdm(metadata.items(), desc=f"Processing {name}"):
key = f'{type}/{speaker_name}/{id}'
if key not in hf:
continue
group = hf[key]
if "phonemes" not in entry:
continue
if "text" not in group:
continue
txt = entry["phonemes"]
phn = "".join(txt)
phn = cfg.tokenizer.encode(phn)
phn = np.array(phn).astype(np.uint8)
del group["text"]
group.create_dataset('text', data=phn, compression='lzf')
"""
for id in tqdm(ids, desc=f"Processing {name}"):
try:
quant_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True
@ -938,8 +962,10 @@ def create_dataset_hdf5( skip_existing=True ):
except Exception as e:
tqdm.write(f'Error while processing {id}: {e}')
"""
with open(str(metadata_path), "w", encoding="utf-8") as f:
f.write( json.dumps( metadata ) )
"""
# training

View File

@ -170,7 +170,7 @@ def _load_dac_model(device="cuda", levels=cfg.model.max_levels):
return model
@cache
def _load_model(device="cuda", backend=cfg.inference.audio_backend, levels=cfg.model.max_levels):
def _load_model(device="cuda", backend=cfg.audio_backend, levels=cfg.model.max_levels):
if backend == "dac":
return _load_dac_model(device, levels=levels)
if backend == "vocos":
@ -267,7 +267,7 @@ def _replace_file_extension(path, suffix):
@torch.inference_mode()
def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.model.max_levels, return_metadata=True):
if cfg.inference.audio_backend == "dac":
if cfg.audio_backend == "dac":
model = _load_dac_model(device, levels=levels )
signal = AudioSignal(wav, sample_rate=sr)
@ -307,7 +307,7 @@ def encode_from_files(paths, device="cuda"):
wav = torch.cat(wavs, dim=-1)
return encode(wav, sr, "cpu")
return encode(wav, sr, device)
def encode_from_file(path, device="cuda"):
if isinstance( path, list ):

View File

@ -112,7 +112,18 @@ class TTS():
paths = [ Path(p) for p in paths.split(";") ]
# merge inputs
res = torch.cat([qnt.encode_from_file(path)[0][:, :].t().to(torch.int16) for path in paths])
proms = []
for path in paths:
prom = qnt.encode_from_file(path)
if hasattr( prom, "codes" ):
prom = prom.codes
prom = prom[0][:, :].t().to(torch.int16)
proms.append( prom )
res = torch.cat(proms)
if trim_length:
res = trim( res, int( cfg.dataset.frames_per_second * trim_length ) )

View File

@ -319,7 +319,7 @@ class AR_NAR(Base):
def example_usage():
#cfg.trainer.backend = "local"
cfg.hyperparameters.gradient_accumulation_steps = 1
if cfg.inference.audio_backend == "dac":
if cfg.audio_backend == "dac":
cfg.sample_rate = 44_000
from functools import partial
@ -340,7 +340,7 @@ def example_usage():
return torch.tensor( cfg.tokenizer.encode(content) )
def _load_quants(path) -> Tensor:
if cfg.inference.audio_backend == "dac":
if cfg.audio_backend == "dac":
qnt = np.load(f'{path}.dac', allow_pickle=True)[()]
return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.prom_levels, :].t().to(torch.int16)
return torch.load(f'{path}.pt')[0][:, :cfg.model.prom_levels].t().to(torch.int16)
@ -456,13 +456,13 @@ def example_usage():
@torch.inference_mode()
def sample( name, steps=1000 ):
if cfg.inference.audio_backend == "dac" and name == "init":
if cfg.audio_backend == "dac" and name == "init":
return
engine.eval()
resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 )
if cfg.inference.audio_backend != "dac":
if cfg.audio_backend != "dac":
for i, o in enumerate(resps_list):
_ = decode_to_file(o, f"data/ar.{i}.{name}.wav", device=device)