validated that SpeechX tasks cse and nse works, added a method to test each task by invoking python3 -m vall_e.data --action=tasks --tasks='sr,se,cse,nse'
This commit is contained in:
parent
6ca347e1e1
commit
f7f6d3bf6d
26
README.md
26
README.md
|
@ -63,7 +63,7 @@ If you're wanting to increase the `prom_levels` for a given model, or increase t
|
|||
|
||||
4. Customize your configuration and define the dataset by modifying `./data/config.yaml`. Refer to `./vall_e/config.py` for details. If you want to choose between different model presets, check `./vall_e/models/__init__.py`.
|
||||
|
||||
If you're interested in creating an HDF5 copy of your dataset, simply invoke: `python -m vall_e.data --create-hdf5 yaml='./data/config.yaml'`
|
||||
If you're interested in creating an HDF5 copy of your dataset, simply invoke: `python -m vall_e.data --action='hdf5' yaml='./data/config.yaml'`
|
||||
|
||||
5. Train the AR and NAR models using the following scripts: `python -m vall_e.train yaml=./data/config.yaml`
|
||||
|
||||
|
@ -81,32 +81,18 @@ Two dataset formats are supported:
|
|||
|
||||
## Export
|
||||
|
||||
Both trained models *can* be exported, but is only required if loading them on systems without DeepSpeed for inferencing (Windows systems). To export the models, run:
|
||||
Both trained models *can* be exported, but is only required if loading them on systems without DeepSpeed for inferencing (Windows systems). To export the models, run: `python -m vall_e.export yaml=./data/config.yaml`.
|
||||
|
||||
```
|
||||
python -m vall_e.export yaml=./data/config.yaml
|
||||
```
|
||||
|
||||
This will export the latest checkpoints.
|
||||
This will export the latest checkpoints under `./data/ckpt/ar-retnet-2/fp32.pth` and `./data/ckpt/nar-retnet-2/fp32.pth` to be loaded on any system with PyTorch.
|
||||
|
||||
## Synthesis
|
||||
|
||||
To synthesize speech, invoke either (if exported the models):
|
||||
|
||||
```
|
||||
python -m vall_e <text> <ref_path> <out_path> --ar-ckpt ./models/ar.pt --nar-ckpt ./models/nar.pt
|
||||
```
|
||||
|
||||
or:
|
||||
|
||||
```
|
||||
python -m vall_e <text> <ref_path> <out_path> yaml=<yaml_path>
|
||||
```
|
||||
To synthesize speech, invoke either (if exported the models): `python -m vall_e <text> <ref_path> <out_path> --ar-ckpt ./models/ar.pt --nar-ckpt ./models/nar.pt` or `python -m vall_e <text> <ref_path> <out_path> yaml=<yaml_path>`
|
||||
|
||||
Some additional flags you can pass are:
|
||||
* `--max-ar-steps`: maximum steps for inferencing through the AR model. Each second is 75 steps.
|
||||
* `--ar-temp`: sampling temperature to use for the AR pass.
|
||||
* `--nar-temp`: sampling temperature to use for the NAR pass.
|
||||
* `--ar-temp`: sampling temperature to use for the AR pass. During experimentation, `0.95` provides the most consistent output.
|
||||
* `--nar-temp`: sampling temperature to use for the NAR pass. During experimentation, `0.2` provides the most clean output.
|
||||
* `--device`: device to use (default: `cuda`, examples: `cuda:0`, `cuda:1`, `cpu`)
|
||||
|
||||
## To-Do
|
||||
|
|
|
@ -72,7 +72,7 @@ evaluation:
|
|||
|
||||
steps: 300
|
||||
ar_temperature: 1.0
|
||||
nar_temperature: 1.0
|
||||
nar_temperature: 0.2
|
||||
|
||||
trainer:
|
||||
iterations: 1_000_000
|
||||
|
|
|
@ -446,6 +446,18 @@ class Config(_Config):
|
|||
tmp = Config.from_yaml( config_path )
|
||||
self.__dict__.update(tmp.__dict__)
|
||||
|
||||
def load_hdf5( self, write=False ):
|
||||
if hasattr(self, 'hdf5'):
|
||||
self.hdf5.close()
|
||||
|
||||
if self.distributed:
|
||||
self.dataset.hdf5_flag = "r"
|
||||
try:
|
||||
self.hdf5 = h5py.File(f'{self.cfg_path}/{self.dataset.hdf5_name}', 'a' if write else self.dataset.hdf5_flag) # to-do, have an easy to set flag that determines if training or creating the dataset
|
||||
except Exception as e:
|
||||
print("Error while opening HDF5 file:", f'{self.cfg_path}/{self.dataset.hdf5_name}', str(e))
|
||||
self.dataset.use_hdf5 = False
|
||||
|
||||
def format( self ):
|
||||
self.dataset = Dataset(**self.dataset)
|
||||
self.models = Models(**self.models)
|
||||
|
@ -466,13 +478,7 @@ try:
|
|||
|
||||
# cached_property stopped working...
|
||||
if cfg.dataset.use_hdf5:
|
||||
if cfg.distributed:
|
||||
cfg.dataset.hdf5_flag = "r"
|
||||
try:
|
||||
cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', cfg.dataset.hdf5_flag) # to-do, have an easy to set flag that determines if training or creating the dataset
|
||||
except Exception as e:
|
||||
print("Error while opening HDF5 file:", f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', str(e))
|
||||
cfg.dataset.use_hdf5 = False
|
||||
cfg.load_hdf5()
|
||||
|
||||
if not cfg.dataset.use_hdf5:
|
||||
cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ]
|
||||
|
|
170
vall_e/data.py
170
vall_e/data.py
|
@ -300,6 +300,11 @@ class Dataset(_Dataset):
|
|||
|
||||
task = random.choice(self.tasks)
|
||||
|
||||
# ensure a speaker has at least four utterances
|
||||
# default to tts if not
|
||||
if len(set(self.paths_by_spkr_name[spkr_name]) - {path}) < 4:
|
||||
task = "tts"
|
||||
|
||||
noise_scale = 0.125
|
||||
# text-to-speech
|
||||
if task == "tts":
|
||||
|
@ -349,7 +354,7 @@ class Dataset(_Dataset):
|
|||
# clean speech editing
|
||||
elif task == "cse" or task == "nse":
|
||||
choices = set(self.paths_by_spkr_name[spkr_name]) - {path}
|
||||
sampled = random.choice([*choices], 4)
|
||||
sampled = random.sample([*choices], 4)
|
||||
|
||||
if cfg.dataset.use_hdf5:
|
||||
texts = [ torch.from_numpy(cfg.hdf5[_get_hdf5_path(path)]["text"][:]).to(self.text_dtype) for path in sampled ]
|
||||
|
@ -359,8 +364,8 @@ class Dataset(_Dataset):
|
|||
qnts = [ _load_quants(path) for path in sampled ]
|
||||
|
||||
# remove <s></s>
|
||||
for text in texts:
|
||||
text = text[1:-1]
|
||||
for i in range(len(texts)):
|
||||
texts[i] = texts[i][1:-1]
|
||||
|
||||
pre_text, mid_text, post_text, edit_text = texts
|
||||
pre_prom, mid_prom, post_prom, edit_prom = qnts
|
||||
|
@ -376,11 +381,11 @@ class Dataset(_Dataset):
|
|||
|
||||
# create new text
|
||||
text = torch.cat(
|
||||
[ 1 ] + # <s>
|
||||
([ pre_text ] if pre_text is not None else []) +
|
||||
[ edit_text ] +
|
||||
([ post_post ] if post_post is not None else []) +
|
||||
[ 2 ] # </s>
|
||||
[ torch.Tensor( [ 1 ] ).to(dtype=self.text_dtype) ] + # <s>
|
||||
([ pre_text, torch.Tensor( [ 3 ] ).to(dtype=self.text_dtype) ] if pre_text is not None else []) + # pre_text + space'
|
||||
[ edit_text ] + # 'edit text'
|
||||
([ torch.Tensor( [ 3 ] ).to(dtype=self.text_dtype), post_text ] if post_text is not None else []) + # 'space' + edit_text
|
||||
[ torch.Tensor( [ 2 ] ).to(dtype=self.text_dtype) ] # </s>
|
||||
)
|
||||
|
||||
if task == "nse":
|
||||
|
@ -397,7 +402,7 @@ class Dataset(_Dataset):
|
|||
# extend the noise to fill the target audio
|
||||
n = repeat_extend_audio(noise, proms.shape[0])
|
||||
# merge the noise over the utterance
|
||||
return merge_audio(proms, noise, scale=[1, noise_scale], device="cpu")
|
||||
return merge_audio(proms, n, scale=[1, noise_scale], device="cpu")
|
||||
|
||||
# apply noise to all pieces
|
||||
pre_prom = noise_proms( pre_prom )
|
||||
|
@ -649,14 +654,17 @@ def create_train_val_dataloader():
|
|||
|
||||
return train_dl, subtrain_dl, val_dl
|
||||
|
||||
# parse yaml to create an hdf5 tile
|
||||
# parse yaml to create an hdf5 file
|
||||
def create_dataset_hdf5():
|
||||
cfg.dataset.use_hdf5 = True
|
||||
cfg.load_hdf5(write=True)
|
||||
|
||||
symmap = get_phone_symmap()
|
||||
|
||||
root = cfg.cfg_path
|
||||
hf = cfg.hdf5
|
||||
|
||||
def add( dir, type="training" ):
|
||||
def add( dir, type="training", audios=True, texts=True ):
|
||||
dir = "./" + str(dir)
|
||||
name = dir.replace(root, "")
|
||||
|
||||
|
@ -670,7 +678,10 @@ def create_dataset_hdf5():
|
|||
# grab IDs for every file
|
||||
ids = { ".".join(file.split(".")[:-2]) for file in files }
|
||||
for id in tqdm(ids, desc=f"Processing {name}"):
|
||||
if not os.path.exists(f'{root}/{name}/{id}.qnt.pt') or not os.path.exists(f'{root}/{name}/{id}.phn.txt'):
|
||||
audio_exists = os.path.exists(f'{root}/{name}/{id}.qnt.pt') if audios else True
|
||||
text_exists = os.path.exists(f'{root}/{name}/{id}.phn.txt') if texts else True
|
||||
|
||||
if not audio_exists or not text_exists:
|
||||
continue
|
||||
|
||||
key = f'{type}/{name}/{id}'
|
||||
|
@ -681,27 +692,29 @@ def create_dataset_hdf5():
|
|||
group = hf.create_group(key)
|
||||
|
||||
# audio
|
||||
qnt = torch.load(f'{root}/{name}/{id}.qnt.pt')[0].t()
|
||||
group.create_dataset('audio', data=qnt.numpy(), compression='lzf')
|
||||
if audios:
|
||||
qnt = torch.load(f'{root}/{name}/{id}.qnt.pt')[0].t()
|
||||
group.create_dataset('audio', data=qnt.numpy(), compression='lzf')
|
||||
|
||||
# text
|
||||
with open(f'{root}/{name}/{id}.phn.txt', "r", encoding="utf8") as f:
|
||||
content = f.read()
|
||||
split = content.split(" ")
|
||||
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
|
||||
for s in set(phones):
|
||||
if s not in symmap:
|
||||
symmap[s] = len(symmap.keys())
|
||||
phn = [ symmap[s] for s in phones ]
|
||||
if texts:
|
||||
with open(f'{root}/{name}/{id}.phn.txt', "r", encoding="utf8") as f:
|
||||
content = f.read()
|
||||
split = content.split(" ")
|
||||
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
|
||||
for s in set(phones):
|
||||
if s not in symmap:
|
||||
symmap[s] = len(symmap.keys())
|
||||
phn = [ symmap[s] for s in phones ]
|
||||
|
||||
group.create_dataset('text', data=phn, compression='lzf', chunks=True)
|
||||
group.create_dataset('text', data=phn, compression='lzf', chunks=True)
|
||||
|
||||
# metadata
|
||||
group.attrs['id'] = id
|
||||
group.attrs['type'] = type
|
||||
group.attrs['speaker'] = name
|
||||
group.attrs['duration'] = qnt.shape[0] / 75
|
||||
group.attrs['phonemes'] = len(phn)
|
||||
# metadata
|
||||
group.attrs['id'] = id
|
||||
group.attrs['type'] = type
|
||||
group.attrs['speaker'] = name
|
||||
group.attrs['duration'] = qnt.shape[0] / 75
|
||||
group.attrs['phonemes'] = len(phn)
|
||||
|
||||
# training
|
||||
for data_dir in tqdm(cfg.dataset.training, desc="Processing Training"):
|
||||
|
@ -713,10 +726,13 @@ def create_dataset_hdf5():
|
|||
|
||||
# noise
|
||||
for data_dir in tqdm(cfg.dataset.noise, desc='Processing Noise'):
|
||||
add( data_dir, type="noise" )
|
||||
add( data_dir, type="noise", texts=False )
|
||||
|
||||
# write symmap
|
||||
hf.create_dataset('symmap', data=json.dumps(symmap))
|
||||
try:
|
||||
hf.create_dataset('symmap', data=json.dumps(symmap))
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
hf.close()
|
||||
|
||||
|
@ -724,14 +740,15 @@ if __name__ == "__main__":
|
|||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser("Save trained model to path.")
|
||||
parser.add_argument("--task", type=str)
|
||||
parser.add_argument("--action", type=str)
|
||||
parser.add_argument("--tasks", type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
task = args.task
|
||||
task = args.action
|
||||
|
||||
if args.task == "hdf5":
|
||||
if args.action == "hdf5":
|
||||
create_dataset_hdf5()
|
||||
elif args.task == "sample":
|
||||
elif args.action == "sample":
|
||||
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
||||
|
||||
samples = {
|
||||
|
@ -745,77 +762,18 @@ if __name__ == "__main__":
|
|||
del v[i]['proms']
|
||||
del v[i]['resps']
|
||||
print(f'{k}:', v)
|
||||
"""
|
||||
elif args.task == "tasks":
|
||||
elif args.action == "tasks":
|
||||
index = 0
|
||||
task = "ns"
|
||||
cfg.dataset.tasks_list = args.tasks.split(",")
|
||||
|
||||
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
||||
batch = next(iter(train_dl))
|
||||
|
||||
train_dataset, val_dataset = create_datasets()
|
||||
train_dataset.task_symmap = get_task_symmap()
|
||||
for text, resps, proms, task in zip(batch["text"], batch["resps"], batch["proms"], batch["task"]):
|
||||
if task not in cfg.dataset.tasks_list:
|
||||
continue
|
||||
|
||||
if cfg.dataset.sample_type == "speaker":
|
||||
spkr_name = train_dataset.spkrs[index]
|
||||
spkr_id = train_dataset.spkr_symmap[spkr_name]
|
||||
path = random.choice([*set(train_dataset.paths_by_spkr_name[spkr_name])])
|
||||
else:
|
||||
path = train_dataset.paths[index]
|
||||
spkr_name = cfg.get_spkr(path)
|
||||
spkr_id = train_dataset.spkr_symmap[spkr_name]
|
||||
|
||||
if cfg.dataset.use_hdf5:
|
||||
key = _get_hdf5_path(path)
|
||||
text = torch.from_numpy(cfg.hdf5[key]["text"][:]).to(train_dataset.text_dtype)
|
||||
resps = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.prom_levels]).to(torch.int16)
|
||||
else:
|
||||
text = torch.tensor([*map(train_dataset.phone_symmap.get, _get_phones(path))]).to(train_dataset.text_dtype)
|
||||
resps = _load_quants(path)
|
||||
|
||||
noise = None
|
||||
if task == "ns" or task == "sr":
|
||||
# sample random noise
|
||||
noise = train_dataset.sample_noise()
|
||||
|
||||
decode_to_file( noise, "./.noise.wav", device="cpu" )
|
||||
|
||||
# extend the noise to fill the target audio
|
||||
noise = repeat_extend_audio(noise, resps.shape[0])
|
||||
# create the input prompt by merging the target audio with the noise
|
||||
proms = merge_audio(resps, noise, scale=[1, 0.125])
|
||||
# set the target to just be the noise if <sr>
|
||||
if task == "sr":
|
||||
resps = noise
|
||||
# prepend the task token
|
||||
proms = torch.cat( [train_dataset.get_task_token(task), proms] )
|
||||
|
||||
# set the text prompt to empty to train without a guided text prompt
|
||||
if random.random() < 0.5:
|
||||
text = torch.tensor([1, 2]).to(train_dataset.text_dtype)
|
||||
# target speech extraction
|
||||
elif task == "tse":
|
||||
# sample a random, clean, utterance for the target speaker
|
||||
clean_proms = train_dataset.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
||||
# sample a random, clean utterance from a different speaker
|
||||
other_proms = train_dataset.sample_prompts(train_dataset.sample_speakers(ignore=[spkr_name]), ignore="")
|
||||
# overlay the random speaker over the target audio
|
||||
|
||||
smallest_size = min(resps.shape[0], other_proms.shape[0])
|
||||
if other_proms.shape[0] == smallest_size:
|
||||
noisy_proms = merge_audio( resps[:smallest_size, :], other_proms )
|
||||
noisy_proms = torch.cat( [ noisy_proms, resps[smallest_size:, :] ] )
|
||||
else:
|
||||
noisy_proms = merge_audio( resps, other_proms[:smallest_size, :] )
|
||||
noisy_proms = torch.cat( [ noisy_proms, other_proms[smallest_size:, :] ] )
|
||||
|
||||
# stitch together the promps
|
||||
proms = torch.cat( [clean_proms, train_dataset.get_task_token(task), noisy_proms] )
|
||||
|
||||
# set the text prompt to empty to train without a guided text prompt
|
||||
if random.random() < 0.5:
|
||||
text = torch.tensor([1, 2]).to(train_dataset.text_dtype)
|
||||
|
||||
decode_to_file( proms, "./.proms.wav", device="cpu" )
|
||||
decode_to_file( resps, "./.resps.wav", device="cpu" )
|
||||
|
||||
if noise is not None:
|
||||
decode_to_file( noise, "./.noise-fill.wav", device="cpu" )
|
||||
"""
|
||||
print(text, task)
|
||||
decode_to_file( proms, f"./.{task}.proms.wav", device="cpu" )
|
||||
decode_to_file( resps, f"./.{task}.resps.wav", device="cpu" )
|
||||
break
|
Loading…
Reference in New Issue
Block a user