This commit is contained in:
mrq 2024-05-12 13:02:15 -05:00
parent 4f1593c8db
commit 2437a86efa
5 changed files with 96 additions and 98 deletions

View File

@ -125,7 +125,7 @@ Unfortunately, efforts to train a *good* foundational model seems entirely predi
#### Backend Architectures
As the core of VALL-E makes use of a language model, various LLM architectures can be supported and slotted in. Currently supported:
As the core of VALL-E makes use of a language model, various LLM architectures can be supported and slotted in. Currently supported LLm architectures:
* `llama`: using HF transformer's LLaMa implementation for its attention-based transformer, boasting RoPE and other improvements.
* `mixtral`: using HF transformer's Mixtral implementation for its attention-based transformer, also utilizing its MoE implementation.
@ -135,9 +135,24 @@ As the core of VALL-E makes use of a language model, various LLM architectures c
* `retnet`: using [TorchScale's RetNet](https://github.com/microsoft/torchscale/blob/main/torchscale/architecture/retnet.py) implementation, a retention-based approach can be used instead.
- Its implementation for MoE can also be utilized.
* `retnet-hf`: using [syncdoth/RetNet/](https://github.com/syncdoth/RetNet) with a HuggingFace-compatible RetNet model
- inferencing cost is about 0.5x, and MoE is not implemented.
- has an inference penality, and MoE is not implemented.
It's recommended to use `llama` with `xformers`-based attention, as the savings are huge in comparison to even `retnet`-backed models.
For audio backends:
* [`encodec`](https://github.com/facebookresearch/encodec): a tried-and-tested EnCodec to encode/decode audio.
* [`vocos`](https://huggingface.co/charactr/vocos-encodec-24khz): a higher quality EnCodec decoder.
- encoding audio will use the `encodec` backend automagically, as there's no EnCodec encoder under `vocos`
* [`descript-audio-codec`](https://github.com/descriptinc/descript-audio-codec): boasts better compression and quality
- **Note** models using `descript-audio-codec` at 24KHz + 6kbps will NOT converge. Unknown if 44KHz fares any better.
`llama`-based models also support different attention backends:
* `math`: torch's SDPA's `math` implementation
* `mem_efficient`: torch's SDPA's memory efficient (`xformers` adjacent) implementation
* `flash`: torch's SDPA's flash attention implementation
* `xformers`: [facebookresearch/xformers](https://github.com/facebookresearch/xformers/)'s memory efficient attention
* `auto`: determine the best fit from the above
* `sdpa`: integrated `LlamaSdpaAttention` attention model
* `flash_attention_2`: integrated `LlamaFlashAttetion2` attention model
## Export

View File

@ -15,7 +15,7 @@ cfg.inference.audio_backend = "encodec"
input_audio = "voices"
input_metadata = "./training/metadata"
output_dataset = f"./training/data-{'2' if cfg.sample_rate else '4'}4KHz-{cfg.inference.audio_backend}"
output_dataset = f"./training/data-{'2' if cfg.sample_rate == 24_000 else '4'}4KHz-{cfg.inference.audio_backend}"
device = "cuda"
audio_extension = ".dac" if cfg.inference.audio_backend == "dac" else ".enc"

View File

@ -558,7 +558,7 @@ class Inference:
amp: bool = False
normalize: bool = False # do NOT enable this unless you know exactly what you're doing
audio_backend: str = "dac"
audio_backend: str = "vocos" # encodec, vocos, dac
# legacy / backwards compat
use_vocos: bool = True
@ -731,6 +731,7 @@ try:
if cfg.dataset.use_hdf5:
cfg.load_hdf5()
except Exception as e:
cfg.dataset.use_hdf5 = False
print("Error while parsing config YAML:", e)
pass

View File

@ -66,7 +66,7 @@ def _get_quant_extension():
return ".dac" if cfg.inference.audio_backend == "dac" else ".qnt.pt"
def _get_phone_extension():
return ".json" if cfg.inference.audio_backend == "dac" else ".phn.txt"
return ".json" # if cfg.inference.audio_backend == "dac" else ".phn.txt"
def _get_quant_path(path):
return _replace_file_extension(path, _get_quant_extension())
@ -136,7 +136,7 @@ def _get_hdf5_path(path):
def _get_hdf5_paths( data_dir, type="training", validate=False ):
data_dir = str(data_dir)
def _validate(child):
def _validate( child ):
phones = child.attrs['phonemes']
duration = child.attrs['duration']
if type not in _total_durations:
@ -145,7 +145,7 @@ def _get_hdf5_paths( data_dir, type="training", validate=False ):
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
key = f"/{type}/{_get_hdf5_path(data_dir)}"
return [ Path(f"{key}/{child}") for child in cfg.hdf5[key].keys() if not validate or _validate(child) ] if key in cfg.hdf5 else []
return [ Path(f"{key}/{child.attrs['id']}") for child in cfg.hdf5[key].values() if not validate or _validate(child) ] if key in cfg.hdf5 else []
def _get_paths_of_extensions( path, extensions=_get_quant_extension(), validate=False ):
if isinstance(path, str):
@ -906,7 +906,6 @@ def create_dataset_hdf5( skip_existing=True ):
if not audio_exists or not text_exists:
continue
key = f'{type}/{speaker_name}/{id}'
if skip_existing and key in hf:
@ -1014,44 +1013,37 @@ def extract_dataset_hdf5( skip_existing=True ):
root = str(cfg.data_dir)
def add( dir, type="training", audios=True, texts=True ):
name = str(dir)
name = name.replace(root, "data/")
def add( type="training", audios=True, texts=True ):
for group in tqdm( hf[f'{type}/data/'].keys(), desc=f"Processing {type}"):
for name in tqdm( hf[f'{type}/data/{group}'].keys(), desc=f"Processing {group}"):
(cfg.data_dir / group / name).mkdir(parents=True, exist_ok=True)
Path(f'{cfg.relpath}/{name}/').mkdir(parents=True, exist_ok=True)
if f'{type}/{name}' not in hf:
return
ids = [ key for key in hf[f'{type}/{name}'].keys() ]
for id in tqdm(ids, desc=f"Processing {name}"):
for id in tqdm( hf[f'{type}/data/{group}/{name}'].keys(), desc=f"Processing {name}"):
try:
key = f'{type}/{name}/{id}'
key = f'{type}/data/{group}/{name}/{id}'
if key not in hf:
tqdm.write(f'Missing key: {key}')
continue
group = hf[key]
audio_exists = "audio" in group
text_exists = "text" in group
audio_exists = "audio" in hf[key]
text_exists = "text" in hf[key]
if not audio_exists or not text_exists:
tqdm.write(f'Missing audio/text: {key}')
continue
audio_path = Path(f'{cfg.relpath}/{name}/{id}.enc')
text_path = Path(f'{cfg.relpath}/{name}/{id}.json')
audio_path = Path(f'{root}/{group}/{name}/{id}.enc')
text_path = Path(f'{root}/{group}/{name}/{id}.json')
# audio
if audios and audio_exists and not audio_path.exists():
qnt = group["audio"][:, :]
torch.save( qnt, f'{cfg.relpath}/{name}/{id}.enc' )
qnt = hf[key]["audio"][:, :]
torch.save( qnt, audio_path )
# text
if texts and text_exists and not text_path.exists():
tokens = group["text"][:][1:-1]
tokens = hf[key]["text"][:][1:-1]
phones = [ reverse_symmap[f'{token}'] for token in tokens ]
phones = list("".join(phones).replace(" ", " "))
@ -1067,17 +1059,9 @@ def extract_dataset_hdf5( skip_existing=True ):
except Exception as e:
raise e
# training
for data_dir in tqdm(cfg.dataset.training, desc="Processing Training"):
add( data_dir, type="training" )
# validation
for data_dir in tqdm(cfg.dataset.validation, desc='Processing Validation'):
add( data_dir, type="validation" )
# noise
for data_dir in tqdm(cfg.dataset.noise, desc='Processing Noise'):
add( data_dir, type="noise", texts=False )
add( type="training" )
add( type="validation" )
add( type="noise", texts=False )
hf.close()
@ -1091,49 +1075,38 @@ def retokenize_dataset_hdf5( skip_existing=True ):
root = str(cfg.data_dir)
def add( dir, type="training" ):
name = str(dir)
name = name.replace(root, "data/")
def add( type="training" ):
for group in tqdm( hf[f'{type}/data/'].keys(), desc=f"Processing {type}"):
for name in tqdm( hf[f'{type}/data/{group}'].keys(), desc=f"Processing {group}"):
(cfg.data_dir / group / name).mkdir(parents=True, exist_ok=True)
Path(f'{cfg.relpath}/{name}/').mkdir(parents=True, exist_ok=True)
if f'{type}/{name}' not in hf:
return
ids = [ key for key in hf[f'{type}/{name}'].keys() ]
for id in tqdm(ids, desc=f"Processing {name}"):
for id in tqdm( hf[f'{type}/data/{group}/{name}'].keys(), desc=f"Processing {name}"):
try:
key = f'{type}/{name}/{id}'
key = f'{type}/data/{group}/{name}/{id}'
if key not in hf:
tqdm.write(f'Missing key: {key}')
continue
group = hf[key]
if not "text" in group:
if "text" not in hf[key]:
tqdm.write(f'Missing text: {key}')
continue
tokens = group["text"][:][1:-1]
# text
tokens = hf[key]["text"][:][1:-1]
content = list("".join([ reverse_symmap[f'{token}'] for token in tokens ]).replace(" ", " "))
tokens = cfg.tokenizer.encode("".join(content))
tokens = np.array(tokens).astype(np.uint8)
del group['text']
group.create_dataset('text', data=tokens, compression='lzf')
del hf[key]['text']
hf[key].create_dataset('text', data=tokens, compression='lzf')
except Exception as e:
raise e
# training
for data_dir in tqdm(cfg.dataset.training, desc="Processing Training"):
add( data_dir, type="training" )
# validation
for data_dir in tqdm(cfg.dataset.validation, desc='Processing Validation'):
add( data_dir, type="validation" )
add( type="training" )
add( type="validation" )
# write symmap
if "symmap" in hf:
@ -1166,6 +1139,15 @@ if __name__ == "__main__":
extract_dataset_hdf5()
if args.action == "retokenize-hdf5":
retokenize_dataset_hdf5()
elif args.action == "list-dataset":
dataset = []
for group in os.listdir(cfg.data_dir):
for name in os.listdir(cfg.data_dir / group):
if len(os.listdir(cfg.data_dir / group / name)) == 0:
continue
dataset.append(f'{group}/{name}')
print(dataset)
elif args.action == "metadata":
create_dataset_metadata()
elif args.action == "sample":

View File

@ -147,7 +147,7 @@ def _load_dac_model(device="cuda", levels=cfg.model.max_levels):
if not cfg.variable_sample_rate:
# yes there's a better way, something like f'{cfg.sample.rate//1000}hz'
if cfg.sample_rate == 44_000:
kwargs["model_type"] = "44kz"
kwargs["model_type"] = "44khz"
elif cfg.sample_rate == 24_000:
kwargs["model_type"] = "24khz"
elif cfg.sample_rate == 16_000: