validated that SpeechX tasks cse and nse works, added a method to test each task by invoking python3 -m vall_e.data --action=tasks --tasks='sr,se,cse,nse'

This commit is contained in:
mrq 2023-08-19 09:50:07 -05:00
parent 6ca347e1e1
commit f7f6d3bf6d
4 changed files with 84 additions and 134 deletions

View File

@ -63,7 +63,7 @@ If you're wanting to increase the `prom_levels` for a given model, or increase t
4. Customize your configuration and define the dataset by modifying `./data/config.yaml`. Refer to `./vall_e/config.py` for details. If you want to choose between different model presets, check `./vall_e/models/__init__.py`.
If you're interested in creating an HDF5 copy of your dataset, simply invoke: `python -m vall_e.data --create-hdf5 yaml='./data/config.yaml'`
If you're interested in creating an HDF5 copy of your dataset, simply invoke: `python -m vall_e.data --action='hdf5' yaml='./data/config.yaml'`
5. Train the AR and NAR models using the following scripts: `python -m vall_e.train yaml=./data/config.yaml`
@ -81,32 +81,18 @@ Two dataset formats are supported:
## Export
Both trained models *can* be exported, but is only required if loading them on systems without DeepSpeed for inferencing (Windows systems). To export the models, run:
Both trained models *can* be exported, but is only required if loading them on systems without DeepSpeed for inferencing (Windows systems). To export the models, run: `python -m vall_e.export yaml=./data/config.yaml`.
```
python -m vall_e.export yaml=./data/config.yaml
```
This will export the latest checkpoints.
This will export the latest checkpoints under `./data/ckpt/ar-retnet-2/fp32.pth` and `./data/ckpt/nar-retnet-2/fp32.pth` to be loaded on any system with PyTorch.
## Synthesis
To synthesize speech, invoke either (if exported the models):
```
python -m vall_e <text> <ref_path> <out_path> --ar-ckpt ./models/ar.pt --nar-ckpt ./models/nar.pt
```
or:
```
python -m vall_e <text> <ref_path> <out_path> yaml=<yaml_path>
```
To synthesize speech, invoke either (if exported the models): `python -m vall_e <text> <ref_path> <out_path> --ar-ckpt ./models/ar.pt --nar-ckpt ./models/nar.pt` or `python -m vall_e <text> <ref_path> <out_path> yaml=<yaml_path>`
Some additional flags you can pass are:
* `--max-ar-steps`: maximum steps for inferencing through the AR model. Each second is 75 steps.
* `--ar-temp`: sampling temperature to use for the AR pass.
* `--nar-temp`: sampling temperature to use for the NAR pass.
* `--ar-temp`: sampling temperature to use for the AR pass. During experimentation, `0.95` provides the most consistent output.
* `--nar-temp`: sampling temperature to use for the NAR pass. During experimentation, `0.2` provides the most clean output.
* `--device`: device to use (default: `cuda`, examples: `cuda:0`, `cuda:1`, `cpu`)
## To-Do

View File

@ -72,7 +72,7 @@ evaluation:
steps: 300
ar_temperature: 1.0
nar_temperature: 1.0
nar_temperature: 0.2
trainer:
iterations: 1_000_000

View File

@ -446,6 +446,18 @@ class Config(_Config):
tmp = Config.from_yaml( config_path )
self.__dict__.update(tmp.__dict__)
def load_hdf5( self, write=False ):
if hasattr(self, 'hdf5'):
self.hdf5.close()
if self.distributed:
self.dataset.hdf5_flag = "r"
try:
self.hdf5 = h5py.File(f'{self.cfg_path}/{self.dataset.hdf5_name}', 'a' if write else self.dataset.hdf5_flag) # to-do, have an easy to set flag that determines if training or creating the dataset
except Exception as e:
print("Error while opening HDF5 file:", f'{self.cfg_path}/{self.dataset.hdf5_name}', str(e))
self.dataset.use_hdf5 = False
def format( self ):
self.dataset = Dataset(**self.dataset)
self.models = Models(**self.models)
@ -466,13 +478,7 @@ try:
# cached_property stopped working...
if cfg.dataset.use_hdf5:
if cfg.distributed:
cfg.dataset.hdf5_flag = "r"
try:
cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', cfg.dataset.hdf5_flag) # to-do, have an easy to set flag that determines if training or creating the dataset
except Exception as e:
print("Error while opening HDF5 file:", f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', str(e))
cfg.dataset.use_hdf5 = False
cfg.load_hdf5()
if not cfg.dataset.use_hdf5:
cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ]

View File

@ -300,6 +300,11 @@ class Dataset(_Dataset):
task = random.choice(self.tasks)
# ensure a speaker has at least four utterances
# default to tts if not
if len(set(self.paths_by_spkr_name[spkr_name]) - {path}) < 4:
task = "tts"
noise_scale = 0.125
# text-to-speech
if task == "tts":
@ -349,7 +354,7 @@ class Dataset(_Dataset):
# clean speech editing
elif task == "cse" or task == "nse":
choices = set(self.paths_by_spkr_name[spkr_name]) - {path}
sampled = random.choice([*choices], 4)
sampled = random.sample([*choices], 4)
if cfg.dataset.use_hdf5:
texts = [ torch.from_numpy(cfg.hdf5[_get_hdf5_path(path)]["text"][:]).to(self.text_dtype) for path in sampled ]
@ -359,8 +364,8 @@ class Dataset(_Dataset):
qnts = [ _load_quants(path) for path in sampled ]
# remove <s></s>
for text in texts:
text = text[1:-1]
for i in range(len(texts)):
texts[i] = texts[i][1:-1]
pre_text, mid_text, post_text, edit_text = texts
pre_prom, mid_prom, post_prom, edit_prom = qnts
@ -376,11 +381,11 @@ class Dataset(_Dataset):
# create new text
text = torch.cat(
[ 1 ] + # <s>
([ pre_text ] if pre_text is not None else []) +
[ edit_text ] +
([ post_post ] if post_post is not None else []) +
[ 2 ] # </s>
[ torch.Tensor( [ 1 ] ).to(dtype=self.text_dtype) ] + # <s>
([ pre_text, torch.Tensor( [ 3 ] ).to(dtype=self.text_dtype) ] if pre_text is not None else []) + # pre_text + space'
[ edit_text ] + # 'edit text'
([ torch.Tensor( [ 3 ] ).to(dtype=self.text_dtype), post_text ] if post_text is not None else []) + # 'space' + edit_text
[ torch.Tensor( [ 2 ] ).to(dtype=self.text_dtype) ] # </s>
)
if task == "nse":
@ -397,7 +402,7 @@ class Dataset(_Dataset):
# extend the noise to fill the target audio
n = repeat_extend_audio(noise, proms.shape[0])
# merge the noise over the utterance
return merge_audio(proms, noise, scale=[1, noise_scale], device="cpu")
return merge_audio(proms, n, scale=[1, noise_scale], device="cpu")
# apply noise to all pieces
pre_prom = noise_proms( pre_prom )
@ -649,14 +654,17 @@ def create_train_val_dataloader():
return train_dl, subtrain_dl, val_dl
# parse yaml to create an hdf5 tile
# parse yaml to create an hdf5 file
def create_dataset_hdf5():
cfg.dataset.use_hdf5 = True
cfg.load_hdf5(write=True)
symmap = get_phone_symmap()
root = cfg.cfg_path
hf = cfg.hdf5
def add( dir, type="training" ):
def add( dir, type="training", audios=True, texts=True ):
dir = "./" + str(dir)
name = dir.replace(root, "")
@ -670,7 +678,10 @@ def create_dataset_hdf5():
# grab IDs for every file
ids = { ".".join(file.split(".")[:-2]) for file in files }
for id in tqdm(ids, desc=f"Processing {name}"):
if not os.path.exists(f'{root}/{name}/{id}.qnt.pt') or not os.path.exists(f'{root}/{name}/{id}.phn.txt'):
audio_exists = os.path.exists(f'{root}/{name}/{id}.qnt.pt') if audios else True
text_exists = os.path.exists(f'{root}/{name}/{id}.phn.txt') if texts else True
if not audio_exists or not text_exists:
continue
key = f'{type}/{name}/{id}'
@ -681,27 +692,29 @@ def create_dataset_hdf5():
group = hf.create_group(key)
# audio
qnt = torch.load(f'{root}/{name}/{id}.qnt.pt')[0].t()
group.create_dataset('audio', data=qnt.numpy(), compression='lzf')
if audios:
qnt = torch.load(f'{root}/{name}/{id}.qnt.pt')[0].t()
group.create_dataset('audio', data=qnt.numpy(), compression='lzf')
# text
with open(f'{root}/{name}/{id}.phn.txt', "r", encoding="utf8") as f:
content = f.read()
split = content.split(" ")
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
for s in set(phones):
if s not in symmap:
symmap[s] = len(symmap.keys())
phn = [ symmap[s] for s in phones ]
if texts:
with open(f'{root}/{name}/{id}.phn.txt', "r", encoding="utf8") as f:
content = f.read()
split = content.split(" ")
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
for s in set(phones):
if s not in symmap:
symmap[s] = len(symmap.keys())
phn = [ symmap[s] for s in phones ]
group.create_dataset('text', data=phn, compression='lzf', chunks=True)
group.create_dataset('text', data=phn, compression='lzf', chunks=True)
# metadata
group.attrs['id'] = id
group.attrs['type'] = type
group.attrs['speaker'] = name
group.attrs['duration'] = qnt.shape[0] / 75
group.attrs['phonemes'] = len(phn)
# metadata
group.attrs['id'] = id
group.attrs['type'] = type
group.attrs['speaker'] = name
group.attrs['duration'] = qnt.shape[0] / 75
group.attrs['phonemes'] = len(phn)
# training
for data_dir in tqdm(cfg.dataset.training, desc="Processing Training"):
@ -713,10 +726,13 @@ def create_dataset_hdf5():
# noise
for data_dir in tqdm(cfg.dataset.noise, desc='Processing Noise'):
add( data_dir, type="noise" )
add( data_dir, type="noise", texts=False )
# write symmap
hf.create_dataset('symmap', data=json.dumps(symmap))
try:
hf.create_dataset('symmap', data=json.dumps(symmap))
except Exception as e:
pass
hf.close()
@ -724,14 +740,15 @@ if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser("Save trained model to path.")
parser.add_argument("--task", type=str)
parser.add_argument("--action", type=str)
parser.add_argument("--tasks", type=str)
args = parser.parse_args()
task = args.task
task = args.action
if args.task == "hdf5":
if args.action == "hdf5":
create_dataset_hdf5()
elif args.task == "sample":
elif args.action == "sample":
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
samples = {
@ -745,77 +762,18 @@ if __name__ == "__main__":
del v[i]['proms']
del v[i]['resps']
print(f'{k}:', v)
"""
elif args.task == "tasks":
elif args.action == "tasks":
index = 0
task = "ns"
cfg.dataset.tasks_list = args.tasks.split(",")
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
batch = next(iter(train_dl))
train_dataset, val_dataset = create_datasets()
train_dataset.task_symmap = get_task_symmap()
for text, resps, proms, task in zip(batch["text"], batch["resps"], batch["proms"], batch["task"]):
if task not in cfg.dataset.tasks_list:
continue
if cfg.dataset.sample_type == "speaker":
spkr_name = train_dataset.spkrs[index]
spkr_id = train_dataset.spkr_symmap[spkr_name]
path = random.choice([*set(train_dataset.paths_by_spkr_name[spkr_name])])
else:
path = train_dataset.paths[index]
spkr_name = cfg.get_spkr(path)
spkr_id = train_dataset.spkr_symmap[spkr_name]
if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path)
text = torch.from_numpy(cfg.hdf5[key]["text"][:]).to(train_dataset.text_dtype)
resps = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.prom_levels]).to(torch.int16)
else:
text = torch.tensor([*map(train_dataset.phone_symmap.get, _get_phones(path))]).to(train_dataset.text_dtype)
resps = _load_quants(path)
noise = None
if task == "ns" or task == "sr":
# sample random noise
noise = train_dataset.sample_noise()
decode_to_file( noise, "./.noise.wav", device="cpu" )
# extend the noise to fill the target audio
noise = repeat_extend_audio(noise, resps.shape[0])
# create the input prompt by merging the target audio with the noise
proms = merge_audio(resps, noise, scale=[1, 0.125])
# set the target to just be the noise if <sr>
if task == "sr":
resps = noise
# prepend the task token
proms = torch.cat( [train_dataset.get_task_token(task), proms] )
# set the text prompt to empty to train without a guided text prompt
if random.random() < 0.5:
text = torch.tensor([1, 2]).to(train_dataset.text_dtype)
# target speech extraction
elif task == "tse":
# sample a random, clean, utterance for the target speaker
clean_proms = train_dataset.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
# sample a random, clean utterance from a different speaker
other_proms = train_dataset.sample_prompts(train_dataset.sample_speakers(ignore=[spkr_name]), ignore="")
# overlay the random speaker over the target audio
smallest_size = min(resps.shape[0], other_proms.shape[0])
if other_proms.shape[0] == smallest_size:
noisy_proms = merge_audio( resps[:smallest_size, :], other_proms )
noisy_proms = torch.cat( [ noisy_proms, resps[smallest_size:, :] ] )
else:
noisy_proms = merge_audio( resps, other_proms[:smallest_size, :] )
noisy_proms = torch.cat( [ noisy_proms, other_proms[smallest_size:, :] ] )
# stitch together the promps
proms = torch.cat( [clean_proms, train_dataset.get_task_token(task), noisy_proms] )
# set the text prompt to empty to train without a guided text prompt
if random.random() < 0.5:
text = torch.tensor([1, 2]).to(train_dataset.text_dtype)
decode_to_file( proms, "./.proms.wav", device="cpu" )
decode_to_file( resps, "./.resps.wav", device="cpu" )
if noise is not None:
decode_to_file( noise, "./.noise-fill.wav", device="cpu" )
"""
print(text, task)
decode_to_file( proms, f"./.{task}.proms.wav", device="cpu" )
decode_to_file( resps, f"./.{task}.resps.wav", device="cpu" )
break