DAC just doesn't work well enough......
This commit is contained in:
parent
e3ef89f5aa
commit
ddbacde0d1
|
@ -143,7 +143,8 @@ For audio backends:
|
|||
* [`vocos`](https://huggingface.co/charactr/vocos-encodec-24khz): a higher quality EnCodec decoder.
|
||||
- encoding audio will use the `encodec` backend automagically, as there's no EnCodec encoder under `vocos`
|
||||
* [`descript-audio-codec`](https://github.com/descriptinc/descript-audio-codec): boasts better compression and quality
|
||||
- **Note** models using `descript-audio-codec` at 24KHz + 8kbps will NOT converge. Audio encoded through the 44KHz seems to work.
|
||||
- **Note** models using `descript-audio-codec` at 24KHz + 8kbps will NOT converge in any manner.
|
||||
- **Note** models using `descript-audio-codec` at 44KHz + 8kbps stops improving after a while.
|
||||
|
||||
`llama`-based models also support different attention backends:
|
||||
* `math`: torch's SDPA's `math` implementation
|
||||
|
|
|
@ -8,8 +8,8 @@ from pathlib import Path
|
|||
from vall_e.config import cfg
|
||||
|
||||
# things that could be args
|
||||
cfg.sample_rate = 44_000
|
||||
cfg.inference.audio_backend = "dac"
|
||||
cfg.sample_rate = 24_000
|
||||
cfg.inference.audio_backend = "encodec"
|
||||
"""
|
||||
cfg.inference.weight_dtype = "bfloat16"
|
||||
cfg.inference.dtype = torch.bfloat16
|
||||
|
|
|
@ -1,14 +1,29 @@
|
|||
import os
|
||||
import json
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
from pathlib import Path
|
||||
|
||||
from vall_e.config import cfg
|
||||
|
||||
# things that could be args
|
||||
cfg.sample_rate = 24_000
|
||||
cfg.inference.audio_backend = "encodec"
|
||||
"""
|
||||
cfg.inference.weight_dtype = "bfloat16"
|
||||
cfg.inference.dtype = torch.bfloat16
|
||||
cfg.inference.amp = True
|
||||
"""
|
||||
|
||||
from vall_e.emb.g2p import encode as valle_phonemize
|
||||
from vall_e.emb.qnt import encode_from_file as valle_quantize, _replace_file_extension
|
||||
|
||||
audio_extension = ".dac" if cfg.inference.audio_backend == "dac" else ".enc"
|
||||
|
||||
input_dataset = "LibriTTS_R"
|
||||
output_dataset = "LibriTTS-Train"
|
||||
output_dataset = f"LibriTTS-Train-{'2' if cfg.sample_rate == 24_000 else '4'}4KHz"
|
||||
device = "cuda"
|
||||
|
||||
txts = []
|
||||
|
@ -32,24 +47,61 @@ for dataset_name in os.listdir(f'./{input_dataset}/'):
|
|||
inpath = Path(f'./{input_dataset}/{dataset_name}/{speaker_id}/{book_id}/{filename}')
|
||||
outpath = Path(f'./{output_dataset}/{speaker_id}/{filename}')
|
||||
|
||||
if ".original.txt" in filename and not _replace_file_extension(outpath, ".json").exists():
|
||||
txts.append([inpath, outpath])
|
||||
if ".wav" in filename and not _replace_file_extension(outpath, ".dac").exists():
|
||||
wavs.append([inpath, outpath])
|
||||
if ".wav" in filename: # and not _replace_file_extension(outpath, ".dac").exists():
|
||||
txts.append((
|
||||
inpath,
|
||||
outpath
|
||||
))
|
||||
|
||||
for paths in tqdm(txts, desc="Phonemizing..."):
|
||||
text = open(paths[0], "r", encoding="utf-8").read()
|
||||
phones = valle_phonemize(text)
|
||||
data = {
|
||||
"text": text,
|
||||
"phonemes": phones,
|
||||
"language": "english",
|
||||
}
|
||||
open(_replace_file_extension(paths[1], ".json"), 'w', encoding='utf-8').write(json.dumps(data))
|
||||
#phones = valle_phonemize(open(paths[0], "r", encoding="utf-8").read())
|
||||
#open(_replace_file_extension(paths[1], ".phn.txt"), "w", encoding="utf-8").write(" ".join(phones))
|
||||
for paths in tqdm(txts, desc="Processing..."):
|
||||
inpath, outpath = paths
|
||||
try:
|
||||
if _replace_file_extension(outpath, ".dac").exists() and _replace_file_extension(outpath, ".json").exists():
|
||||
data = json.loads(open(_replace_file_extension(outpath, ".json"), 'r', encoding='utf-8').read())
|
||||
qnt = np.load(_replace_file_extension(outpath, audio_extension), allow_pickle=True)
|
||||
|
||||
if not isinstance(data["phonemes"], str):
|
||||
data["phonemes"] = "".join(data["phonemes"])
|
||||
|
||||
for paths in tqdm(wavs, desc="Quantizing..."):
|
||||
qnt = valle_quantize(paths[0], device=device)
|
||||
qnt.save(_replace_file_extension(paths[1], ".dac"))
|
||||
#torch.save(qnt.cpu(), _replace_file_extension(paths[1], ".qnt.pt"))
|
||||
for k, v in data.items():
|
||||
qnt[()]['metadata'][k] = v
|
||||
|
||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), qnt)
|
||||
else:
|
||||
text = open(_replace_file_extension(inpath, ".original.txt"), "r", encoding="utf-8").read()
|
||||
|
||||
phones = valle_phonemize(text)
|
||||
qnt = valle_quantize(_replace_file_extension(inpath, ".wav"), device=device)
|
||||
|
||||
if cfg.inference.audio_backend == "dac":
|
||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
||||
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
|
||||
"metadata": {
|
||||
"original_length": qnt.original_length,
|
||||
"sample_rate": qnt.sample_rate,
|
||||
|
||||
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
|
||||
"chunk_length": qnt.chunk_length,
|
||||
"channels": qnt.channels,
|
||||
"padding": qnt.padding,
|
||||
"dac_version": "1.0.0",
|
||||
|
||||
"text": text.strip(),
|
||||
"phonemes": "".join(phones),
|
||||
"language": "en",
|
||||
},
|
||||
})
|
||||
else:
|
||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
||||
"codes": qnt.cpu().numpy().astype(np.uint16),
|
||||
"metadata": {
|
||||
"original_length": qnt.shape[0] / 75.0,
|
||||
"sample_rate": cfg.sample_rate,
|
||||
|
||||
"text": text.strip(),
|
||||
"phonemes": "".join(phones),
|
||||
"language": "en",
|
||||
},
|
||||
})
|
||||
except Exception as e:
|
||||
tqdm.write(f"Failed to process: {paths}: {e}")
|
||||
|
|
|
@ -156,7 +156,7 @@ class Dataset:
|
|||
p_resp_append: float = 1.0
|
||||
|
||||
sample_type: str = "path" # path | speaker
|
||||
|
||||
|
||||
tasks_list: list[str] = field(default_factory=lambda: ["tts"])
|
||||
|
||||
_frames_per_second: int = 0 # allows setting your own hint
|
||||
|
@ -166,7 +166,7 @@ class Dataset:
|
|||
if self._frames_per_second > 0:
|
||||
return self._frames_per_second
|
||||
|
||||
if cfg.inference.audio_backend == "dac":
|
||||
if cfg.audio_backend == "dac":
|
||||
# using the 44KHz model with 24KHz sources has a frame rate of 41Hz
|
||||
if cfg.variable_sample_rate and cfg.sample_rate == 24_000:
|
||||
return 41
|
||||
|
@ -567,7 +567,7 @@ class Inference:
|
|||
amp: bool = False
|
||||
|
||||
normalize: bool = False # do NOT enable this unless you know exactly what you're doing
|
||||
audio_backend: str = "vocos" # encodec, vocos, dac
|
||||
audio_backend: str = "" # encodec, vocos, dac
|
||||
|
||||
# legacy / backwards compat
|
||||
use_vocos: bool = True
|
||||
|
@ -628,6 +628,8 @@ class Config(_Config):
|
|||
sample_rate: int = 24_000
|
||||
variable_sample_rate: bool = False # NOT recommended, as running directly 24Khz audio in the 44Khz DAC model will have detrimental quality loss
|
||||
|
||||
audio_backend: str = "vocos"
|
||||
|
||||
@property
|
||||
def distributed(self):
|
||||
return world_size() > 1
|
||||
|
@ -726,6 +728,9 @@ class Config(_Config):
|
|||
|
||||
if self.trainer.backend == "local" and self.distributed:
|
||||
self.trainer.ddp = True
|
||||
|
||||
if self.inference.audio_backend != "" and self.audio_backend == "":
|
||||
self.audio_backend = self.inference.audio_backend
|
||||
|
||||
# Preserves the old behavior
|
||||
class NaiveTokenizer:
|
||||
|
|
|
@ -63,10 +63,10 @@ def _replace_file_extension(path, suffix):
|
|||
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
|
||||
|
||||
def _get_quant_extension():
|
||||
return ".dac" if cfg.inference.audio_backend == "dac" else ".enc"
|
||||
return ".dac" if cfg.audio_backend == "dac" else ".enc"
|
||||
|
||||
def _get_phone_extension():
|
||||
return ".json" # if cfg.inference.audio_backend == "dac" else ".phn.txt"
|
||||
return ".json" # if cfg.audio_backend == "dac" else ".phn.txt"
|
||||
|
||||
def _get_quant_path(path):
|
||||
return _replace_file_extension(path, _get_quant_extension())
|
||||
|
@ -876,12 +876,36 @@ def create_dataset_hdf5( skip_existing=True ):
|
|||
|
||||
if not os.path.isdir(f'{root}/{name}/'):
|
||||
return
|
||||
# tqdm.write(f'{root}/{name}')
|
||||
|
||||
files = os.listdir(f'{root}/{name}/')
|
||||
|
||||
# grab IDs for every file
|
||||
ids = { file.replace(_get_quant_extension(), "").replace(_get_phone_extension(), "") for file in files }
|
||||
|
||||
"""
|
||||
# rephonemizes if you fuck up and use and old tokenizer...
|
||||
for id, entry in tqdm(metadata.items(), desc=f"Processing {name}"):
|
||||
key = f'{type}/{speaker_name}/{id}'
|
||||
|
||||
if key not in hf:
|
||||
continue
|
||||
|
||||
group = hf[key]
|
||||
|
||||
if "phonemes" not in entry:
|
||||
continue
|
||||
if "text" not in group:
|
||||
continue
|
||||
|
||||
txt = entry["phonemes"]
|
||||
phn = "".join(txt)
|
||||
phn = cfg.tokenizer.encode(phn)
|
||||
phn = np.array(phn).astype(np.uint8)
|
||||
|
||||
del group["text"]
|
||||
group.create_dataset('text', data=phn, compression='lzf')
|
||||
"""
|
||||
|
||||
for id in tqdm(ids, desc=f"Processing {name}"):
|
||||
try:
|
||||
quant_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True
|
||||
|
@ -938,8 +962,10 @@ def create_dataset_hdf5( skip_existing=True ):
|
|||
except Exception as e:
|
||||
tqdm.write(f'Error while processing {id}: {e}')
|
||||
|
||||
"""
|
||||
with open(str(metadata_path), "w", encoding="utf-8") as f:
|
||||
f.write( json.dumps( metadata ) )
|
||||
"""
|
||||
|
||||
|
||||
# training
|
||||
|
|
|
@ -170,7 +170,7 @@ def _load_dac_model(device="cuda", levels=cfg.model.max_levels):
|
|||
return model
|
||||
|
||||
@cache
|
||||
def _load_model(device="cuda", backend=cfg.inference.audio_backend, levels=cfg.model.max_levels):
|
||||
def _load_model(device="cuda", backend=cfg.audio_backend, levels=cfg.model.max_levels):
|
||||
if backend == "dac":
|
||||
return _load_dac_model(device, levels=levels)
|
||||
if backend == "vocos":
|
||||
|
@ -267,7 +267,7 @@ def _replace_file_extension(path, suffix):
|
|||
|
||||
@torch.inference_mode()
|
||||
def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.model.max_levels, return_metadata=True):
|
||||
if cfg.inference.audio_backend == "dac":
|
||||
if cfg.audio_backend == "dac":
|
||||
model = _load_dac_model(device, levels=levels )
|
||||
signal = AudioSignal(wav, sample_rate=sr)
|
||||
|
||||
|
@ -307,7 +307,7 @@ def encode_from_files(paths, device="cuda"):
|
|||
|
||||
wav = torch.cat(wavs, dim=-1)
|
||||
|
||||
return encode(wav, sr, "cpu")
|
||||
return encode(wav, sr, device)
|
||||
|
||||
def encode_from_file(path, device="cuda"):
|
||||
if isinstance( path, list ):
|
||||
|
|
|
@ -112,7 +112,18 @@ class TTS():
|
|||
paths = [ Path(p) for p in paths.split(";") ]
|
||||
|
||||
# merge inputs
|
||||
res = torch.cat([qnt.encode_from_file(path)[0][:, :].t().to(torch.int16) for path in paths])
|
||||
|
||||
proms = []
|
||||
|
||||
for path in paths:
|
||||
prom = qnt.encode_from_file(path)
|
||||
if hasattr( prom, "codes" ):
|
||||
prom = prom.codes
|
||||
prom = prom[0][:, :].t().to(torch.int16)
|
||||
|
||||
proms.append( prom )
|
||||
|
||||
res = torch.cat(proms)
|
||||
|
||||
if trim_length:
|
||||
res = trim( res, int( cfg.dataset.frames_per_second * trim_length ) )
|
||||
|
|
|
@ -319,7 +319,7 @@ class AR_NAR(Base):
|
|||
def example_usage():
|
||||
#cfg.trainer.backend = "local"
|
||||
cfg.hyperparameters.gradient_accumulation_steps = 1
|
||||
if cfg.inference.audio_backend == "dac":
|
||||
if cfg.audio_backend == "dac":
|
||||
cfg.sample_rate = 44_000
|
||||
|
||||
from functools import partial
|
||||
|
@ -340,7 +340,7 @@ def example_usage():
|
|||
return torch.tensor( cfg.tokenizer.encode(content) )
|
||||
|
||||
def _load_quants(path) -> Tensor:
|
||||
if cfg.inference.audio_backend == "dac":
|
||||
if cfg.audio_backend == "dac":
|
||||
qnt = np.load(f'{path}.dac', allow_pickle=True)[()]
|
||||
return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.prom_levels, :].t().to(torch.int16)
|
||||
return torch.load(f'{path}.pt')[0][:, :cfg.model.prom_levels].t().to(torch.int16)
|
||||
|
@ -456,13 +456,13 @@ def example_usage():
|
|||
|
||||
@torch.inference_mode()
|
||||
def sample( name, steps=1000 ):
|
||||
if cfg.inference.audio_backend == "dac" and name == "init":
|
||||
if cfg.audio_backend == "dac" and name == "init":
|
||||
return
|
||||
|
||||
engine.eval()
|
||||
resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 )
|
||||
|
||||
if cfg.inference.audio_backend != "dac":
|
||||
if cfg.audio_backend != "dac":
|
||||
for i, o in enumerate(resps_list):
|
||||
_ = decode_to_file(o, f"data/ar.{i}.{name}.wav", device=device)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user