validated that SpeechX tasks cse and nse works, added a method to test each task by invoking python3 -m vall_e.data --action=tasks --tasks='sr,se,cse,nse'

This commit is contained in:
mrq 2023-08-19 09:50:07 -05:00
parent 6ca347e1e1
commit f7f6d3bf6d
4 changed files with 84 additions and 134 deletions

View File

@ -63,7 +63,7 @@ If you're wanting to increase the `prom_levels` for a given model, or increase t
4. Customize your configuration and define the dataset by modifying `./data/config.yaml`. Refer to `./vall_e/config.py` for details. If you want to choose between different model presets, check `./vall_e/models/__init__.py`. 4. Customize your configuration and define the dataset by modifying `./data/config.yaml`. Refer to `./vall_e/config.py` for details. If you want to choose between different model presets, check `./vall_e/models/__init__.py`.
If you're interested in creating an HDF5 copy of your dataset, simply invoke: `python -m vall_e.data --create-hdf5 yaml='./data/config.yaml'` If you're interested in creating an HDF5 copy of your dataset, simply invoke: `python -m vall_e.data --action='hdf5' yaml='./data/config.yaml'`
5. Train the AR and NAR models using the following scripts: `python -m vall_e.train yaml=./data/config.yaml` 5. Train the AR and NAR models using the following scripts: `python -m vall_e.train yaml=./data/config.yaml`
@ -81,32 +81,18 @@ Two dataset formats are supported:
## Export ## Export
Both trained models *can* be exported, but is only required if loading them on systems without DeepSpeed for inferencing (Windows systems). To export the models, run: Both trained models *can* be exported, but is only required if loading them on systems without DeepSpeed for inferencing (Windows systems). To export the models, run: `python -m vall_e.export yaml=./data/config.yaml`.
``` This will export the latest checkpoints under `./data/ckpt/ar-retnet-2/fp32.pth` and `./data/ckpt/nar-retnet-2/fp32.pth` to be loaded on any system with PyTorch.
python -m vall_e.export yaml=./data/config.yaml
```
This will export the latest checkpoints.
## Synthesis ## Synthesis
To synthesize speech, invoke either (if exported the models): To synthesize speech, invoke either (if exported the models): `python -m vall_e <text> <ref_path> <out_path> --ar-ckpt ./models/ar.pt --nar-ckpt ./models/nar.pt` or `python -m vall_e <text> <ref_path> <out_path> yaml=<yaml_path>`
```
python -m vall_e <text> <ref_path> <out_path> --ar-ckpt ./models/ar.pt --nar-ckpt ./models/nar.pt
```
or:
```
python -m vall_e <text> <ref_path> <out_path> yaml=<yaml_path>
```
Some additional flags you can pass are: Some additional flags you can pass are:
* `--max-ar-steps`: maximum steps for inferencing through the AR model. Each second is 75 steps. * `--max-ar-steps`: maximum steps for inferencing through the AR model. Each second is 75 steps.
* `--ar-temp`: sampling temperature to use for the AR pass. * `--ar-temp`: sampling temperature to use for the AR pass. During experimentation, `0.95` provides the most consistent output.
* `--nar-temp`: sampling temperature to use for the NAR pass. * `--nar-temp`: sampling temperature to use for the NAR pass. During experimentation, `0.2` provides the most clean output.
* `--device`: device to use (default: `cuda`, examples: `cuda:0`, `cuda:1`, `cpu`) * `--device`: device to use (default: `cuda`, examples: `cuda:0`, `cuda:1`, `cpu`)
## To-Do ## To-Do

View File

@ -72,7 +72,7 @@ evaluation:
steps: 300 steps: 300
ar_temperature: 1.0 ar_temperature: 1.0
nar_temperature: 1.0 nar_temperature: 0.2
trainer: trainer:
iterations: 1_000_000 iterations: 1_000_000

View File

@ -446,6 +446,18 @@ class Config(_Config):
tmp = Config.from_yaml( config_path ) tmp = Config.from_yaml( config_path )
self.__dict__.update(tmp.__dict__) self.__dict__.update(tmp.__dict__)
def load_hdf5( self, write=False ):
if hasattr(self, 'hdf5'):
self.hdf5.close()
if self.distributed:
self.dataset.hdf5_flag = "r"
try:
self.hdf5 = h5py.File(f'{self.cfg_path}/{self.dataset.hdf5_name}', 'a' if write else self.dataset.hdf5_flag) # to-do, have an easy to set flag that determines if training or creating the dataset
except Exception as e:
print("Error while opening HDF5 file:", f'{self.cfg_path}/{self.dataset.hdf5_name}', str(e))
self.dataset.use_hdf5 = False
def format( self ): def format( self ):
self.dataset = Dataset(**self.dataset) self.dataset = Dataset(**self.dataset)
self.models = Models(**self.models) self.models = Models(**self.models)
@ -466,13 +478,7 @@ try:
# cached_property stopped working... # cached_property stopped working...
if cfg.dataset.use_hdf5: if cfg.dataset.use_hdf5:
if cfg.distributed: cfg.load_hdf5()
cfg.dataset.hdf5_flag = "r"
try:
cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', cfg.dataset.hdf5_flag) # to-do, have an easy to set flag that determines if training or creating the dataset
except Exception as e:
print("Error while opening HDF5 file:", f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', str(e))
cfg.dataset.use_hdf5 = False
if not cfg.dataset.use_hdf5: if not cfg.dataset.use_hdf5:
cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ] cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ]

View File

@ -300,6 +300,11 @@ class Dataset(_Dataset):
task = random.choice(self.tasks) task = random.choice(self.tasks)
# ensure a speaker has at least four utterances
# default to tts if not
if len(set(self.paths_by_spkr_name[spkr_name]) - {path}) < 4:
task = "tts"
noise_scale = 0.125 noise_scale = 0.125
# text-to-speech # text-to-speech
if task == "tts": if task == "tts":
@ -349,7 +354,7 @@ class Dataset(_Dataset):
# clean speech editing # clean speech editing
elif task == "cse" or task == "nse": elif task == "cse" or task == "nse":
choices = set(self.paths_by_spkr_name[spkr_name]) - {path} choices = set(self.paths_by_spkr_name[spkr_name]) - {path}
sampled = random.choice([*choices], 4) sampled = random.sample([*choices], 4)
if cfg.dataset.use_hdf5: if cfg.dataset.use_hdf5:
texts = [ torch.from_numpy(cfg.hdf5[_get_hdf5_path(path)]["text"][:]).to(self.text_dtype) for path in sampled ] texts = [ torch.from_numpy(cfg.hdf5[_get_hdf5_path(path)]["text"][:]).to(self.text_dtype) for path in sampled ]
@ -359,8 +364,8 @@ class Dataset(_Dataset):
qnts = [ _load_quants(path) for path in sampled ] qnts = [ _load_quants(path) for path in sampled ]
# remove <s></s> # remove <s></s>
for text in texts: for i in range(len(texts)):
text = text[1:-1] texts[i] = texts[i][1:-1]
pre_text, mid_text, post_text, edit_text = texts pre_text, mid_text, post_text, edit_text = texts
pre_prom, mid_prom, post_prom, edit_prom = qnts pre_prom, mid_prom, post_prom, edit_prom = qnts
@ -376,11 +381,11 @@ class Dataset(_Dataset):
# create new text # create new text
text = torch.cat( text = torch.cat(
[ 1 ] + # <s> [ torch.Tensor( [ 1 ] ).to(dtype=self.text_dtype) ] + # <s>
([ pre_text ] if pre_text is not None else []) + ([ pre_text, torch.Tensor( [ 3 ] ).to(dtype=self.text_dtype) ] if pre_text is not None else []) + # pre_text + space'
[ edit_text ] + [ edit_text ] + # 'edit text'
([ post_post ] if post_post is not None else []) + ([ torch.Tensor( [ 3 ] ).to(dtype=self.text_dtype), post_text ] if post_text is not None else []) + # 'space' + edit_text
[ 2 ] # </s> [ torch.Tensor( [ 2 ] ).to(dtype=self.text_dtype) ] # </s>
) )
if task == "nse": if task == "nse":
@ -397,7 +402,7 @@ class Dataset(_Dataset):
# extend the noise to fill the target audio # extend the noise to fill the target audio
n = repeat_extend_audio(noise, proms.shape[0]) n = repeat_extend_audio(noise, proms.shape[0])
# merge the noise over the utterance # merge the noise over the utterance
return merge_audio(proms, noise, scale=[1, noise_scale], device="cpu") return merge_audio(proms, n, scale=[1, noise_scale], device="cpu")
# apply noise to all pieces # apply noise to all pieces
pre_prom = noise_proms( pre_prom ) pre_prom = noise_proms( pre_prom )
@ -649,14 +654,17 @@ def create_train_val_dataloader():
return train_dl, subtrain_dl, val_dl return train_dl, subtrain_dl, val_dl
# parse yaml to create an hdf5 tile # parse yaml to create an hdf5 file
def create_dataset_hdf5(): def create_dataset_hdf5():
cfg.dataset.use_hdf5 = True
cfg.load_hdf5(write=True)
symmap = get_phone_symmap() symmap = get_phone_symmap()
root = cfg.cfg_path root = cfg.cfg_path
hf = cfg.hdf5 hf = cfg.hdf5
def add( dir, type="training" ): def add( dir, type="training", audios=True, texts=True ):
dir = "./" + str(dir) dir = "./" + str(dir)
name = dir.replace(root, "") name = dir.replace(root, "")
@ -670,7 +678,10 @@ def create_dataset_hdf5():
# grab IDs for every file # grab IDs for every file
ids = { ".".join(file.split(".")[:-2]) for file in files } ids = { ".".join(file.split(".")[:-2]) for file in files }
for id in tqdm(ids, desc=f"Processing {name}"): for id in tqdm(ids, desc=f"Processing {name}"):
if not os.path.exists(f'{root}/{name}/{id}.qnt.pt') or not os.path.exists(f'{root}/{name}/{id}.phn.txt'): audio_exists = os.path.exists(f'{root}/{name}/{id}.qnt.pt') if audios else True
text_exists = os.path.exists(f'{root}/{name}/{id}.phn.txt') if texts else True
if not audio_exists or not text_exists:
continue continue
key = f'{type}/{name}/{id}' key = f'{type}/{name}/{id}'
@ -681,27 +692,29 @@ def create_dataset_hdf5():
group = hf.create_group(key) group = hf.create_group(key)
# audio # audio
qnt = torch.load(f'{root}/{name}/{id}.qnt.pt')[0].t() if audios:
group.create_dataset('audio', data=qnt.numpy(), compression='lzf') qnt = torch.load(f'{root}/{name}/{id}.qnt.pt')[0].t()
group.create_dataset('audio', data=qnt.numpy(), compression='lzf')
# text # text
with open(f'{root}/{name}/{id}.phn.txt', "r", encoding="utf8") as f: if texts:
content = f.read() with open(f'{root}/{name}/{id}.phn.txt', "r", encoding="utf8") as f:
split = content.split(" ") content = f.read()
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"] split = content.split(" ")
for s in set(phones): phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
if s not in symmap: for s in set(phones):
symmap[s] = len(symmap.keys()) if s not in symmap:
phn = [ symmap[s] for s in phones ] symmap[s] = len(symmap.keys())
phn = [ symmap[s] for s in phones ]
group.create_dataset('text', data=phn, compression='lzf', chunks=True) group.create_dataset('text', data=phn, compression='lzf', chunks=True)
# metadata # metadata
group.attrs['id'] = id group.attrs['id'] = id
group.attrs['type'] = type group.attrs['type'] = type
group.attrs['speaker'] = name group.attrs['speaker'] = name
group.attrs['duration'] = qnt.shape[0] / 75 group.attrs['duration'] = qnt.shape[0] / 75
group.attrs['phonemes'] = len(phn) group.attrs['phonemes'] = len(phn)
# training # training
for data_dir in tqdm(cfg.dataset.training, desc="Processing Training"): for data_dir in tqdm(cfg.dataset.training, desc="Processing Training"):
@ -713,10 +726,13 @@ def create_dataset_hdf5():
# noise # noise
for data_dir in tqdm(cfg.dataset.noise, desc='Processing Noise'): for data_dir in tqdm(cfg.dataset.noise, desc='Processing Noise'):
add( data_dir, type="noise" ) add( data_dir, type="noise", texts=False )
# write symmap # write symmap
hf.create_dataset('symmap', data=json.dumps(symmap)) try:
hf.create_dataset('symmap', data=json.dumps(symmap))
except Exception as e:
pass
hf.close() hf.close()
@ -724,14 +740,15 @@ if __name__ == "__main__":
import argparse import argparse
parser = argparse.ArgumentParser("Save trained model to path.") parser = argparse.ArgumentParser("Save trained model to path.")
parser.add_argument("--task", type=str) parser.add_argument("--action", type=str)
parser.add_argument("--tasks", type=str)
args = parser.parse_args() args = parser.parse_args()
task = args.task task = args.action
if args.task == "hdf5": if args.action == "hdf5":
create_dataset_hdf5() create_dataset_hdf5()
elif args.task == "sample": elif args.action == "sample":
train_dl, subtrain_dl, val_dl = create_train_val_dataloader() train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
samples = { samples = {
@ -745,77 +762,18 @@ if __name__ == "__main__":
del v[i]['proms'] del v[i]['proms']
del v[i]['resps'] del v[i]['resps']
print(f'{k}:', v) print(f'{k}:', v)
""" elif args.action == "tasks":
elif args.task == "tasks":
index = 0 index = 0
task = "ns" cfg.dataset.tasks_list = args.tasks.split(",")
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
batch = next(iter(train_dl))
train_dataset, val_dataset = create_datasets() for text, resps, proms, task in zip(batch["text"], batch["resps"], batch["proms"], batch["task"]):
train_dataset.task_symmap = get_task_symmap() if task not in cfg.dataset.tasks_list:
continue
if cfg.dataset.sample_type == "speaker": print(text, task)
spkr_name = train_dataset.spkrs[index] decode_to_file( proms, f"./.{task}.proms.wav", device="cpu" )
spkr_id = train_dataset.spkr_symmap[spkr_name] decode_to_file( resps, f"./.{task}.resps.wav", device="cpu" )
path = random.choice([*set(train_dataset.paths_by_spkr_name[spkr_name])]) break
else:
path = train_dataset.paths[index]
spkr_name = cfg.get_spkr(path)
spkr_id = train_dataset.spkr_symmap[spkr_name]
if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path)
text = torch.from_numpy(cfg.hdf5[key]["text"][:]).to(train_dataset.text_dtype)
resps = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.prom_levels]).to(torch.int16)
else:
text = torch.tensor([*map(train_dataset.phone_symmap.get, _get_phones(path))]).to(train_dataset.text_dtype)
resps = _load_quants(path)
noise = None
if task == "ns" or task == "sr":
# sample random noise
noise = train_dataset.sample_noise()
decode_to_file( noise, "./.noise.wav", device="cpu" )
# extend the noise to fill the target audio
noise = repeat_extend_audio(noise, resps.shape[0])
# create the input prompt by merging the target audio with the noise
proms = merge_audio(resps, noise, scale=[1, 0.125])
# set the target to just be the noise if <sr>
if task == "sr":
resps = noise
# prepend the task token
proms = torch.cat( [train_dataset.get_task_token(task), proms] )
# set the text prompt to empty to train without a guided text prompt
if random.random() < 0.5:
text = torch.tensor([1, 2]).to(train_dataset.text_dtype)
# target speech extraction
elif task == "tse":
# sample a random, clean, utterance for the target speaker
clean_proms = train_dataset.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
# sample a random, clean utterance from a different speaker
other_proms = train_dataset.sample_prompts(train_dataset.sample_speakers(ignore=[spkr_name]), ignore="")
# overlay the random speaker over the target audio
smallest_size = min(resps.shape[0], other_proms.shape[0])
if other_proms.shape[0] == smallest_size:
noisy_proms = merge_audio( resps[:smallest_size, :], other_proms )
noisy_proms = torch.cat( [ noisy_proms, resps[smallest_size:, :] ] )
else:
noisy_proms = merge_audio( resps, other_proms[:smallest_size, :] )
noisy_proms = torch.cat( [ noisy_proms, other_proms[smallest_size:, :] ] )
# stitch together the promps
proms = torch.cat( [clean_proms, train_dataset.get_task_token(task), noisy_proms] )
# set the text prompt to empty to train without a guided text prompt
if random.random() < 0.5:
text = torch.tensor([1, 2]).to(train_dataset.text_dtype)
decode_to_file( proms, "./.proms.wav", device="cpu" )
decode_to_file( resps, "./.resps.wav", device="cpu" )
if noise is not None:
decode_to_file( noise, "./.noise-fill.wav", device="cpu" )
"""