ugh
This commit is contained in:
parent
4f1593c8db
commit
2437a86efa
21
README.md
21
README.md
|
@ -125,7 +125,7 @@ Unfortunately, efforts to train a *good* foundational model seems entirely predi
|
|||
|
||||
#### Backend Architectures
|
||||
|
||||
As the core of VALL-E makes use of a language model, various LLM architectures can be supported and slotted in. Currently supported:
|
||||
As the core of VALL-E makes use of a language model, various LLM architectures can be supported and slotted in. Currently supported LLm architectures:
|
||||
|
||||
* `llama`: using HF transformer's LLaMa implementation for its attention-based transformer, boasting RoPE and other improvements.
|
||||
* `mixtral`: using HF transformer's Mixtral implementation for its attention-based transformer, also utilizing its MoE implementation.
|
||||
|
@ -135,9 +135,24 @@ As the core of VALL-E makes use of a language model, various LLM architectures c
|
|||
* `retnet`: using [TorchScale's RetNet](https://github.com/microsoft/torchscale/blob/main/torchscale/architecture/retnet.py) implementation, a retention-based approach can be used instead.
|
||||
- Its implementation for MoE can also be utilized.
|
||||
* `retnet-hf`: using [syncdoth/RetNet/](https://github.com/syncdoth/RetNet) with a HuggingFace-compatible RetNet model
|
||||
- inferencing cost is about 0.5x, and MoE is not implemented.
|
||||
- has an inference penality, and MoE is not implemented.
|
||||
|
||||
It's recommended to use `llama` with `xformers`-based attention, as the savings are huge in comparison to even `retnet`-backed models.
|
||||
For audio backends:
|
||||
|
||||
* [`encodec`](https://github.com/facebookresearch/encodec): a tried-and-tested EnCodec to encode/decode audio.
|
||||
* [`vocos`](https://huggingface.co/charactr/vocos-encodec-24khz): a higher quality EnCodec decoder.
|
||||
- encoding audio will use the `encodec` backend automagically, as there's no EnCodec encoder under `vocos`
|
||||
* [`descript-audio-codec`](https://github.com/descriptinc/descript-audio-codec): boasts better compression and quality
|
||||
- **Note** models using `descript-audio-codec` at 24KHz + 6kbps will NOT converge. Unknown if 44KHz fares any better.
|
||||
|
||||
`llama`-based models also support different attention backends:
|
||||
* `math`: torch's SDPA's `math` implementation
|
||||
* `mem_efficient`: torch's SDPA's memory efficient (`xformers` adjacent) implementation
|
||||
* `flash`: torch's SDPA's flash attention implementation
|
||||
* `xformers`: [facebookresearch/xformers](https://github.com/facebookresearch/xformers/)'s memory efficient attention
|
||||
* `auto`: determine the best fit from the above
|
||||
* `sdpa`: integrated `LlamaSdpaAttention` attention model
|
||||
* `flash_attention_2`: integrated `LlamaFlashAttetion2` attention model
|
||||
|
||||
## Export
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ cfg.inference.audio_backend = "encodec"
|
|||
|
||||
input_audio = "voices"
|
||||
input_metadata = "./training/metadata"
|
||||
output_dataset = f"./training/data-{'2' if cfg.sample_rate else '4'}4KHz-{cfg.inference.audio_backend}"
|
||||
output_dataset = f"./training/data-{'2' if cfg.sample_rate == 24_000 else '4'}4KHz-{cfg.inference.audio_backend}"
|
||||
device = "cuda"
|
||||
|
||||
audio_extension = ".dac" if cfg.inference.audio_backend == "dac" else ".enc"
|
||||
|
|
|
@ -558,7 +558,7 @@ class Inference:
|
|||
amp: bool = False
|
||||
|
||||
normalize: bool = False # do NOT enable this unless you know exactly what you're doing
|
||||
audio_backend: str = "dac"
|
||||
audio_backend: str = "vocos" # encodec, vocos, dac
|
||||
|
||||
# legacy / backwards compat
|
||||
use_vocos: bool = True
|
||||
|
@ -731,6 +731,7 @@ try:
|
|||
if cfg.dataset.use_hdf5:
|
||||
cfg.load_hdf5()
|
||||
except Exception as e:
|
||||
cfg.dataset.use_hdf5 = False
|
||||
print("Error while parsing config YAML:", e)
|
||||
pass
|
||||
|
||||
|
|
166
vall_e/data.py
166
vall_e/data.py
|
@ -66,7 +66,7 @@ def _get_quant_extension():
|
|||
return ".dac" if cfg.inference.audio_backend == "dac" else ".qnt.pt"
|
||||
|
||||
def _get_phone_extension():
|
||||
return ".json" if cfg.inference.audio_backend == "dac" else ".phn.txt"
|
||||
return ".json" # if cfg.inference.audio_backend == "dac" else ".phn.txt"
|
||||
|
||||
def _get_quant_path(path):
|
||||
return _replace_file_extension(path, _get_quant_extension())
|
||||
|
@ -136,7 +136,7 @@ def _get_hdf5_path(path):
|
|||
def _get_hdf5_paths( data_dir, type="training", validate=False ):
|
||||
data_dir = str(data_dir)
|
||||
|
||||
def _validate(child):
|
||||
def _validate( child ):
|
||||
phones = child.attrs['phonemes']
|
||||
duration = child.attrs['duration']
|
||||
if type not in _total_durations:
|
||||
|
@ -145,7 +145,7 @@ def _get_hdf5_paths( data_dir, type="training", validate=False ):
|
|||
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
|
||||
|
||||
key = f"/{type}/{_get_hdf5_path(data_dir)}"
|
||||
return [ Path(f"{key}/{child}") for child in cfg.hdf5[key].keys() if not validate or _validate(child) ] if key in cfg.hdf5 else []
|
||||
return [ Path(f"{key}/{child.attrs['id']}") for child in cfg.hdf5[key].values() if not validate or _validate(child) ] if key in cfg.hdf5 else []
|
||||
|
||||
def _get_paths_of_extensions( path, extensions=_get_quant_extension(), validate=False ):
|
||||
if isinstance(path, str):
|
||||
|
@ -906,7 +906,6 @@ def create_dataset_hdf5( skip_existing=True ):
|
|||
if not audio_exists or not text_exists:
|
||||
continue
|
||||
|
||||
|
||||
key = f'{type}/{speaker_name}/{id}'
|
||||
|
||||
if skip_existing and key in hf:
|
||||
|
@ -1014,70 +1013,55 @@ def extract_dataset_hdf5( skip_existing=True ):
|
|||
|
||||
root = str(cfg.data_dir)
|
||||
|
||||
def add( dir, type="training", audios=True, texts=True ):
|
||||
name = str(dir)
|
||||
name = name.replace(root, "data/")
|
||||
def add( type="training", audios=True, texts=True ):
|
||||
for group in tqdm( hf[f'{type}/data/'].keys(), desc=f"Processing {type}"):
|
||||
for name in tqdm( hf[f'{type}/data/{group}'].keys(), desc=f"Processing {group}"):
|
||||
(cfg.data_dir / group / name).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
Path(f'{cfg.relpath}/{name}/').mkdir(parents=True, exist_ok=True)
|
||||
for id in tqdm( hf[f'{type}/data/{group}/{name}'].keys(), desc=f"Processing {name}"):
|
||||
try:
|
||||
key = f'{type}/data/{group}/{name}/{id}'
|
||||
|
||||
if f'{type}/{name}' not in hf:
|
||||
return
|
||||
if key not in hf:
|
||||
tqdm.write(f'Missing key: {key}')
|
||||
continue
|
||||
|
||||
ids = [ key for key in hf[f'{type}/{name}'].keys() ]
|
||||
audio_exists = "audio" in hf[key]
|
||||
text_exists = "text" in hf[key]
|
||||
|
||||
for id in tqdm(ids, desc=f"Processing {name}"):
|
||||
try:
|
||||
key = f'{type}/{name}/{id}'
|
||||
if not audio_exists or not text_exists:
|
||||
tqdm.write(f'Missing audio/text: {key}')
|
||||
continue
|
||||
|
||||
if key not in hf:
|
||||
tqdm.write(f'Missing key: {key}')
|
||||
continue
|
||||
audio_path = Path(f'{root}/{group}/{name}/{id}.enc')
|
||||
text_path = Path(f'{root}/{group}/{name}/{id}.json')
|
||||
|
||||
group = hf[key]
|
||||
audio_exists = "audio" in group
|
||||
text_exists = "text" in group
|
||||
# audio
|
||||
if audios and audio_exists and not audio_path.exists():
|
||||
qnt = hf[key]["audio"][:, :]
|
||||
torch.save( qnt, audio_path )
|
||||
|
||||
if not audio_exists or not text_exists:
|
||||
tqdm.write(f'Missing audio/text: {key}')
|
||||
continue
|
||||
# text
|
||||
if texts and text_exists and not text_path.exists():
|
||||
tokens = hf[key]["text"][:][1:-1]
|
||||
phones = [ reverse_symmap[f'{token}'] for token in tokens ]
|
||||
phones = list("".join(phones).replace(" ", " "))
|
||||
|
||||
audio_path = Path(f'{cfg.relpath}/{name}/{id}.enc')
|
||||
text_path = Path(f'{cfg.relpath}/{name}/{id}.json')
|
||||
j = {
|
||||
"text": "",
|
||||
"phonemes": phones,
|
||||
"language": "en"
|
||||
}
|
||||
|
||||
# audio
|
||||
if audios and audio_exists and not audio_path.exists():
|
||||
qnt = group["audio"][:, :]
|
||||
torch.save( qnt, f'{cfg.relpath}/{name}/{id}.enc' )
|
||||
with open(text_path, "w", encoding="utf-8") as f:
|
||||
f.write( json.dumps( j ) )
|
||||
|
||||
# text
|
||||
if texts and text_exists and not text_path.exists():
|
||||
tokens = group["text"][:][1:-1]
|
||||
phones = [ reverse_symmap[f'{token}'] for token in tokens ]
|
||||
phones = list("".join(phones).replace(" ", " "))
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
j = {
|
||||
"text": "",
|
||||
"phonemes": phones,
|
||||
"language": "en"
|
||||
}
|
||||
|
||||
with open(text_path, "w", encoding="utf-8") as f:
|
||||
f.write( json.dumps( j ) )
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
# training
|
||||
for data_dir in tqdm(cfg.dataset.training, desc="Processing Training"):
|
||||
add( data_dir, type="training" )
|
||||
|
||||
# validation
|
||||
for data_dir in tqdm(cfg.dataset.validation, desc='Processing Validation'):
|
||||
add( data_dir, type="validation" )
|
||||
|
||||
# noise
|
||||
for data_dir in tqdm(cfg.dataset.noise, desc='Processing Noise'):
|
||||
add( data_dir, type="noise", texts=False )
|
||||
add( type="training" )
|
||||
add( type="validation" )
|
||||
add( type="noise", texts=False )
|
||||
|
||||
hf.close()
|
||||
|
||||
|
@ -1091,49 +1075,38 @@ def retokenize_dataset_hdf5( skip_existing=True ):
|
|||
|
||||
root = str(cfg.data_dir)
|
||||
|
||||
def add( dir, type="training" ):
|
||||
name = str(dir)
|
||||
name = name.replace(root, "data/")
|
||||
def add( type="training" ):
|
||||
for group in tqdm( hf[f'{type}/data/'].keys(), desc=f"Processing {type}"):
|
||||
for name in tqdm( hf[f'{type}/data/{group}'].keys(), desc=f"Processing {group}"):
|
||||
(cfg.data_dir / group / name).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
Path(f'{cfg.relpath}/{name}/').mkdir(parents=True, exist_ok=True)
|
||||
for id in tqdm( hf[f'{type}/data/{group}/{name}'].keys(), desc=f"Processing {name}"):
|
||||
try:
|
||||
key = f'{type}/data/{group}/{name}/{id}'
|
||||
|
||||
if f'{type}/{name}' not in hf:
|
||||
return
|
||||
if key not in hf:
|
||||
tqdm.write(f'Missing key: {key}')
|
||||
continue
|
||||
|
||||
ids = [ key for key in hf[f'{type}/{name}'].keys() ]
|
||||
if "text" not in hf[key]:
|
||||
tqdm.write(f'Missing text: {key}')
|
||||
continue
|
||||
|
||||
for id in tqdm(ids, desc=f"Processing {name}"):
|
||||
try:
|
||||
key = f'{type}/{name}/{id}'
|
||||
# text
|
||||
tokens = hf[key]["text"][:][1:-1]
|
||||
content = list("".join([ reverse_symmap[f'{token}'] for token in tokens ]).replace(" ", " "))
|
||||
|
||||
if key not in hf:
|
||||
tqdm.write(f'Missing key: {key}')
|
||||
continue
|
||||
tokens = cfg.tokenizer.encode("".join(content))
|
||||
tokens = np.array(tokens).astype(np.uint8)
|
||||
|
||||
group = hf[key]
|
||||
if not "text" in group:
|
||||
tqdm.write(f'Missing text: {key}')
|
||||
continue
|
||||
del hf[key]['text']
|
||||
hf[key].create_dataset('text', data=tokens, compression='lzf')
|
||||
|
||||
tokens = group["text"][:][1:-1]
|
||||
content = list("".join([ reverse_symmap[f'{token}'] for token in tokens ]).replace(" ", " "))
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
tokens = cfg.tokenizer.encode("".join(content))
|
||||
tokens = np.array(tokens).astype(np.uint8)
|
||||
|
||||
del group['text']
|
||||
group.create_dataset('text', data=tokens, compression='lzf')
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
# training
|
||||
for data_dir in tqdm(cfg.dataset.training, desc="Processing Training"):
|
||||
add( data_dir, type="training" )
|
||||
|
||||
# validation
|
||||
for data_dir in tqdm(cfg.dataset.validation, desc='Processing Validation'):
|
||||
add( data_dir, type="validation" )
|
||||
add( type="training" )
|
||||
add( type="validation" )
|
||||
|
||||
# write symmap
|
||||
if "symmap" in hf:
|
||||
|
@ -1166,6 +1139,15 @@ if __name__ == "__main__":
|
|||
extract_dataset_hdf5()
|
||||
if args.action == "retokenize-hdf5":
|
||||
retokenize_dataset_hdf5()
|
||||
elif args.action == "list-dataset":
|
||||
dataset = []
|
||||
for group in os.listdir(cfg.data_dir):
|
||||
for name in os.listdir(cfg.data_dir / group):
|
||||
if len(os.listdir(cfg.data_dir / group / name)) == 0:
|
||||
continue
|
||||
dataset.append(f'{group}/{name}')
|
||||
|
||||
print(dataset)
|
||||
elif args.action == "metadata":
|
||||
create_dataset_metadata()
|
||||
elif args.action == "sample":
|
||||
|
|
|
@ -147,7 +147,7 @@ def _load_dac_model(device="cuda", levels=cfg.model.max_levels):
|
|||
if not cfg.variable_sample_rate:
|
||||
# yes there's a better way, something like f'{cfg.sample.rate//1000}hz'
|
||||
if cfg.sample_rate == 44_000:
|
||||
kwargs["model_type"] = "44kz"
|
||||
kwargs["model_type"] = "44khz"
|
||||
elif cfg.sample_rate == 24_000:
|
||||
kwargs["model_type"] = "24khz"
|
||||
elif cfg.sample_rate == 16_000:
|
||||
|
|
Loading…
Reference in New Issue
Block a user