DAC just doesn't work well enough......

This commit is contained in:
mrq 2024-05-25 11:07:52 -05:00
parent e3ef89f5aa
commit ddbacde0d1
8 changed files with 132 additions and 37 deletions

View File

@ -143,7 +143,8 @@ For audio backends:
* [`vocos`](https://huggingface.co/charactr/vocos-encodec-24khz): a higher quality EnCodec decoder. * [`vocos`](https://huggingface.co/charactr/vocos-encodec-24khz): a higher quality EnCodec decoder.
- encoding audio will use the `encodec` backend automagically, as there's no EnCodec encoder under `vocos` - encoding audio will use the `encodec` backend automagically, as there's no EnCodec encoder under `vocos`
* [`descript-audio-codec`](https://github.com/descriptinc/descript-audio-codec): boasts better compression and quality * [`descript-audio-codec`](https://github.com/descriptinc/descript-audio-codec): boasts better compression and quality
- **Note** models using `descript-audio-codec` at 24KHz + 8kbps will NOT converge. Audio encoded through the 44KHz seems to work. - **Note** models using `descript-audio-codec` at 24KHz + 8kbps will NOT converge in any manner.
- **Note** models using `descript-audio-codec` at 44KHz + 8kbps stops improving after a while.
`llama`-based models also support different attention backends: `llama`-based models also support different attention backends:
* `math`: torch's SDPA's `math` implementation * `math`: torch's SDPA's `math` implementation

View File

@ -8,8 +8,8 @@ from pathlib import Path
from vall_e.config import cfg from vall_e.config import cfg
# things that could be args # things that could be args
cfg.sample_rate = 44_000 cfg.sample_rate = 24_000
cfg.inference.audio_backend = "dac" cfg.inference.audio_backend = "encodec"
""" """
cfg.inference.weight_dtype = "bfloat16" cfg.inference.weight_dtype = "bfloat16"
cfg.inference.dtype = torch.bfloat16 cfg.inference.dtype = torch.bfloat16

View File

@ -1,14 +1,29 @@
import os import os
import json import json
import torch import torch
import numpy as np
from tqdm.auto import tqdm from tqdm.auto import tqdm
from pathlib import Path from pathlib import Path
from vall_e.config import cfg
# things that could be args
cfg.sample_rate = 24_000
cfg.inference.audio_backend = "encodec"
"""
cfg.inference.weight_dtype = "bfloat16"
cfg.inference.dtype = torch.bfloat16
cfg.inference.amp = True
"""
from vall_e.emb.g2p import encode as valle_phonemize from vall_e.emb.g2p import encode as valle_phonemize
from vall_e.emb.qnt import encode_from_file as valle_quantize, _replace_file_extension from vall_e.emb.qnt import encode_from_file as valle_quantize, _replace_file_extension
audio_extension = ".dac" if cfg.inference.audio_backend == "dac" else ".enc"
input_dataset = "LibriTTS_R" input_dataset = "LibriTTS_R"
output_dataset = "LibriTTS-Train" output_dataset = f"LibriTTS-Train-{'2' if cfg.sample_rate == 24_000 else '4'}4KHz"
device = "cuda" device = "cuda"
txts = [] txts = []
@ -32,24 +47,61 @@ for dataset_name in os.listdir(f'./{input_dataset}/'):
inpath = Path(f'./{input_dataset}/{dataset_name}/{speaker_id}/{book_id}/{filename}') inpath = Path(f'./{input_dataset}/{dataset_name}/{speaker_id}/{book_id}/{filename}')
outpath = Path(f'./{output_dataset}/{speaker_id}/{filename}') outpath = Path(f'./{output_dataset}/{speaker_id}/{filename}')
if ".original.txt" in filename and not _replace_file_extension(outpath, ".json").exists(): if ".wav" in filename: # and not _replace_file_extension(outpath, ".dac").exists():
txts.append([inpath, outpath]) txts.append((
if ".wav" in filename and not _replace_file_extension(outpath, ".dac").exists(): inpath,
wavs.append([inpath, outpath]) outpath
))
for paths in tqdm(txts, desc="Processing..."):
inpath, outpath = paths
try:
if _replace_file_extension(outpath, ".dac").exists() and _replace_file_extension(outpath, ".json").exists():
data = json.loads(open(_replace_file_extension(outpath, ".json"), 'r', encoding='utf-8').read())
qnt = np.load(_replace_file_extension(outpath, audio_extension), allow_pickle=True)
if not isinstance(data["phonemes"], str):
data["phonemes"] = "".join(data["phonemes"])
for k, v in data.items():
qnt[()]['metadata'][k] = v
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), qnt)
else:
text = open(_replace_file_extension(inpath, ".original.txt"), "r", encoding="utf-8").read()
for paths in tqdm(txts, desc="Phonemizing..."):
text = open(paths[0], "r", encoding="utf-8").read()
phones = valle_phonemize(text) phones = valle_phonemize(text)
data = { qnt = valle_quantize(_replace_file_extension(inpath, ".wav"), device=device)
"text": text,
"phonemes": phones,
"language": "english",
}
open(_replace_file_extension(paths[1], ".json"), 'w', encoding='utf-8').write(json.dumps(data))
#phones = valle_phonemize(open(paths[0], "r", encoding="utf-8").read())
#open(_replace_file_extension(paths[1], ".phn.txt"), "w", encoding="utf-8").write(" ".join(phones))
for paths in tqdm(wavs, desc="Quantizing..."): if cfg.inference.audio_backend == "dac":
qnt = valle_quantize(paths[0], device=device) np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
qnt.save(_replace_file_extension(paths[1], ".dac")) "codes": qnt.codes.cpu().numpy().astype(np.uint16),
#torch.save(qnt.cpu(), _replace_file_extension(paths[1], ".qnt.pt")) "metadata": {
"original_length": qnt.original_length,
"sample_rate": qnt.sample_rate,
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
"chunk_length": qnt.chunk_length,
"channels": qnt.channels,
"padding": qnt.padding,
"dac_version": "1.0.0",
"text": text.strip(),
"phonemes": "".join(phones),
"language": "en",
},
})
else:
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
"codes": qnt.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": qnt.shape[0] / 75.0,
"sample_rate": cfg.sample_rate,
"text": text.strip(),
"phonemes": "".join(phones),
"language": "en",
},
})
except Exception as e:
tqdm.write(f"Failed to process: {paths}: {e}")

View File

@ -166,7 +166,7 @@ class Dataset:
if self._frames_per_second > 0: if self._frames_per_second > 0:
return self._frames_per_second return self._frames_per_second
if cfg.inference.audio_backend == "dac": if cfg.audio_backend == "dac":
# using the 44KHz model with 24KHz sources has a frame rate of 41Hz # using the 44KHz model with 24KHz sources has a frame rate of 41Hz
if cfg.variable_sample_rate and cfg.sample_rate == 24_000: if cfg.variable_sample_rate and cfg.sample_rate == 24_000:
return 41 return 41
@ -567,7 +567,7 @@ class Inference:
amp: bool = False amp: bool = False
normalize: bool = False # do NOT enable this unless you know exactly what you're doing normalize: bool = False # do NOT enable this unless you know exactly what you're doing
audio_backend: str = "vocos" # encodec, vocos, dac audio_backend: str = "" # encodec, vocos, dac
# legacy / backwards compat # legacy / backwards compat
use_vocos: bool = True use_vocos: bool = True
@ -628,6 +628,8 @@ class Config(_Config):
sample_rate: int = 24_000 sample_rate: int = 24_000
variable_sample_rate: bool = False # NOT recommended, as running directly 24Khz audio in the 44Khz DAC model will have detrimental quality loss variable_sample_rate: bool = False # NOT recommended, as running directly 24Khz audio in the 44Khz DAC model will have detrimental quality loss
audio_backend: str = "vocos"
@property @property
def distributed(self): def distributed(self):
return world_size() > 1 return world_size() > 1
@ -727,6 +729,9 @@ class Config(_Config):
if self.trainer.backend == "local" and self.distributed: if self.trainer.backend == "local" and self.distributed:
self.trainer.ddp = True self.trainer.ddp = True
if self.inference.audio_backend != "" and self.audio_backend == "":
self.audio_backend = self.inference.audio_backend
# Preserves the old behavior # Preserves the old behavior
class NaiveTokenizer: class NaiveTokenizer:
def get_vocab( self ): def get_vocab( self ):

View File

@ -63,10 +63,10 @@ def _replace_file_extension(path, suffix):
return (path.parent / path.name.split(".")[0]).with_suffix(suffix) return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
def _get_quant_extension(): def _get_quant_extension():
return ".dac" if cfg.inference.audio_backend == "dac" else ".enc" return ".dac" if cfg.audio_backend == "dac" else ".enc"
def _get_phone_extension(): def _get_phone_extension():
return ".json" # if cfg.inference.audio_backend == "dac" else ".phn.txt" return ".json" # if cfg.audio_backend == "dac" else ".phn.txt"
def _get_quant_path(path): def _get_quant_path(path):
return _replace_file_extension(path, _get_quant_extension()) return _replace_file_extension(path, _get_quant_extension())
@ -876,12 +876,36 @@ def create_dataset_hdf5( skip_existing=True ):
if not os.path.isdir(f'{root}/{name}/'): if not os.path.isdir(f'{root}/{name}/'):
return return
# tqdm.write(f'{root}/{name}')
files = os.listdir(f'{root}/{name}/') files = os.listdir(f'{root}/{name}/')
# grab IDs for every file # grab IDs for every file
ids = { file.replace(_get_quant_extension(), "").replace(_get_phone_extension(), "") for file in files } ids = { file.replace(_get_quant_extension(), "").replace(_get_phone_extension(), "") for file in files }
"""
# rephonemizes if you fuck up and use and old tokenizer...
for id, entry in tqdm(metadata.items(), desc=f"Processing {name}"):
key = f'{type}/{speaker_name}/{id}'
if key not in hf:
continue
group = hf[key]
if "phonemes" not in entry:
continue
if "text" not in group:
continue
txt = entry["phonemes"]
phn = "".join(txt)
phn = cfg.tokenizer.encode(phn)
phn = np.array(phn).astype(np.uint8)
del group["text"]
group.create_dataset('text', data=phn, compression='lzf')
"""
for id in tqdm(ids, desc=f"Processing {name}"): for id in tqdm(ids, desc=f"Processing {name}"):
try: try:
quant_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True quant_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True
@ -938,8 +962,10 @@ def create_dataset_hdf5( skip_existing=True ):
except Exception as e: except Exception as e:
tqdm.write(f'Error while processing {id}: {e}') tqdm.write(f'Error while processing {id}: {e}')
"""
with open(str(metadata_path), "w", encoding="utf-8") as f: with open(str(metadata_path), "w", encoding="utf-8") as f:
f.write( json.dumps( metadata ) ) f.write( json.dumps( metadata ) )
"""
# training # training

View File

@ -170,7 +170,7 @@ def _load_dac_model(device="cuda", levels=cfg.model.max_levels):
return model return model
@cache @cache
def _load_model(device="cuda", backend=cfg.inference.audio_backend, levels=cfg.model.max_levels): def _load_model(device="cuda", backend=cfg.audio_backend, levels=cfg.model.max_levels):
if backend == "dac": if backend == "dac":
return _load_dac_model(device, levels=levels) return _load_dac_model(device, levels=levels)
if backend == "vocos": if backend == "vocos":
@ -267,7 +267,7 @@ def _replace_file_extension(path, suffix):
@torch.inference_mode() @torch.inference_mode()
def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.model.max_levels, return_metadata=True): def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.model.max_levels, return_metadata=True):
if cfg.inference.audio_backend == "dac": if cfg.audio_backend == "dac":
model = _load_dac_model(device, levels=levels ) model = _load_dac_model(device, levels=levels )
signal = AudioSignal(wav, sample_rate=sr) signal = AudioSignal(wav, sample_rate=sr)
@ -307,7 +307,7 @@ def encode_from_files(paths, device="cuda"):
wav = torch.cat(wavs, dim=-1) wav = torch.cat(wavs, dim=-1)
return encode(wav, sr, "cpu") return encode(wav, sr, device)
def encode_from_file(path, device="cuda"): def encode_from_file(path, device="cuda"):
if isinstance( path, list ): if isinstance( path, list ):

View File

@ -112,7 +112,18 @@ class TTS():
paths = [ Path(p) for p in paths.split(";") ] paths = [ Path(p) for p in paths.split(";") ]
# merge inputs # merge inputs
res = torch.cat([qnt.encode_from_file(path)[0][:, :].t().to(torch.int16) for path in paths])
proms = []
for path in paths:
prom = qnt.encode_from_file(path)
if hasattr( prom, "codes" ):
prom = prom.codes
prom = prom[0][:, :].t().to(torch.int16)
proms.append( prom )
res = torch.cat(proms)
if trim_length: if trim_length:
res = trim( res, int( cfg.dataset.frames_per_second * trim_length ) ) res = trim( res, int( cfg.dataset.frames_per_second * trim_length ) )

View File

@ -319,7 +319,7 @@ class AR_NAR(Base):
def example_usage(): def example_usage():
#cfg.trainer.backend = "local" #cfg.trainer.backend = "local"
cfg.hyperparameters.gradient_accumulation_steps = 1 cfg.hyperparameters.gradient_accumulation_steps = 1
if cfg.inference.audio_backend == "dac": if cfg.audio_backend == "dac":
cfg.sample_rate = 44_000 cfg.sample_rate = 44_000
from functools import partial from functools import partial
@ -340,7 +340,7 @@ def example_usage():
return torch.tensor( cfg.tokenizer.encode(content) ) return torch.tensor( cfg.tokenizer.encode(content) )
def _load_quants(path) -> Tensor: def _load_quants(path) -> Tensor:
if cfg.inference.audio_backend == "dac": if cfg.audio_backend == "dac":
qnt = np.load(f'{path}.dac', allow_pickle=True)[()] qnt = np.load(f'{path}.dac', allow_pickle=True)[()]
return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.prom_levels, :].t().to(torch.int16) return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.prom_levels, :].t().to(torch.int16)
return torch.load(f'{path}.pt')[0][:, :cfg.model.prom_levels].t().to(torch.int16) return torch.load(f'{path}.pt')[0][:, :cfg.model.prom_levels].t().to(torch.int16)
@ -456,13 +456,13 @@ def example_usage():
@torch.inference_mode() @torch.inference_mode()
def sample( name, steps=1000 ): def sample( name, steps=1000 ):
if cfg.inference.audio_backend == "dac" and name == "init": if cfg.audio_backend == "dac" and name == "init":
return return
engine.eval() engine.eval()
resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 ) resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 )
if cfg.inference.audio_backend != "dac": if cfg.audio_backend != "dac":
for i, o in enumerate(resps_list): for i, o in enumerate(resps_list):
_ = decode_to_file(o, f"data/ar.{i}.{name}.wav", device=device) _ = decode_to_file(o, f"data/ar.{i}.{name}.wav", device=device)